Compare commits

..

167 Commits

Author SHA1 Message Date
2f75d3aebe Update
[ghstack-poisoned]
2025-10-29 21:23:30 +08:00
cd51ee276b Update (base update)
[ghstack-poisoned]
2025-10-29 21:23:30 +08:00
f44a8f5201 Update
[ghstack-poisoned]
2025-10-29 21:02:54 +08:00
95039738e3 Update (base update)
[ghstack-poisoned]
2025-10-29 21:02:54 +08:00
3f0985bf89 Update
[ghstack-poisoned]
2025-10-11 21:35:31 +08:00
d96b71c3ef Update (base update)
[ghstack-poisoned]
2025-10-11 21:35:31 +08:00
6676fe538f Update
[ghstack-poisoned]
2025-10-08 22:52:13 +08:00
29d16a10f3 Update (base update)
[ghstack-poisoned]
2025-10-08 22:52:13 +08:00
70183368c6 Update
[ghstack-poisoned]
2025-10-08 22:39:24 +08:00
c374b66c75 Update (base update)
[ghstack-poisoned]
2025-10-08 22:39:24 +08:00
c601a1ea72 Update
[ghstack-poisoned]
2025-09-19 18:03:46 +08:00
faa50fa6c4 Update (base update)
[ghstack-poisoned]
2025-09-19 18:03:46 +08:00
bbb546f542 Update
[ghstack-poisoned]
2025-09-06 11:34:34 +08:00
dc53fc2af2 Update (base update)
[ghstack-poisoned]
2025-09-06 11:34:34 +08:00
a2d5216c04 Update
[ghstack-poisoned]
2025-08-17 16:23:38 +08:00
26a1088f9f Update (base update)
[ghstack-poisoned]
2025-08-17 16:23:38 +08:00
9f3385822d Update
[ghstack-poisoned]
2025-08-09 02:51:18 +08:00
968b72ca2c Update (base update)
[ghstack-poisoned]
2025-08-09 02:51:18 +08:00
5ae581762b Update
[ghstack-poisoned]
2025-07-31 15:19:09 +08:00
54911834c4 Update (base update)
[ghstack-poisoned]
2025-07-31 15:19:09 +08:00
332e835040 Update
[ghstack-poisoned]
2025-07-25 20:00:31 +08:00
1e1b37ed77 Update (base update)
[ghstack-poisoned]
2025-07-25 20:00:31 +08:00
d05480c236 Update
[ghstack-poisoned]
2025-07-17 15:02:04 +08:00
8e2c6ff709 Update (base update)
[ghstack-poisoned]
2025-07-17 15:02:04 +08:00
3e66bd8fa8 Update
[ghstack-poisoned]
2025-07-09 19:01:34 +08:00
f589cb4a72 Update (base update)
[ghstack-poisoned]
2025-07-09 19:01:34 +08:00
369df36d49 Update
[ghstack-poisoned]
2025-07-03 16:24:23 +08:00
7aff2fc214 Update (base update)
[ghstack-poisoned]
2025-07-03 16:24:23 +08:00
b27cc37252 Update
[ghstack-poisoned]
2025-06-28 20:59:47 +08:00
41663a247d Update (base update)
[ghstack-poisoned]
2025-06-28 20:59:47 +08:00
80e40a5976 Update
[ghstack-poisoned]
2025-06-27 21:27:45 +08:00
1c337ea84b Update (base update)
[ghstack-poisoned]
2025-06-27 21:27:45 +08:00
fb2b16422d Update
[ghstack-poisoned]
2025-06-23 22:51:21 +08:00
6d00bd774f Update (base update)
[ghstack-poisoned]
2025-06-23 22:51:21 +08:00
2a9ceb2f1f Update
[ghstack-poisoned]
2025-06-18 23:17:48 +08:00
303e7afcd9 Update (base update)
[ghstack-poisoned]
2025-06-18 23:17:48 +08:00
cead985182 Update
[ghstack-poisoned]
2025-06-06 19:50:50 +08:00
c2e1972c18 Update (base update)
[ghstack-poisoned]
2025-06-06 19:50:50 +08:00
11d7c79cea Update
[ghstack-poisoned]
2025-05-31 21:59:59 +08:00
b32f36ce35 Update (base update)
[ghstack-poisoned]
2025-05-31 21:59:59 +08:00
979efcb825 Update
[ghstack-poisoned]
2025-05-28 20:43:33 +08:00
2c60570864 Update (base update)
[ghstack-poisoned]
2025-05-28 20:43:33 +08:00
99836e07fb Update
[ghstack-poisoned]
2025-05-16 11:37:32 +08:00
faea6584f8 Update (base update)
[ghstack-poisoned]
2025-05-16 11:37:32 +08:00
08d95fe5c4 Update
[ghstack-poisoned]
2025-05-14 20:35:01 +08:00
3d383f42e9 Update (base update)
[ghstack-poisoned]
2025-05-14 20:35:01 +08:00
e57894677e Update
[ghstack-poisoned]
2025-05-08 21:19:08 +08:00
27ac21550e Update (base update)
[ghstack-poisoned]
2025-05-08 21:19:08 +08:00
41259dc86e Update
[ghstack-poisoned]
2025-05-04 02:10:44 +08:00
28c48d60aa Update (base update)
[ghstack-poisoned]
2025-05-04 02:10:44 +08:00
113b8306a8 Update
[ghstack-poisoned]
2025-05-03 02:34:22 +08:00
ab474eebfd Update (base update)
[ghstack-poisoned]
2025-05-03 02:34:22 +08:00
dbd47d2dae Update
[ghstack-poisoned]
2025-05-03 01:14:43 +08:00
727a1aa849 Update (base update)
[ghstack-poisoned]
2025-05-03 01:14:43 +08:00
ba32ef92a6 Update
[ghstack-poisoned]
2025-05-03 00:45:00 +08:00
c659299f60 Update (base update)
[ghstack-poisoned]
2025-05-03 00:45:00 +08:00
313bfeea17 Update
[ghstack-poisoned]
2025-05-03 00:40:33 +08:00
4fd958b566 Update (base update)
[ghstack-poisoned]
2025-05-03 00:40:33 +08:00
5710b9d4af Update
[ghstack-poisoned]
2025-05-02 02:30:03 +08:00
f3e5113185 Update (base update)
[ghstack-poisoned]
2025-05-02 02:30:03 +08:00
599b39676b Update
[ghstack-poisoned]
2025-05-02 02:25:06 +08:00
e4d59f3a0a Update (base update)
[ghstack-poisoned]
2025-05-02 02:25:06 +08:00
fffb30d445 Update
[ghstack-poisoned]
2025-05-02 01:44:37 +08:00
cd38ad58a5 Update (base update)
[ghstack-poisoned]
2025-05-02 01:44:37 +08:00
c309091454 Update
[ghstack-poisoned]
2025-05-02 01:39:06 +08:00
3e53b1fb27 Update (base update)
[ghstack-poisoned]
2025-05-02 01:39:06 +08:00
bfe4839177 Update
[ghstack-poisoned]
2025-04-26 11:34:39 +08:00
121f110b83 Update (base update)
[ghstack-poisoned]
2025-04-26 11:34:39 +08:00
ad0668cbb8 Update
[ghstack-poisoned]
2025-04-23 21:35:46 +08:00
7a962a2f0d Update (base update)
[ghstack-poisoned]
2025-04-23 21:35:46 +08:00
0d8d2a0360 Update
[ghstack-poisoned]
2025-04-15 22:19:49 +08:00
f6795a4922 Update (base update)
[ghstack-poisoned]
2025-04-15 22:19:49 +08:00
898c2037ed Update
[ghstack-poisoned]
2025-04-15 22:12:31 +08:00
f98737b13a Update (base update)
[ghstack-poisoned]
2025-04-15 22:12:31 +08:00
97cf576c2c Update
[ghstack-poisoned]
2025-04-15 22:10:41 +08:00
f643f2e78b Update (base update)
[ghstack-poisoned]
2025-04-15 22:10:41 +08:00
c9607615a4 Update
[ghstack-poisoned]
2025-04-15 22:03:25 +08:00
d5e568d6e0 Update (base update)
[ghstack-poisoned]
2025-04-15 22:03:25 +08:00
aa53182976 Update
[ghstack-poisoned]
2025-04-11 19:04:56 +08:00
2549707053 Update (base update)
[ghstack-poisoned]
2025-04-11 19:04:56 +08:00
0625bbf0c7 Update
[ghstack-poisoned]
2025-04-11 18:35:20 +08:00
3b8b6ef6fb Update (base update)
[ghstack-poisoned]
2025-04-11 18:35:20 +08:00
ddb82b6b96 Update
[ghstack-poisoned]
2025-04-11 18:27:39 +08:00
57f2575735 Update (base update)
[ghstack-poisoned]
2025-04-11 18:27:39 +08:00
760f4fb105 Update
[ghstack-poisoned]
2025-04-11 18:16:46 +08:00
d1aba50677 Update (base update)
[ghstack-poisoned]
2025-04-11 18:16:46 +08:00
98984e1561 Update
[ghstack-poisoned]
2025-04-10 17:25:15 +08:00
46bb41bd37 Update (base update)
[ghstack-poisoned]
2025-04-10 17:25:15 +08:00
246b2fd7a0 Update
[ghstack-poisoned]
2025-04-07 22:41:39 +08:00
d90e80dd35 Update (base update)
[ghstack-poisoned]
2025-04-07 22:41:39 +08:00
1fa721545a Update
[ghstack-poisoned]
2025-04-05 23:26:58 +08:00
8c13a8323a Update (base update)
[ghstack-poisoned]
2025-04-05 23:26:58 +08:00
f0e9ee0bdc Update
[ghstack-poisoned]
2025-04-03 23:11:59 +08:00
f649b7bfbd Update (base update)
[ghstack-poisoned]
2025-04-03 23:11:59 +08:00
6e6343130f Update
[ghstack-poisoned]
2025-04-03 22:22:53 +08:00
45f34b99a9 Update (base update)
[ghstack-poisoned]
2025-04-03 22:22:53 +08:00
5df659dcf8 Update
[ghstack-poisoned]
2025-04-03 21:58:28 +08:00
53f09a5136 Update (base update)
[ghstack-poisoned]
2025-04-03 21:58:28 +08:00
2e6d995297 Update
[ghstack-poisoned]
2025-04-02 00:14:59 +08:00
aa65799ee0 Update (base update)
[ghstack-poisoned]
2025-04-02 00:14:58 +08:00
02a382b7be Update
[ghstack-poisoned]
2025-03-21 00:08:15 +08:00
6abac60294 Update (base update)
[ghstack-poisoned]
2025-03-21 00:08:15 +08:00
4037b4fc22 Update
[ghstack-poisoned]
2025-03-14 12:47:28 +08:00
af5bc4e801 Update (base update)
[ghstack-poisoned]
2025-03-14 12:47:28 +08:00
108d7f193a Update
[ghstack-poisoned]
2025-03-13 04:41:40 +08:00
a8fb34cae5 Update (base update)
[ghstack-poisoned]
2025-03-13 04:41:40 +08:00
7b122690be Update
[ghstack-poisoned]
2025-03-07 18:09:20 +08:00
4caf34ab53 Update (base update)
[ghstack-poisoned]
2025-03-07 18:09:20 +08:00
3409a1e033 Update
[ghstack-poisoned]
2025-03-07 03:57:21 +08:00
66681eea1b Update (base update)
[ghstack-poisoned]
2025-03-07 03:57:21 +08:00
b7838168c3 Update
[ghstack-poisoned]
2025-03-07 03:19:28 +08:00
26b6913b3d Update (base update)
[ghstack-poisoned]
2025-03-07 03:19:28 +08:00
ca192e08bd Update
[ghstack-poisoned]
2025-03-06 21:45:27 +08:00
88a13fdc94 Update (base update)
[ghstack-poisoned]
2025-03-06 21:45:27 +08:00
d099b63055 Update
[ghstack-poisoned]
2025-03-05 20:36:42 +08:00
5fd8eb6fb8 Update (base update)
[ghstack-poisoned]
2025-03-05 20:36:41 +08:00
a4a2c1cffb Update
[ghstack-poisoned]
2025-03-05 20:19:47 +08:00
400df72bda Update (base update)
[ghstack-poisoned]
2025-03-05 20:19:47 +08:00
33f1963cc4 Update
[ghstack-poisoned]
2025-03-05 20:14:48 +08:00
ac68c29b4c Update (base update)
[ghstack-poisoned]
2025-03-05 20:14:48 +08:00
2832a51c54 Update
[ghstack-poisoned]
2025-03-05 01:46:28 +08:00
3b3af30466 Update (base update)
[ghstack-poisoned]
2025-03-05 00:41:45 +08:00
49df7a7617 Update
[ghstack-poisoned]
2025-03-05 00:41:45 +08:00
92176bfdff Update (base update)
[ghstack-poisoned]
2025-03-04 22:37:24 +08:00
2f41576c02 Update
[ghstack-poisoned]
2025-03-04 22:37:24 +08:00
0b4314efe2 Update (base update)
[ghstack-poisoned]
2025-03-04 19:05:38 +08:00
f092c0e8ce Update
[ghstack-poisoned]
2025-03-04 19:05:38 +08:00
0c224957b6 Update (base update)
[ghstack-poisoned]
2025-03-04 17:15:30 +08:00
a912652a09 Update
[ghstack-poisoned]
2025-03-04 17:15:30 +08:00
afb24fc9c1 Update (base update)
[ghstack-poisoned]
2025-03-04 11:43:02 +08:00
7017bcda6e Update
[ghstack-poisoned]
2025-03-04 11:43:02 +08:00
98b5a2fc77 Update (base update)
[ghstack-poisoned]
2025-03-04 04:46:17 +08:00
c2e569992b Update
[ghstack-poisoned]
2025-03-04 04:46:17 +08:00
4cdcd94061 Update (base update)
[ghstack-poisoned]
2025-03-04 03:31:33 +08:00
b2727a655f Update
[ghstack-poisoned]
2025-03-04 03:31:33 +08:00
3ae5f28df9 Update (base update)
[ghstack-poisoned]
2025-03-04 03:09:37 +08:00
130df3a1d6 Update
[ghstack-poisoned]
2025-03-04 03:09:37 +08:00
ffe60c3005 Update (base update)
[ghstack-poisoned]
2025-03-04 02:44:32 +08:00
f03956e5ae Update
[ghstack-poisoned]
2025-03-04 02:44:32 +08:00
782d543bf7 Update (base update)
[ghstack-poisoned]
2025-03-01 21:29:32 +08:00
a5d05d3d4e Update
[ghstack-poisoned]
2025-03-01 21:29:32 +08:00
7915feda28 Update (base update)
[ghstack-poisoned]
2025-03-01 19:17:04 +08:00
7c8aeffc4a Update
[ghstack-poisoned]
2025-03-01 19:17:04 +08:00
0518f254ed Update
[ghstack-poisoned]
2025-03-01 03:09:57 +08:00
ef3adf6eac Update
[ghstack-poisoned]
2025-03-01 02:56:44 +08:00
c7e5b56a7d Update
[ghstack-poisoned]
2025-03-01 01:36:03 +08:00
86001ed575 Update (base update)
[ghstack-poisoned]
2025-03-01 00:37:37 +08:00
b8eadce989 Update
[ghstack-poisoned]
2025-03-01 00:37:37 +08:00
06a7899877 Update (base update)
[ghstack-poisoned]
2025-03-01 00:36:38 +08:00
4b8ceddea7 Update
[ghstack-poisoned]
2025-03-01 00:36:38 +08:00
d55bb26dc5 Update (base update)
[ghstack-poisoned]
2025-03-01 00:24:08 +08:00
c93a280d5a Update
[ghstack-poisoned]
2025-03-01 00:24:08 +08:00
0e45af0a50 Update (base update)
[ghstack-poisoned]
2025-03-01 00:03:44 +08:00
f62063d25c Update
[ghstack-poisoned]
2025-03-01 00:03:44 +08:00
06d1a7fa0b Update (base update)
[ghstack-poisoned]
2025-02-28 23:53:49 +08:00
00a6482df6 Update
[ghstack-poisoned]
2025-02-28 23:53:49 +08:00
1d6e7ada97 Update (base update)
[ghstack-poisoned]
2025-02-28 23:45:41 +08:00
308f6a05ac Update
[ghstack-poisoned]
2025-02-28 23:45:41 +08:00
4e773f0037 Update
[ghstack-poisoned]
2025-02-28 23:39:42 +08:00
672915aece Update (base update)
[ghstack-poisoned]
2025-02-28 22:39:24 +08:00
08affe6664 Update
[ghstack-poisoned]
2025-02-28 22:39:24 +08:00
c36201b8b2 Update
[ghstack-poisoned]
2025-02-28 22:11:12 +08:00
347e4a1001 Update (base update)
[ghstack-poisoned]
2025-02-28 20:16:12 +08:00
c5f436b865 Update
[ghstack-poisoned]
2025-02-28 20:16:12 +08:00
51232f8496 Update
[ghstack-poisoned]
2025-02-28 19:35:17 +08:00
a9a875cb5c Update (base update)
[ghstack-poisoned]
2025-02-28 18:51:42 +08:00
3ae649453b Update
[ghstack-poisoned]
2025-02-28 18:51:42 +08:00
544 changed files with 6780 additions and 13466 deletions

View File

@ -195,16 +195,13 @@ case "$tag" in
NINJA_VERSION=1.9.0
TRITON=yes
;;
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
pytorch-linux-jammy-xpu-n-py3)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11
VISION=yes
XPU_VERSION=2025.2
NINJA_VERSION=1.9.0
TRITON=yes
if [[ $tag =~ "benchmarks" ]]; then
INDUCTOR_BENCHMARKS=yes
fi
;;
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10

View File

@ -49,20 +49,12 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
export SYSROOT_DEP="sysroot_linux-64=2.17"
fi
# Install correct Python version
# Also ensure sysroot is using a modern GLIBC to match system compilers
if [ "$ANACONDA_PYTHON_VERSION" = "3.14" ]; then
as_jenkins conda create -n py_$ANACONDA_PYTHON_VERSION -y\
python="3.14.0" \
${SYSROOT_DEP} \
-c conda-forge
else
# Install correct Python version
# Also ensure sysroot is using a modern GLIBC to match system compilers
as_jenkins conda create -n py_$ANACONDA_PYTHON_VERSION -y\
python="$ANACONDA_PYTHON_VERSION" \
${SYSROOT_DEP}
fi
# libstdcxx from conda default channels are too old, we need GLIBCXX_3.4.30
# which is provided in libstdcxx 12 and up.
conda_install libstdcxx-ng=12.3.0 --update-deps -c conda-forge

View File

@ -40,7 +40,11 @@ EOF
# Default url values
rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}"
amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu"
# Add amdgpu repository
UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'`
echo "deb [arch=amd64] ${amdgpu_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list
# Add rocm repository
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -

View File

@ -12,8 +12,8 @@ function do_install() {
rocm_version_nodot=${rocm_version//./}
# post merge of https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
# https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2"
rocm_dir="/opt/rocm"

View File

@ -138,12 +138,10 @@ numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
#test_binary_ufuncs.py
numpy==1.22.4; python_version == "3.10"
numpy==1.26.2; python_version == "3.11" or python_version == "3.12"
numpy==2.1.2; python_version >= "3.13" and python_version < "3.14"
numpy==2.3.4; python_version >= "3.14"
numpy==2.1.2; python_version >= "3.13"
pandas==2.0.3; python_version < "3.13"
pandas==2.2.3; python_version >= "3.13" and python_version < "3.14"
pandas==2.3.3; python_version >= "3.14"
pandas==2.2.3; python_version >= "3.13"
#onnxruntime
#Description: scoring engine for Open Neural Network Exchange (ONNX) models
@ -155,8 +153,7 @@ opt-einsum==3.3
#Pinned versions: 3.3
#test that import: test_linalg.py
optree==0.13.0 ; python_version < "3.14"
optree==0.17.0 ; python_version >= "3.14"
optree==0.13.0
#Description: A library for tree manipulation
#Pinned versions: 0.13.0
#test that import: test_vmap.py, test_aotdispatch.py, test_dynamic_shapes.py,
@ -255,8 +252,7 @@ scikit-image==0.22.0
#test that import:
scipy==1.10.1 ; python_version <= "3.11"
scipy==1.14.1 ; python_version > "3.11" and python_version < "3.14"
scipy==1.16.2 ; python_version >= "3.14"
scipy==1.14.1 ; python_version >= "3.12"
# Pin SciPy because of failing distribution tests (see #60347)
#Description: scientific python
#Pinned versions: 1.10.1
@ -328,8 +324,7 @@ pywavelets==1.7.0 ; python_version >= "3.12"
#Pinned versions: 1.4.1
#test that import:
lxml==5.3.0 ; python_version < "3.14"
lxml==6.0.2 ; python_version >= "3.14"
lxml==5.3.0
#Description: This is a requirement of unittest-xml-reporting
PyGithub==2.3.0
@ -339,9 +334,7 @@ sympy==1.13.3
#Pinned versions:
#test that import:
onnx==1.19.1 ; python_version < "3.14"
# Unpin once Python 3.14 is supported. See onnxruntime issue 26309.
onnx==1.18.0 ; python_version == "3.14"
onnx==1.19.1
#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:
@ -366,7 +359,7 @@ pwlf==2.2.1
#test that import: test_sac_estimator.py
# To build PyTorch itself
pyyaml==6.0.3
pyyaml==6.0.2
pyzstd
setuptools==78.1.1
packaging==23.1

View File

@ -54,15 +54,12 @@ ENV OPENSSL_DIR /opt/openssl
RUN rm install_openssl.sh
ARG INDUCTOR_BENCHMARKS
ARG ANACONDA_PYTHON_VERSION
ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION
COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
COPY ci_commit_pins/timm.txt timm.txt
COPY ci_commit_pins/torchbench.txt torchbench.txt
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt
# Install XPU Dependencies
ARG XPU_VERSION

View File

@ -6,7 +6,7 @@ dependencies = [
"GitPython==3.1.45",
"docker==7.1.0",
"pytest==7.3.2",
"uv==0.9.6"
"uv==0.9.5"
]
[tool.setuptools]

View File

@ -1,7 +1,7 @@
SHELL=/usr/bin/env bash
DOCKER_CMD ?= docker
DESIRED_ROCM ?= 7.1
DESIRED_ROCM ?= 7.0
DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM))
PACKAGE_NAME = magma-rocm
# inherit this from underlying docker image, do not pass this env var to docker
@ -16,7 +16,6 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
magma-rocm/build_magma.sh
.PHONY: all
all: magma-rocm71
all: magma-rocm70
all: magma-rocm64
@ -25,11 +24,6 @@ clean:
$(RM) -r magma-*
$(RM) -r output
.PHONY: magma-rocm71
magma-rocm71: DESIRED_ROCM := 7.1
magma-rocm71:
$(DOCKER_RUN)
.PHONY: magma-rocm70
magma-rocm70: DESIRED_ROCM := 7.0
magma-rocm70:

View File

@ -6,8 +6,8 @@ set -eou pipefail
# The script expects DESIRED_CUDA and PACKAGE_NAME to be set
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
# post merge of https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
# https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
# Folders for the build
PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata
@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE
# Fetch magma sources and verify checksum
pushd ${PACKAGE_DIR}
git clone https://github.com/icl-utk-edu/magma
git clone https://github.com/jeffdaily/magma
pushd magma
git checkout ${MAGMA_VERSION}
popd

View File

@ -426,7 +426,7 @@ fi
if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then
# export test times so that potential sharded tests that'll branch off this build will use consistent data
# don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build
PYTHONPATH=. python tools/stats/export_test_times.py
python tools/stats/export_test_times.py
fi
# don't do this for bazel or s390x or riscv64 as they don't use sccache
if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then

View File

@ -572,8 +572,6 @@ fi
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
DYNAMO_BENCHMARK_FLAGS+=(--device cpu)
elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
DYNAMO_BENCHMARK_FLAGS+=(--device xpu)
else
DYNAMO_BENCHMARK_FLAGS+=(--device cuda)
fi
@ -667,8 +665,6 @@ test_perf_for_dashboard() {
device=cuda_b200
elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
device=rocm
elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
device=xpu
fi
for mode in "${modes[@]}"; do
@ -1761,7 +1757,7 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
else
# Do this after checkout_install_torchbench to ensure we clobber any
# nightlies that torchbench may pull in
if [[ "${TEST_CONFIG}" != *cpu* && "${TEST_CONFIG}" != *xpu* ]]; then
if [[ "${TEST_CONFIG}" != *cpu* ]]; then
install_torchrec_and_fbgemm
fi
PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id"

View File

@ -27,9 +27,7 @@ runs:
docker system prune -af
diskspace_new=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //')
if [[ "$diskspace_new" -gt "$diskspace_cutoff" ]] ; then
diskspace_cutoff_int=$((diskspace_cutoff + 0))
difference=$((100 - diskspace_cutoff_int))
echo "Error: Available diskspace is less than $difference percent. Not enough diskspace."
echo "Error: Available diskspace is less than $diskspace_cutoff percent. Not enough diskspace."
echo "$msg"
exit 1
else

View File

@ -1 +1 @@
3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2
69bbe7363897764f9e758d851cd0340147d27f94

View File

@ -19,7 +19,6 @@ ciflow_push_tags:
- ciflow/inductor-perf-test-nightly-rocm-mi300
- ciflow/inductor-perf-test-nightly-rocm-mi355
- ciflow/inductor-perf-test-nightly-x86-zen
- ciflow/inductor-perf-test-nightly-xpu
- ciflow/inductor-periodic
- ciflow/inductor-rocm
- ciflow/linux-aarch64
@ -27,7 +26,6 @@ ciflow_push_tags:
- ciflow/nightly
- ciflow/op-benchmark
- ciflow/periodic
- ciflow/periodic-rocm-mi200
- ciflow/periodic-rocm-mi300
- ciflow/pull
- ciflow/quantization-periodic

View File

@ -11,17 +11,11 @@ architectures:
* Latest XPU
"""
import json
import os
import re
from pathlib import Path
from typing import Optional
SCRIPT_DIR = Path(__file__).absolute().parent
REPO_ROOT = SCRIPT_DIR.parent.parent
# NOTE: Please also update the CUDA sources in `PIP_SOURCES` in tools/nightly.py when changing this
CUDA_ARCHES = ["12.6", "12.8", "12.9", "13.0"]
CUDA_STABLE = "12.8"
CUDA_ARCHES_FULL_VERSION = {
@ -37,7 +31,8 @@ CUDA_ARCHES_CUDNN_VERSION = {
"13.0": "9",
}
ROCM_ARCHES = ["7.0", "7.1"]
# NOTE: Please also update the ROCm sources in `PIP_SOURCES` in tools/nightly.py when changing this
ROCM_ARCHES = ["6.4", "7.0"]
XPU_ARCHES = ["xpu"]
@ -142,48 +137,9 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
}
# Used by tools/nightly.py
PYTORCH_NIGHTLY_PIP_INDEX_URL = "https://download.pytorch.org/whl/nightly"
NIGHTLY_SOURCE_MATRIX = {
"cpu": dict(
name="cpu",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cpu",
supported_platforms=["Linux", "macOS", "Windows"],
accelerator="cpu",
)
}
CUDA_NIGHTLY_SOURCE_MATRIX = {
f"cuda-{major}.{minor}": dict(
name=f"cuda-{major}.{minor}",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu{major}{minor}",
supported_platforms=["Linux", "Windows"],
accelerator="cuda",
)
for major, minor in (map(int, version.split(".")) for version in CUDA_ARCHES)
}
ROCM_NIGHTLY_SOURCE_MATRIX = {
f"rocm-{major}.{minor}": dict(
name=f"rocm-{major}.{minor}",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/rocm{major}.{minor}",
supported_platforms=["Linux"],
accelerator="rocm",
)
for major, minor in (map(int, version.split(".")) for version in ROCM_ARCHES)
}
XPU_NIGHTLY_SOURCE_MATRIX = {
"xpu": dict(
name="xpu",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/xpu",
supported_platforms=["Linux"],
accelerator="xpu",
)
}
NIGHTLY_SOURCE_MATRIX.update(CUDA_NIGHTLY_SOURCE_MATRIX)
NIGHTLY_SOURCE_MATRIX.update(ROCM_NIGHTLY_SOURCE_MATRIX)
NIGHTLY_SOURCE_MATRIX.update(XPU_NIGHTLY_SOURCE_MATRIX)
def get_nccl_wheel_version(arch_version: str) -> str:
import re
requirements = map(
str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version])
)
@ -191,14 +147,17 @@ def get_nccl_wheel_version(arch_version: str) -> str:
def read_nccl_pin(arch_version: str) -> str:
nccl_pin_path = (
REPO_ROOT
/ ".ci"
/ "docker"
/ "ci_commit_pins"
/ f"nccl-cu{arch_version[:2]}.txt"
from pathlib import Path
nccl_pin_path = os.path.join(
Path(__file__).absolute().parents[2],
".ci",
"docker",
"ci_commit_pins",
f"nccl-cu{arch_version[:2]}.txt",
)
return nccl_pin_path.read_text().strip()
with open(nccl_pin_path) as f:
return f.read().strip()
def validate_nccl_dep_consistency(arch_version: str) -> None:
@ -206,8 +165,7 @@ def validate_nccl_dep_consistency(arch_version: str) -> None:
wheel_ver = get_nccl_wheel_version(arch_version)
if not nccl_release_tag.startswith(f"v{wheel_ver}"):
raise RuntimeError(
f"{arch_version} NCCL release tag version {nccl_release_tag} "
f"does not correspond to wheel version {wheel_ver}"
f"{arch_version} NCCL release tag version {nccl_release_tag} does not correspond to wheel version {wheel_ver}"
)
@ -454,14 +412,7 @@ def generate_wheels_matrix(
return ret
arch_version = ""
for arch_version in CUDA_ARCHES:
validate_nccl_dep_consistency(arch_version)
del arch_version
if __name__ == "__main__":
# Used by tools/nightly.py
(SCRIPT_DIR / "nightly_source_matrix.json").write_text(
json.dumps(NIGHTLY_SOURCE_MATRIX, indent=4) + "\n"
)
validate_nccl_dep_consistency("13.0")
validate_nccl_dep_consistency("12.9")
validate_nccl_dep_consistency("12.8")
validate_nccl_dep_consistency("12.6")

View File

@ -38,10 +38,6 @@ on:
default: ""
description: |
List of tests to include (empty string implies default list)
dashboard-tag:
required: false
type: string
default: ""
disable-monitor:
description: |
[Experimental] Disable utilization monitoring for tests.
@ -62,11 +58,6 @@ on:
required: false
type: number
default: 1
secrets:
HUGGING_FACE_HUB_TOKEN:
required: false
description: |
HF Auth token to avoid rate limits when downloading models or datasets from hub
permissions:
id-token: write
contents: read
@ -205,8 +196,6 @@ jobs:
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }}
DASHBOARD_TAG: ${{ inputs.dashboard-tag }}
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }}
run: |
# Fetch aws credential from IMDs
@ -257,8 +246,6 @@ jobs:
-e PYTORCH_TEST_RERUN_DISABLED_TESTS \
-e TESTS_TO_INCLUDE \
-e ZE_AFFINITY_MASK \
-e HUGGING_FACE_HUB_TOKEN \
-e DASHBOARD_TAG \
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
--ulimit stack=10485760:83886080 \
--ulimit core=0 \

View File

@ -36,7 +36,7 @@ jobs:
runs-on: linux.9xlarge.ephemeral
strategy:
matrix:
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm7.0", "rocm7.1", "cpu"]
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.4", "rocm7.0", "cpu"]
steps:
- name: Build docker image
uses: pytorch/pytorch/.github/actions/binary-docker-build@main

View File

@ -52,8 +52,8 @@ jobs:
{ tag: "cuda12.9" },
{ tag: "cuda12.8" },
{ tag: "cuda12.6" },
{ tag: "rocm6.4" },
{ tag: "rocm7.0" },
{ tag: "rocm7.1" },
{ tag: "cpu" },
]
steps:

View File

@ -34,7 +34,7 @@ jobs:
id-token: write
strategy:
matrix:
rocm_version: ["71", "70"]
rocm_version: ["70", "64"]
steps:
- name: Checkout PyTorch
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

View File

@ -54,8 +54,8 @@ jobs:
{ name: "manylinuxaarch64-builder", tag: "cuda12.9", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm6.4", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm7.0", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm7.1", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "cpu", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28_aarch64-builder", tag: "cpu-aarch64", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "xpu", runner: "linux.9xlarge.ephemeral" },

View File

@ -55,7 +55,7 @@ jobs:
docker-image: ["pytorch/manylinux2_28-builder:cpu"]
include:
- device: "rocm"
rocm_version: "7.1"
rocm_version: "7.0"
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
- device: "cuda"
rocm_version: ""

View File

@ -57,7 +57,6 @@ jobs:
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
pytorch-linux-jammy-py3.10-clang12,
pytorch-linux-jammy-py3.13-clang12,
pytorch-linux-jammy-py3.14-clang12,
pytorch-linux-jammy-rocm-n-py3,
pytorch-linux-noble-rocm-n-py3,
pytorch-linux-jammy-rocm-n-py3-benchmarks,
@ -67,7 +66,6 @@ jobs:
pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-xpu-n-1-py3,
pytorch-linux-jammy-xpu-n-py3,
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
pytorch-linux-jammy-py3-clang18-asan,
pytorch-linux-jammy-py3-clang12-onnx,
pytorch-linux-jammy-linter,

View File

@ -384,6 +384,124 @@ jobs:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-rocm6_4-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
timeout-minutes: 300
build_name: libtorch-rocm6_4-shared-with-deps-release
build_environment: linux-binary-libtorch
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-rocm6_4-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-rocm6_4-shared-with-deps-release-build
- get-label-type
runs-on: linux.rocm.gpu.mi250
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-rocm6_4-shared-with-deps-release
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: configure aws credentials
id: aws_creds
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
docker-image-name: libtorch-cxx11-builder
custom-tag-prefix: rocm6.4
docker-build-dir: .ci/docker
working-directory: pytorch
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
env:
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
libtorch-rocm6_4-shared-with-deps-release-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-rocm6_4-shared-with-deps-release-test
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
build_name: libtorch-rocm6_4-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-rocm7_0-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
@ -501,121 +619,3 @@ jobs:
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-rocm7_1-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm7.1
GPU_ARCH_VERSION: "7.1"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
timeout-minutes: 300
build_name: libtorch-rocm7_1-shared-with-deps-release
build_environment: linux-binary-libtorch
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-rocm7_1-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-rocm7_1-shared-with-deps-release-build
- get-label-type
runs-on: linux.rocm.gpu.mi250
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm7.1
GPU_ARCH_VERSION: "7.1"
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-rocm7_1-shared-with-deps-release
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: configure aws credentials
id: aws_creds
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
docker-image-name: libtorch-cxx11-builder
custom-tag-prefix: rocm7.1
docker-build-dir: .ci/docker
working-directory: pytorch
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
env:
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
libtorch-rocm7_1-shared-with-deps-release-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-rocm7_1-shared-with-deps-release-test
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm7.1
GPU_ARCH_VERSION: "7.1"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
build_name: libtorch-rocm7_1-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml

File diff suppressed because it is too large Load Diff

View File

@ -1,148 +0,0 @@
name: inductor-perf-nightly-xpu
on:
push:
tags:
- ciflow/inductor-perf-test-nightly-xpu/*
schedule:
- cron: 30 17 * * *
workflow_dispatch:
inputs:
training:
description: Run training (on by default)?
required: false
type: boolean
default: true
inference:
description: Run inference (on by default)?
required: false
type: boolean
default: true
default:
description: Run inductor_default?
required: false
type: boolean
default: false
dynamic:
description: Run inductor_dynamic_shapes?
required: false
type: boolean
default: false
cppwrapper:
description: Run inductor_cpp_wrapper?
required: false
type: boolean
default: false
cudagraphs:
description: Run inductor_cudagraphs?
required: false
type: boolean
default: false
freezing_cudagraphs:
description: Run inductor_cudagraphs with freezing for inference?
required: false
type: boolean
default: false
aotinductor:
description: Run aot_inductor for inference?
required: false
type: boolean
default: false
maxautotune:
description: Run inductor_max_autotune?
required: false
type: boolean
default: false
benchmark_configs:
description: The list of configs used the benchmark
required: false
type: string
default: inductor_huggingface_perf,inductor_timm_perf,inductor_torchbench_perf,cachebench
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
opt_out_experiments: lf
xpu-n-py3_10-inductor-benchmark-build:
name: xpu-n-py3.10-inductor-benchmark
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_xpu", shard: 1, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 2, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 3, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 4, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 5, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
]}
secrets: inherit
xpu-n-py3_10-inductor-benchmark-test-nightly:
permissions:
id-token: write
contents: read
if: github.event_name != 'workflow_dispatch'
name: xpu-n-py3.10-inductor-benchmark
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-jammy-xpu-n-py3.10
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
timeout-minutes: 720
# Disable monitor in perf tests for more investigation
disable-monitor: true
monitor-log-interval: 10
monitor-data-collect-interval: 2
secrets: inherit
xpu-n-py3_10-inductor-benchmark-test:
permissions:
id-token: write
contents: read
if: github.event_name == 'workflow_dispatch'
name: xpu-n-py3.10-inductor-test
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-jammy-xpu-n-py3.10
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
timeout-minutes: 720
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit

View File

@ -1,84 +0,0 @@
name: periodic-rocm-mi200
on:
schedule:
# We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
# Also run less frequently on weekends.
- cron: 45 0,8,16 * * 1-5
- cron: 45 4 * * 0,6
- cron: 45 4,12,20 * * 1-5
- cron: 45 12 * * 0,6
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
push:
tags:
- ciflow/periodic/*
- ciflow/periodic-rocm-mi200/*
branches:
- release/*
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
llm-td:
if: github.repository_owner == 'pytorch'
name: before-test
uses: ./.github/workflows/llm_td_retrieval.yml
permissions:
id-token: write
contents: read
target-determination:
name: before-test
uses: ./.github/workflows/target_determination.yml
needs: llm-td
permissions:
id-token: write
contents: read
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-rocm-py3_10-build:
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit

View File

@ -204,6 +204,37 @@ jobs:
test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-rocm-py3_10-build:
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build:
name: linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck
uses: ./.github/workflows/_linux-build.yml

View File

@ -6,7 +6,6 @@ on:
- pull
- trunk
- periodic
- periodic-rocm-mi200
- periodic-rocm-mi300
- inductor
- unstable

1
.gitignore vendored
View File

@ -143,7 +143,6 @@ scripts/release_notes/*.json
sccache-stats*.json
lint.json
merge_record.json
.github/scripts/nightly_source_matrix.json
# These files get copied over on invoking setup.py
torchgen/packaged/*

View File

@ -183,6 +183,7 @@ include_patterns = [
'benchmarks/instruction_counts/**/*.py',
'tools/**/*.py',
'torchgen/**/*.py',
'torch/utils/pytree/__init__.py',
'torch/utils/_pytree.py',
'torch/utils/_cxx_pytree.py',
'torch/utils/benchmark/utils/common.py',

View File

@ -374,7 +374,7 @@ cmake_dependent_option(
"Build the lazy Torchscript backend, not compatible with mobile builds" ON
"NOT INTERN_BUILD_MOBILE" OFF)
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin folder"
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
OFF "USE_CUDA" OFF)
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
"CPU_AARCH64" OFF)

View File

@ -195,6 +195,7 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
/torch/utils/_pytree.py @XuehaiPan
/torch/utils/_cxx_pytree.py @XuehaiPan
/torch/utils/pytree/ @XuehaiPan
/torch/pytree.py @XuehaiPan
/torch/_dynamo/polyfills/pytree.py @XuehaiPan
# Relating to libtorch ABI

View File

@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")

View File

@ -354,9 +354,47 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
Vectorized frac() const;
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
#ifdef __ARM_FEATURE_BF16
Vectorized<c10::BFloat16> neg() const {
return -values;
}
Vectorized<c10::BFloat16> reciprocal() const {
return 1.0f / values;
}
Vectorized<c10::BFloat16> operator==(
const Vectorized<c10::BFloat16>& other) const {
return values == other.values;
}
Vectorized<c10::BFloat16> operator!=(
const Vectorized<c10::BFloat16>& other) const {
return values != other.values;
}
Vectorized<c10::BFloat16> operator<(
const Vectorized<c10::BFloat16>& other) const {
return values < other.values;
}
Vectorized<c10::BFloat16> operator<=(
const Vectorized<c10::BFloat16>& other) const {
return values <= other.values;
}
Vectorized<c10::BFloat16> operator>(
const Vectorized<c10::BFloat16>& other) const {
return values > other.values;
}
Vectorized<c10::BFloat16> operator>=(
const Vectorized<c10::BFloat16>& other) const {
return values >= other.values;
}
#else
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
@ -364,6 +402,7 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
#endif
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
@ -412,28 +451,52 @@ template <>
Vectorized<c10::BFloat16> inline operator+(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x + y;
#else
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator-(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x - y;
#else
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator*(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x * y;
#else
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator/(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x / y;
#else
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
#endif
}
// frac. Implement this here so we can use subtraction
@ -544,12 +607,19 @@ Vectorized<c10::BFloat16> inline fmadd(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y + z;
#else
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
// elements, not the bottom and top half, so they don't seem
// particularly useful here. Ideally we would include dot product in
// the Vectorized interface...
return a * b + c;
#endif
}
template <>
@ -557,8 +627,15 @@ Vectorized<c10::BFloat16> inline fnmadd(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y + z;
#else
// See NOTE [BF16 FMA] above.
return -a * b + c;
#endif
}
template <>
@ -566,8 +643,15 @@ Vectorized<c10::BFloat16> inline fmsub(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y - z;
#else
// See NOTE [BF16 FMA] above.
return a * b - c;
#endif
}
template <>
@ -575,8 +659,15 @@ Vectorized<c10::BFloat16> inline fnmsub(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y - z;
#else
// See NOTE [BF16 FMA] above.
return -a * b - c;
#endif
}
#endif // !defined(C10_MOBILE) && defined(__aarch64__)

View File

@ -309,7 +309,7 @@ class Vectorized<float> {
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
// Implementation copied from Arm Optimized Routine
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
inline Vectorized<float> vexpq_f32_u20() const {
Vectorized<float> exp_u20() const {
// bail out to sleef if it's a special case:
// i.e. there's an input s.t. |input| > 87.3....
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
@ -348,9 +348,6 @@ class Vectorized<float> {
return vfmaq_f32(scale, poly, scale);
}
Vectorized<float> exp_u20() const {
return vexpq_f32_u20();
}
Vectorized<float> fexp_u20() const {
return exp_u20();
}
@ -637,7 +634,7 @@ inline Vectorized<float> Vectorized<float>::erf() const {
// - exp(- x * x)
auto pow_2 = (*this) * (*this);
auto neg_pow_2 = pow_2 ^ neg_zero_vec;
auto tmp4 = neg_pow_2.vexpq_f32_u20();
auto tmp4 = neg_pow_2.exp();
auto tmp5 = tmp4 ^ neg_zero_vec;
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
auto tmp6 = t * tmp5;

View File

@ -7,6 +7,17 @@
#endif
#if defined(USE_ROCM)
// hipSparse const API added in v2.4.0
#if HIPSPARSE_VERSION >= 200400
#define AT_USE_HIPSPARSE_GENERIC_API() 1
#else
#define AT_USE_HIPSPARSE_GENERIC_API() 1
#endif
#else // USE_ROCM
#define AT_USE_HIPSPARSE_GENERIC_API() 0
#endif // USE_ROCM
// cuSparse Generic API spsv function was added in CUDA 11.3.0
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1

View File

@ -1,7 +1,6 @@
#include <ATen/cuda/CUDAContextLight.h>
#include <ATen/cuda/Sleep.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
@ -25,22 +24,8 @@ __global__ void spin_kernel(int64_t cycles) {
#endif
}
}
thread_local int *flag = nullptr;
__global__ void busy_wait_for_flag_kernel(int *flag) {
atomicExch(flag, 1);
while (atomicAdd(flag, 0) == 1) {
// do nothing
}
}
__global__ void clear_flag_kernel(int *flag) {
atomicExch(flag, 0);
}
} // anonymous namespace
void sleep(int64_t cycles) {
dim3 grid(1);
dim3 block(1);
@ -48,26 +33,6 @@ void sleep(int64_t cycles) {
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
void busy_wait_for_flag() {
if (!flag) {
flag = (int*)c10::cuda::CUDACachingAllocator::raw_alloc(sizeof(int));
}
dim3 grid(1);
dim3 block(1);
busy_wait_for_flag_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(flag);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
void clear_flag() {
if (!flag) {
flag = (int*)c10::cuda::CUDACachingAllocator::raw_alloc(sizeof(int));
}
dim3 grid(1);
dim3 block(1);
clear_flag_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(flag);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#ifdef USE_ROCM
__global__ void flush_icache_kernel()
{

View File

@ -7,11 +7,6 @@ namespace at::cuda {
// enqueues a kernel that spins for the specified number of cycles
TORCH_CUDA_CU_API void sleep(int64_t cycles);
// enqueues a kernel that spins until a flag is cleared by a
// corresponding call to clear_flag()
TORCH_CUDA_CU_API void busy_wait_for_flag();
TORCH_CUDA_CU_API void clear_flag();
// flushes instruction cache for ROCm; no-op for CUDA
TORCH_CUDA_CU_API void flush_icache();

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
@ -152,36 +151,6 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
}
virtual bool isAvailable() const override;
/* MTIAGraph related APIs */
virtual int64_t mtiagraphCreate(bool keep_graph = false) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return -1;
}
virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphCaptureEnd(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphInstantiate(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphReplay(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphReset(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual MempoolId_t mtiagraphPool(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
};
struct TORCH_API MTIAHooksArgs {};

View File

@ -410,8 +410,8 @@ struct ConvParams {
return false;
}
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
// broken on cuDNN 9.8 - 9.14
if (cudnn_version >= 90800 && cudnn_version < 91500) {
// broken on cuDNN 9.8
if (cudnn_version >= 90800) {
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
(input.scalar_type() == at::kBFloat16 || input.scalar_type() == at::kHalf) &&
weight.dim() == 5) {

View File

@ -50,18 +50,35 @@ static inline bool parseLinearFlatten3d() {
// `_flatten_nd_linear` flattens all but the last dimension of the input tensor
// before passing it to linear operation
static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
const auto input_sizes = input.sym_sizes();
// can't use -1 in reshape because it errors when a dimension is 0
c10::SymInt flattened_dim = 1;
for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
flattened_dim = flattened_dim * input_sizes[i];
const auto input_sizes = input.sym_sizes();
const auto result_flattened = [&]() -> Tensor {
const auto input_ncols = input_sizes.back();
const auto input_flattened_nrows = [&]() -> c10::SymInt {
// can't use -1 in reshape because it errors when a dimension is 0
auto flattened_nrows = c10::SymInt{1};
for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) {
flattened_nrows *= size;
}
return flattened_nrows;
}();
const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols});
if (weight.layout() == c10::kStrided) {
return at::addmm(bias, input_flattened, weight.t());
} else {
// weight is sparse, and addmm for sparse expects matmul lhs to be sparse,
// so we transpose the problem.
// NOTE: at::matmul handles (dense @ sparse) similarly.
const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1);
return at::addmm(bias_t, weight, input_flattened.t()).t();
}
auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
const auto result = at::addmm(bias, inp_reshape, weight.t());
auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
c10::SymDimVector sizes_vec(new_size.begin(), new_size.end());
sizes_vec.push_back(result.sym_size(1));
return result.view_symint(sizes_vec);
}();
// Unflatten flattened row dims
auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()};
result_sizes.back() = result_flattened.sym_size(1);
return result_flattened.view_symint(result_sizes);
}
@ -90,15 +107,23 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
// Fused op is marginally faster.
return at::addmm(*bias, input, weight.t());
}
if (bias->defined() && !input.is_xla()) {
// Also hit the fused path for contiguous 3D input, if not using xla
const auto is_bias_likely_fusable = (
bias->defined() &&
// cuBLASLt: will fuse in the epilogue without copies
// when input/weight/bias are all strided.
// When weight is not strided, bias will not be fused,
// but we can still dispatch here to avoid at::matmul
// path which will probably use a very similar
// flattening optimization.
(bias->dim() == 1 && bias->is_contiguous_or_false())
);
if (is_bias_likely_fusable && !input.is_xla()) {
// Also hit the fused path for contiguous nD input, if not using xla
// backend. Reshaping/flattening has some performance implications on xla.
bool is_contiguous = input.is_contiguous_or_false();
if (is_contiguous && input_dim == 3) {
if (input.is_contiguous_or_false()) {
return _flatten_nd_linear(input, weight, *bias);
} else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {
return _flatten_nd_linear(input, weight, *bias);
} else if (parseLinearFlatten3d() && input_dim == 3) {
} else if (parseLinearFlatten3d()) {
// If user forces flattening via env var
const Tensor input_cont = input.contiguous();
return _flatten_nd_linear(input_cont, weight, *bias);

View File

@ -170,14 +170,10 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
#if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous()
// Conditions for bias to be fusable
&& (
self.is_contiguous() &&
// NOTE: fine to have 1-len dims to the left from the right-most one
(self.dim() == 1 || self.squeeze().dim() == 1) &&
self.sizes().back() == mat2_sizes[1]
)
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||

View File

@ -22,6 +22,9 @@
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#ifdef USE_ROCM
#include <ATen/native/hip/ck_group_gemm.h>
#endif
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
@ -213,9 +216,9 @@ _f4_f4_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const std::optional<Tensor>& global_scale_a,
const Tensor& global_scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& global_scale_b,
const Tensor& global_scale_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
Tensor& out) {
@ -225,28 +228,14 @@ _f4_f4_bf16_grouped_mm_fbgemm(
"mat_a must be Float4_e2n1fn_2, got: ", mat_a.scalar_type());
TORCH_CHECK_VALUE(mat_b.scalar_type() == at::kFloat4_e2m1fn_x2,
"mat_b must be Float4_e2n1fn_2, got: ", mat_b.scalar_type());
std::optional<Tensor> combined_global_scale = std::nullopt;
if (global_scale_a.has_value() || global_scale_b.has_value()) {
// NVFP4
TORCH_CHECK_VALUE(global_scale_a.has_value() && global_scale_b.has_value(),
"For NVFP4 grouped gemm both of global_scale_{a,b} must have values")
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
TORCH_CHECK_VALUE(global_scale_a.value().scalar_type() == at::kFloat,
"global_scale_a must be Float, got: ", global_scale_a.value().scalar_type());
TORCH_CHECK_VALUE(global_scale_b.value().scalar_type() == at::kFloat,
"global_scale_b must be Float, got: ", global_scale_b.value().scalar_type());
combined_global_scale = global_scale_a.value().mul(global_scale_b.value());
} else {
// MXFP4
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu,
"scale_a must be Float8_e8m0fnu, got: ", scale_a.scalar_type());
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e8m0fnu,
"scale_b must be Float8_e8m0fnu, got: ", scale_b.scalar_type());
}
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
TORCH_CHECK_VALUE(global_scale_a.scalar_type() == at::kFloat,
"global_scale_a must be Float, got: ", global_scale_a.scalar_type());
TORCH_CHECK_VALUE(global_scale_b.scalar_type() == at::kFloat,
"global_scale_b must be Float, got: ", global_scale_b.scalar_type());
auto o = fbgemm_gpu::f4f4bf16_grouped_mm(
mat_a,
@ -255,7 +244,7 @@ _f4_f4_bf16_grouped_mm_fbgemm(
scale_b,
offs.value(),
out,
combined_global_scale
global_scale_a.mul(global_scale_b)
);
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_FBGEMM_GENAI, and only for CUDA")
@ -485,10 +474,9 @@ namespace {
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 4> scale_grouped_kernel_dispatch = {{
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 3> scale_grouped_kernel_dispatch = {{
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
{ "mxfp4_mxfp4", scaled_blas::check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4},
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}}};
} // anonymous namespace
@ -614,21 +602,6 @@ _scaled_grouped_mm_cuda_v2(
offs.value(),
out);
}
case ScaledGemmImplementation::MXFP4_MXFP4: {
// scale shape checks
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _f4_f4_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a[0], /* block-scale A */
std::nullopt, /* global-scale A */
scale_b[0], /* block-scale B */
std::nullopt, /* global-scale B */
offs.value(),
std::nullopt, /* bias */
out);
}
case ScaledGemmImplementation::NVFP4_NVFP4: {
// scale shape checks
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
@ -666,12 +639,19 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}

View File

@ -13,7 +13,7 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx,
if (allow_neg_indices) {
ind = (ind < 0) ? ind + ind_dim_size : ind;
}
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind);
CUDA_KERNEL_ASSERT(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
if (off >= slice_size) return;
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);

View File

@ -59,22 +59,6 @@
// forward declare
class cublasCommonArgs;
namespace fbgemm_gpu {
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
// To update supported ops means a submodule bump, which is.. painful. Instead, we
// can simply forward-declare the methods we want to use.. Works at least as a short-term
// thing, but should still be fixed somewhere/somehow.
at::Tensor f4f4bf16(
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
std::optional<at::Tensor>,
bool use_mx);
} // namespace fbgemm_gpu
using at::blas::ScalingType;
using at::blas::SwizzleType;
@ -810,24 +794,6 @@ void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const Sc
}
}
void
_check_deepseek_support() {
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major != 9) {
// Only on Hopper GPUs
TORCH_CHECK_NOT_IMPLEMENTED(
dprops->major == 9,
"DeepSeek style (1x128, 128x128) scaling only supported in CUDA for SM90")
}
// Only in cublasLt >= 12.9
TORCH_CHECK_NOT_IMPLEMENTED(
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
);
#endif
}
Tensor&
_scaled_block1x128_block1x128(
const Tensor& mat_a, const Tensor& mat_b,
@ -836,12 +802,8 @@ _scaled_block1x128_block1x128(
const c10::ScalarType out_dtype,
const bool use_fast_accum,
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
// CUDA: Only Hopper GPUs
_check_deepseek_support();
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
@ -859,12 +821,6 @@ _scaled_block1x128_block1x128(
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"1x128 and 128x128 scaling not available with ROCm"
);
#endif
}
Tensor&
@ -875,12 +831,10 @@ _scaled_block128x128_block1x128(
const c10::ScalarType out_dtype,
const bool use_fast_accum,
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
// CUDA: Only Hopper GPUs
_check_deepseek_support();
std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl;
std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl;
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
@ -898,12 +852,6 @@ _scaled_block128x128_block1x128(
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"1x128 and 128x128 scaling not available with ROCm"
);
#endif
}
Tensor&
@ -914,12 +862,8 @@ _scaled_block1x128_block128x128(
const c10::ScalarType out_dtype,
const bool use_fast_accum,
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
// CUDA: Only Hopper GPUs
_check_deepseek_support();
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
@ -937,12 +881,6 @@ _scaled_block1x128_block128x128(
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"1x128 and 128x128 scaling not available with ROCm"
);
#endif
}
Tensor&
@ -1013,47 +951,26 @@ _scaled_mxfp4_mxfp4(
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
Tensor& out) {
#if !defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI)
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#ifndef USE_ROCM
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
#endif
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
auto K_multiplier = 2;
#ifdef USE_ROCM
// AMD
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
#else
// NVIDIA
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
#endif
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
#ifdef USE_ROCM
// AMD
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
#else
// NVIDIA
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
#endif
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
#ifdef USE_ROCM
// AMD
auto scaling_choice_a = ScalingType::BlockWise1x32;
auto scaling_choice_b = ScalingType::BlockWise1x32;
@ -1068,29 +985,11 @@ _scaled_mxfp4_mxfp4(
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
#else
// NVIDIA
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
// but we have one we need to use. Two clear options are to copy into
// our output (slow), or use a move-assignment-operator (faster).
// However, the compiler can complain about the explicit move preventing
// copy elision because the return from f4f4bf16 is a temporary object.
// So we don't explicitly move, and trust the compiler here...
// In the longer term this should be fixed on the FBGemm side.
out = fbgemm_gpu::f4f4bf16(
mat_a,
mat_b.transpose(-2, -1),
scale_a,
scale_b,
std::nullopt, /* global_scale */
true /* use_mx */
);
return out;
#endif
}
Tensor&
@ -1215,20 +1114,17 @@ _scaled_mm_cuda_v2_out(
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
}
// Handle fp4 packed-K dimension
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
" but got ", bias->numel());
TORCH_CHECK_VALUE(
K_multiplier * mat_a.sizes()[1] % 16 == 0,
mat_a.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes()[0],
"x",
K_multiplier * mat_a.sizes()[1],
mat_a.sizes()[1],
").");
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
mat_b.sizes()[1], ") must be divisible by 16");
// TODO(slayton): Existing checks, not sure if they should really be here.

View File

@ -160,8 +160,8 @@ struct _cuda_scatter_gather_internal_kernel {
auto offsets = offset_calc.get(i);
int64_t idx_dim = *(index_t*)(index_ptr + offsets[2]);
CUDA_KERNEL_ASSERT_VERBOSE(idx_dim >= 0 && idx_dim < index_size
&& "scatter gather kernel index out of bounds", "Expected 0 <= idx_dim < index_size (%ld), but got idx_dim = %ld", index_size, idx_dim);
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
&& "scatter gather kernel index out of bounds");
f(
(scalar_t*)(self_ptr + offsets[0]),
@ -406,8 +406,9 @@ struct _cuda_scatter_fill_internal_kernel {
auto offsets = offset_calc.get(i);
int64_t idx_dim = *(index_t*)(index_ptr + offsets[1]);
CUDA_KERNEL_ASSERT_VERBOSE(idx_dim >= 0 && idx_dim < index_size
&& "index out of bounds", "Expected 0 <= idx_dim < index_size (%ld), but got idx_dim = %ld", index_size, idx_dim);
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
&& "index out of bounds"
);
f(
(scalar_t*)(self_ptr + offsets[0]),

View File

@ -141,8 +141,7 @@ WelfordDataLN cuWelfordOnlineSum(
if constexpr (!rms_norm){
U delta = val - curr_sum.mean;
U new_count = curr_sum.count + 1.f;
//Due to low CU count, we run into accuracy issues on gfx90a with `__builtin_amdgcn_rcpf`
#if defined(USE_ROCM) && !defined(__gfx90a__) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
#else
U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
@ -164,8 +163,7 @@ WelfordDataLN cuWelfordCombine(
U count = dataA.count + dataB.count;
U mean, sigma2;
if (count > decltype(dataB.count){0}) {
//Due to low CU count, we run into accuracy issues on gfx90a with `__builtin_amdgcn_rcpf`
#if defined(USE_ROCM) && !defined(__gfx90a__) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
auto coef = __builtin_amdgcn_rcpf(count);
#else
auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division

View File

@ -0,0 +1,19 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <optional>
namespace at {
namespace hip {
namespace detail {
void group_gemm_ck(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
at::Tensor& out);
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -0,0 +1,458 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/hip/HIPContext.h>
#include <ATen/Tensor.h>
#include <ATen/TensorAccessor.h>
#include <c10/hip/HIPStream.h>
#include <iostream>
#include <vector>
#include <optional>
#include <type_traits>
#include <ck/ck.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/utility/tuple.hpp>
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
namespace at {
namespace hip {
namespace detail {
namespace CkTypes {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
}
template <typename ALayout, typename BLayout, typename DataType>
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
1, 1,
S<1,32,1,8>, 4
>;
template <typename ALayout, typename BLayout, typename DataType>
void launch_grouped_bgemm_ck_impl_dispatch(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
at::Tensor& out)
{
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
using PassThrough = CkTypes::PassThrough;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a_ptrs, p_b_ptrs;
std::vector<void*> p_e_ptrs;
// Note: d_ptrs will be resized after we populate the other vectors
const int mat_a_dim = mat_a.dim();
const int mat_b_dim = mat_b.dim();
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
const size_t a_element_size = mat_a.element_size();
const size_t b_element_size = mat_b.element_size();
const size_t out_element_size = out.element_size();
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
if (mat_a_dim == 2 && mat_b_dim == 2) {
// 2D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
const int M = mat_a.size(0); // number of rows in A
const int N = mat_b.size(1); // number of columns in B
const int K = mat_a.size(1); // columns in A == rows in B
// for 2d*2d input, output is 3d.
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
for (int i = 0; i < num_groups; ++i) {
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
int end_k = offs_accessor[i];
int k = end_k - start_k;
//K dimension are sliced, hence select stride(1) always.
//K dimension is always dimension 1, regardless of memory layout (row/column major)
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
const void* group_b_ptr;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
// Leading dimension = distance between rows = stride(0)
ldb = mat_b.stride(0);
} else {
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
// Leading dimension = distance between columns = stride(1)
ldb = mat_b.stride(1);
}
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
lda = mat_a.stride(0);
} else {
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
lda = mat_a.stride(1);
}
// Output is always row-major in 3D tensor [num_groups, M, N]
// Leading dimension for each group's [M,N] slice = stride(1) = N
ldc = out.stride(1);
size_t output_group_bytes = M * N * out_element_size;
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(k),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc)
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
// 2D*3D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 2d*3d input, output is 2d.
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
const int K = mat_a.size(1); // columns in A
// For 2D-3D case: The output determines N (result width)
const int N = out.size(1); // N is the width of the output tensor
for (int i = 0; i < num_groups; ++i) {
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
int end_m = offs_accessor[i];
int m = end_m - start_m;
// Skip zero-sized groups but continue processing subsequent groups
if (m <= 0) {
continue;
}
// Select A rows for group i: skip start_m rows
const void* group_a_ptr;
int lda;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
lda = mat_a.stride(0); // distance between rows
} else {
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
// Detect stride pattern for A tensor to determine appropriate lda calculation
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
if (a_is_strided_tensor) {
// For strided A tensors: stride(0) gives the actual leading dimension
lda = mat_a.stride(0);
} else {
// For non-strided A tensors: use the M dimension (total rows)
lda = mat_a.size(0); // Total M dimension for column-major layout
}
}
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
} else {
// Detect stride pattern to determine appropriate ldb calculation
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
if (is_strided_tensor) {
// For strided tensors: stride(2) gives the actual leading dimension
ldb = mat_b.stride(2);
} else {
// For non-strided tensors: use the N dimension
ldb = mat_b.size(1);
}
}
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
gemm_descs.push_back({
static_cast<ck::index_t>(m),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc)
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
// 3d*3d input, output is 3d - batched matrix multiplication
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
// Each batch is processed as a separate GEMM operation
const int batch_size = mat_a.size(0);
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
int N;
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
N = mat_b.size(2);
} else if (mat_b.size(2) == K) {
// B is [batch, n, k] - transposed layout
N = mat_b.size(1);
} else {
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
}
for (int i = 0; i < batch_size; ++i) {
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
// Select output batch for group i: Output[i, :, :]
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldb, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
} else {
// Column-major A: leading dimension = distance between columns = stride(2)
lda = mat_a.stride(2);
}
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B: leading dimension = distance between rows
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(1); // stride between K rows
} else {
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
}
} else {
// Column-major B: leading dimension = distance between columns
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(2); // stride between N columns
} else {
// B is [batch, n, k] - transposed layout
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
}
}
// Output is typically row-major: leading dimension = distance between rows = stride(1)
ldc = out.stride(1);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc)
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
// 3D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 3d*2d input, output is 3d.
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
const int batch_size = mat_a.size(0); // n_groups
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A
// For row-major A and B case: B should be [K, total_N]
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
for (int i = 0; i < num_groups; ++i) {
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
int end_n = offs_accessor[i];
int n = end_n - start_n;
// Skip zero-sized groups but continue processing subsequent groups
if (n <= 0) {
continue;
}
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
const void* group_b_ptr;
int ldb;
// Check if B is row-major or column-major
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(0); // distance between rows (should be total_N)
} else {
// Column-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(1); // distance between columns (should be K)
}
// Select output slice for group i: Output[:, start_n:end_n]
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
int lda, ldc;
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
// Output is row-major: leading dimension = distance between rows = stride(0)
ldc = out.stride(0);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc)
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
}
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
// Initialize d_ptrs with the correct size
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
static DeviceOp gemm_instance;
auto argument = gemm_instance.MakeArgument(
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
);
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
"CK Group GEMM: argument unsupported (shape/strides/type config)");
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
void* gemm_arg_buf = nullptr;
void* ws_buf = nullptr;
hipMalloc(&gemm_arg_buf, arg_buf_size);
hipMalloc(&ws_buf, ws_size);
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
auto invoker = gemm_instance.MakeInvoker();
hipStream_t stream = c10::hip::getCurrentHIPStream();
invoker.Run(argument, {stream});
hipFree(gemm_arg_buf);
hipFree(ws_buf);
}
void group_gemm_ck(
const at::Tensor& input_a,
const at::Tensor& input_b_colmajor,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& /*bias*/,
at::Tensor& out)
{
// Detect if input_a is row-major based on stride pattern
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
// Ensure tensor A is row-major and contiguous if not already
at::Tensor mat_a = input_a;
if (!a_row_major) {
// If A is not row-major, make it contiguous (row-major)
mat_a = input_a.contiguous();
}
// Force tensor B to be column-major using double transpose trick
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
at::Tensor mat_b = input_b_colmajor;
if (!b_col_major) {
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
}
// For 3D tensors, check the last dimension stride for row-major detection
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
if (mat_a.dtype() == at::kBFloat16) {
// bf16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kHalf) {
// fp16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kFloat) {
// fp32 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
}
}
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -40,37 +40,14 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
return true;
}
bool input_require_grad(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_mask) {
return at::GradMode::is_enabled() &&
(query.requires_grad() || key.requires_grad() || value.requires_grad() ||
(attn_mask.has_value() && attn_mask.value().requires_grad()));
}
bool check_grad(sdp::sdp_params const& params, bool debug) {
if (!input_require_grad(
params.query, params.key, params.value, params.attn_mask))
return true;
auto q_num_heads = params.query.sym_size(-3);
auto k_num_heads = params.key.sym_size(-3);
auto v_num_heads = params.value.sym_size(-3);
bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads;
if (debug && is_gqa)
TORCH_WARN(
"scale_dot_product_attention with gqa is not supported for gradient computation on xpu.");
bool attn_mask_needs_grad =
params.attn_mask.has_value() && params.attn_mask.value().requires_grad();
if (debug && attn_mask_needs_grad) {
TORCH_WARN(
"scale_dot_product_attention on xpu is not supported when attn_mask.requires_grad() == True.");
bool check_no_grad(sdp::sdp_params const& params, bool debug) {
const bool any_inputs_require_grad = params.query.requires_grad() ||
params.key.requires_grad() || params.value.requires_grad();
const bool gradmode_enabled = at::GradMode::is_enabled();
if (debug && any_inputs_require_grad && gradmode_enabled) {
TORCH_WARN("Backward or grad to be supported.");
}
return !is_gqa && !attn_mask_needs_grad;
return !any_inputs_require_grad || !gradmode_enabled;
}
bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) {
@ -88,7 +65,7 @@ bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) {
sdp::check_nonzero_sequence_lengths_dense,
sdp::check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>,
check_head_dim_size_xpu,
check_grad);
check_no_grad);
for (auto& constraint : constraints) {
if (!constraint(params, debug)) {
return false;
@ -248,11 +225,10 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale,
bool compute_logsumexp) {
std::optional<double> scale) {
TORCH_INTERNAL_ASSERT(
query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
"scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {B, H, T, K}");
"scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}");
TORCH_INTERNAL_ASSERT(
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
(key.size(2) == value.size(2)),
@ -269,9 +245,6 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
TORCH_INTERNAL_ASSERT(
!(attn_bias.has_value() && is_causal),
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal");
TORCH_INTERNAL_ASSERT(
!(attn_bias.has_value() && attn_bias.value().requires_grad()),
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot have requires_grad=True");
const int64_t batch_size = query.size(0);
const int64_t num_head_q = query.size(1);
@ -281,14 +254,11 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
const int64_t seq_len_q = query.size(2);
const int64_t seq_len_kv = key.size(2);
at::Tensor attention;
std::vector<int64_t> attention_shape = {
at::Tensor output;
std::vector<int64_t> output_shape = {
batch_size, num_head_q, seq_len_q, head_dim_v};
alloc_with_matching_layout(query, attention, attention_shape);
auto opts = query.options();
at::Tensor logsumexp =
at::empty({batch_size, num_head_q, seq_len_q}, opts.dtype(at::kFloat));
alloc_with_matching_layout(query, output, output_shape);
at::Tensor logsumexp, debug_attn_mask; // not supported
at::native::onednn::sdpa(
batch_size,
@ -304,15 +274,15 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
attn_bias,
is_causal,
scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim_qk)),
attention,
compute_logsumexp,
output,
false,
logsumexp);
// rng not used
auto philox_seed = at::empty({}, at::dtype(at::kLong));
auto philox_offset = at::empty({}, at::dtype(at::kLong));
return std::make_tuple(
attention,
output,
logsumexp,
/* cum_seq_q */ at::Tensor(),
/* cum_seq_k */ at::Tensor(),
@ -320,106 +290,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
seq_len_kv,
philox_seed,
philox_offset,
/*debug_attn_mask */ at::Tensor());
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward_xpu(
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& attn_bias,
std::array<bool, 4> grad_input_mask,
const at::Tensor& out,
const at::Tensor& logsumexp,
const at::Tensor& cum_seq_q,
const at::Tensor& cum_seq_k,
int64_t max_q,
int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
std::optional<double> scale) {
TORCH_INTERNAL_ASSERT(
grad_out.dim() == 4 && out.dim() == 4 &&
grad_out.size(0) == out.size(0) && grad_out.size(1) == out.size(1) &&
grad_out.size(2) == out.size(2) && grad_out.size(3) == out.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {B, H, T, K}");
TORCH_INTERNAL_ASSERT(
query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Accept only 4 dims inputs shape of {B, H, T, K}");
TORCH_INTERNAL_ASSERT(
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
(key.size(2) == value.size(2)),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: K/V should have the same batch / seq / num_head");
TORCH_INTERNAL_ASSERT(
query.size(0) == grad_out.size(0) && query.size(1) == grad_out.size(1) &&
query.size(2) == grad_out.size(2),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Q should have the same batch / num_head / seq_len as grad_out");
TORCH_INTERNAL_ASSERT(
query.size(3) == key.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Q/K should have the same head_dim");
TORCH_INTERNAL_ASSERT(
value.size(3) == grad_out.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: V should have the same head_dim as grad_out");
TORCH_INTERNAL_ASSERT(
query.size(1) == key.size(1),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: number of heads in K/V must equal to number of heads in Q");
TORCH_INTERNAL_ASSERT(
dropout_p == 0.0,
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Currently do not support dropout > 0");
TORCH_INTERNAL_ASSERT(
logsumexp.dim() == 3 && logsumexp.size(0) == query.size(0) &&
logsumexp.size(1) == query.size(1) &&
logsumexp.size(2) == query.size(2) &&
"scaled_dot_product_fused_attention_overrideable_backward_xpu: logsumexp should have the shape of {B, H, T}");
std::optional<Tensor> attn_bias_opt;
if (attn_bias.defined()) {
attn_bias_opt = attn_bias;
}
const int64_t batch_size = query.size(0);
const int64_t num_head_q = query.size(1);
const int64_t num_head_kv = key.size(1);
const int64_t seq_len_q = query.size(2);
const int64_t seq_len_kv = key.size(2);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
auto grad_q = at::empty_like(query);
auto grad_k = at::empty_like(key);
auto grad_v = at::empty_like(value);
auto grad_attn_bias = attn_bias_opt.has_value()
? at::empty_like(attn_bias_opt.value())
: at::Tensor();
at::native::onednn::sdpa_backward(
batch_size,
num_head_q,
num_head_kv,
seq_len_q,
seq_len_kv,
head_dim_qk,
head_dim_v,
grad_out,
query,
key,
value,
out,
logsumexp,
attn_bias_opt,
is_causal,
scale.has_value() ? scale.value() : (1.0 / std::sqrt(query.size(3))),
grad_q,
grad_k,
grad_v);
return std::make_tuple(
std::move(grad_q),
std::move(grad_k),
std::move(grad_v),
std::move(grad_attn_bias));
debug_attn_mask);
}
REGISTER_XPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_xpu);

View File

@ -86,28 +86,6 @@ struct zeta_functor {
}
};
struct logaddexp_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return c10::metal::logaddexp(a, b);
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return c10::metal::logaddexp(float(a), float(b));
}
};
struct logaddexp2_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return c10::metal::logaddexp2(a, b);
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return c10::metal::logaddexp2(float(a), float(b));
}
};
struct xlog1py_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
@ -399,10 +377,6 @@ REGISTER_FLOAT_BINARY_OP(fmin);
REGISTER_FLOAT_BINARY_OP(nextafter);
REGISTER_FLOAT_BINARY_OP(zeta);
REGISTER_INT2FLOAT_BINARY_OP(zeta);
REGISTER_FLOAT_BINARY_OP(logaddexp);
REGISTER_INT2FLOAT_BINARY_OP(logaddexp);
REGISTER_FLOAT_BINARY_OP(logaddexp2);
REGISTER_INT2FLOAT_BINARY_OP(logaddexp2);
REGISTER_FLOAT_BINARY_OP(xlog1py);
REGISTER_INT2FLOAT_BINARY_OP(xlog1py);
REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_t);
@ -489,8 +463,6 @@ REGISTER_BINARY_OP(add, float2, float2);
REGISTER_BINARY_OP(add, half2, half2);
REGISTER_BINARY_OP(sub, float2, float2);
REGISTER_BINARY_OP(sub, half2, half2);
REGISTER_BINARY_OP(logaddexp, float2, float2);
REGISTER_BINARY_OP(logaddexp, half2, half2);
REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2, float2);
REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2, half2);
REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2, float2);

View File

@ -89,14 +89,6 @@ static void zeta_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "zeta");
}
static void logaddexp_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "logaddexp");
}
static void logaddexp2_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "logaddexp2");
}
static void xlog1py_mps_kernel(TensorIteratorBase& iter) {
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "xlog1py_mps not implemented for non-floating types");
lib.exec_binary_kernel(iter, "xlog1py");
@ -219,8 +211,6 @@ REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
REGISTER_DISPATCH(copysign_stub, &copysign_mps_kernel)
REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel)
REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel)
REGISTER_DISPATCH(logaddexp_stub, &logaddexp_mps_kernel);
REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_mps_kernel);
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_mps_kernel)
REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_mps_kernel)
REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel)

View File

@ -17,6 +17,8 @@
#include <ATen/ops/ge_native.h>
#include <ATen/ops/gt_native.h>
#include <ATen/ops/le_native.h>
#include <ATen/ops/logaddexp2_native.h>
#include <ATen/ops/logaddexp_native.h>
#include <ATen/ops/logical_and_native.h>
#include <ATen/ops/logical_or_native.h>
#include <ATen/ops/logical_xor_native.h>
@ -275,6 +277,30 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
}
}
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor =
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil]
name:nil];
return [mpsGraph logarithmWithTensor:sumTensor name:nil];
};
mps::binaryOpTensor(self, other, output, "logaddexp_out_mps", logaddexp_op_block);
}
TORCH_IMPL_FUNC(logaddexp2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor =
[mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil]
name:nil];
return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil];
};
mps::binaryOpTensor(self, other, output, "logaddexp2_out_mps", logaddexp2_op_block);
}
TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();

View File

@ -1028,18 +1028,15 @@ TORCH_IMPL_FUNC(prod_out_mps)
}
TORCH_IMPL_FUNC(amax_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) {
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amax is not defined for complex types");
reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps");
}
TORCH_IMPL_FUNC(amin_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) {
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amin is not defined for complex types");
reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps");
}
TORCH_IMPL_FUNC(aminmax_out_mps)
(const Tensor& input_t, std::optional<int64_t> dim_opt, bool keepdim, const Tensor& min_t, const Tensor& max_t) {
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "aminmax is not defined for complex types");
reduction_out_mps(input_t,
dim_opt.has_value() ? OptionalIntArrayRef({*dim_opt}) : std::nullopt,
keepdim,

View File

@ -31,7 +31,6 @@ void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& v
indices.copy_(values.toType(at::ScalarType::Long));
return;
}
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "kthvalue is not implemented for complex types");
// issue #154890, raising error to prevent crash within MPSGraph until
// workaround is implemented.
TORCH_CHECK(self.dim() - dim <= 4, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890");

View File

@ -3622,7 +3622,8 @@
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: logaddexp_out
CPU, CUDA: logaddexp_out
MPS: logaddexp_out_mps
tags: pointwise
- func: logaddexp(Tensor self, Tensor other) -> Tensor
@ -3634,7 +3635,8 @@
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: logaddexp2_out
CPU, CUDA: logaddexp2_out
MPS: logaddexp2_out_mps
tags: pointwise
- func: logaddexp2(Tensor self, Tensor other) -> Tensor
@ -15095,7 +15097,7 @@
CPU: _scaled_dot_product_flash_attention_cpu
tags: nondeterministic_seeded
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None, bool compute_log_sumexp=True) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable
XPU: _scaled_dot_product_fused_attention_overrideable_xpu
@ -15119,7 +15121,6 @@
variants: function
dispatch:
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable_backward
XPU: _scaled_dot_product_fused_attention_overrideable_backward_xpu
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
dispatch:

View File

@ -467,28 +467,6 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe
!options.has_layout() || options.layout() == kSparse,
"expected sparse layout, but got layout ",
options.layout());
if (indices.numel() > 0) {
Tensor min_indices =
std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
Tensor cpu_min_indices;
if (!indices.is_cpu()) {
cpu_min_indices = min_indices.to(at::DeviceType::CPU);
} else {
cpu_min_indices = min_indices;
}
auto cpu_min_indices_accessor = cpu_min_indices.accessor<int64_t, 1>();
for (const auto d : c10::irange(indices.size(0))) {
int64_t min_index_in_dim = cpu_min_indices_accessor[d];
TORCH_CHECK(
min_index_in_dim >= 0,
"found negative index ",
min_index_in_dim,
" for dim ",
d);
}
}
return at::native::_sparse_coo_tensor_unsafe(
indices,
values,

View File

@ -768,11 +768,8 @@ Tensor scaled_dot_product_attention(
return std::get<0>(out_and_lse);
}
case SDPBackend::overrideable: {
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
compute_logsumexp = compute_logsumexp ||
(at::GradMode::is_enabled() && attn_mask.has_value() && attn_mask.value().requires_grad());
auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable(
query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale, compute_logsumexp);
query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale);
return std::get<0>(out_lse_softmax);
}
case SDPBackend::math: {
@ -1018,8 +1015,7 @@ _scaled_dot_product_fused_attention_overrideable(
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale,
bool compute_logsumexp) {
std::optional<double> scale) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_dot_product_fused_attention_overrideable not implemented. This is an operator for privateuse1 backends, please use TORCH_LIBRARY_IMPL to override this function ");
}

View File

@ -1837,10 +1837,6 @@ class BenchmarkRunner:
def skip_models_for_cuda(self):
return set()
@property
def skip_models_for_xpu(self):
return set()
@property
def skip_models_for_cpu(self):
return set()
@ -3931,8 +3927,6 @@ def run(runner, args, original_dir=None):
runner.skip_models.update(runner.skip_models_for_cpu_aarch64)
elif args.devices == ["cuda"]:
runner.skip_models.update(runner.skip_models_for_cuda)
elif args.devices == ["xpu"]:
runner.skip_models.update(runner.skip_models_for_xpu)
if not args.multiprocess:
runner.skip_models.update(runner.skip_multiprocess_models)

View File

@ -124,10 +124,6 @@ class TorchBenchmarkRunner(BenchmarkRunner):
def skip_models_for_cuda(self):
return self._skip["device"]["cuda"]
@property
def skip_models_for_xpu(self):
return self._skip["device"]["xpu"]
@property
def skip_models_for_freezing_cuda(self):
return self._skip["freezing"]["cuda"]

View File

@ -217,9 +217,6 @@ skip:
cuda: []
xpu:
- *DETECTRON2_MODELS
test:
training:
- *DETECTRON2_MODELS

View File

@ -482,7 +482,6 @@ inductor_core_resources = [
"torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp",
"torch/csrc/inductor/inductor_ops.cpp",
"torch/csrc/jit/serialization/pickle.cpp",
"torch/csrc/shim_common.cpp",
]
libtorch_core_sources = sorted(

View File

@ -1,4 +1,4 @@
// Implementation of special math functions for Metal
// Implementation of specal math functions for Metal
#pragma once
#include <c10/metal/expm1f.h>
#include <c10/metal/igamma.h>
@ -624,64 +624,6 @@ inline T spherical_bessel_j0(T x) {
return static_cast<T>(::metal::sin(x) / x);
}
template <typename T>
inline ::metal::enable_if_t<is_scalar_floating_point_v<T>, T> logaddexp(
T a,
T b) {
float a0 = static_cast<float>(a);
float b0 = static_cast<float>(b);
if (::metal::isinf(a0) && a0 == b0) {
return static_cast<T>(a0);
} else {
float m0 = ::metal::max(a0, b0);
return static_cast<T>(
m0 + ::c10::metal::log1p(::metal::exp(-::metal::abs(a0 - b0))));
}
}
// The function is ported from mlx
template <typename T>
inline ::metal::enable_if_t<is_complex_v<T>, T> logaddexp(T a, T b) {
if (::metal::isnan(a.x) || ::metal::isnan(a.y) || ::metal::isnan(b.x) ||
::metal::isnan(b.y)) {
return T(NAN, NAN);
}
T maxval = a.x > b.x ? a : b;
T minval = a.x < b.x ? a : b;
constexpr auto inf = ::metal::numeric_limits<T>::infinity().x;
if (minval.x == -inf || maxval.x == inf) {
return maxval;
}
float2 maxval_ = static_cast<float2>(maxval);
float2 minval_ = static_cast<float2>(minval);
float m = ::metal::exp(minval_.x - maxval_.x);
float2 dexp{
m * ::metal::cos(minval_.y - maxval_.y),
m * ::metal::sin(minval_.y - maxval_.y),
};
return static_cast<T>(maxval_ + ::c10::metal::log1p(dexp));
}
template <typename T>
inline T logaddexp2(T a, T b) {
constexpr auto log_2 = float(0.693147180559945309417232121458176);
constexpr auto inv_log_2 = float(1) / log_2;
float a0 = static_cast<float>(a);
float b0 = static_cast<float>(b);
if (::metal::isinf(a0) && a0 == b0) {
return static_cast<T>(a0);
} else {
float m0 = ::metal::max(a0, b0);
return static_cast<T>(
m0 +
::c10::metal::log1p(::metal::pow(float(2), -::metal::abs(a0 - b0))) *
inv_log_2);
}
}
template <typename T>
inline float xlog1py(T x, T y) {
if (::metal::isnan(y)) {

View File

@ -322,24 +322,6 @@ inline float log1p(float x) {
return rc;
}
// The function is ported from mlx
inline float2 log1p(float2 in) {
float x = in.x;
float y = in.y;
float zabs = ::metal::precise::sqrt(x * x + y * y);
float theta = ::metal::atan2(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1p(r), theta};
} else {
auto z0 = ::metal::sqrt((x + 1) * (x + 1) + y * y);
return {::metal::log(z0), theta};
}
}
template <typename T1, typename T2 = T1>
struct pair {
T1 first;

View File

@ -34,7 +34,7 @@ struct MemEvent {
bool overlaps(const MemBlock& a, const MemBlock& b) {
// two blocks dont overlap if
// |---a--------|--------------b--------|
// start_a end_a <= start_b end_b
// strat_a end_a <= start_b end_b
return !(
(a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset));
}

View File

@ -33,7 +33,7 @@ struct bitset final {
constexpr bitset() noexcept = default;
constexpr bitset(const bitset&) noexcept = default;
constexpr bitset(bitset&&) noexcept = default;
// there is an issue for gcc 5.3.0 when define default function as constexpr
// there is an issure for gcc 5.3.0 when define default function as constexpr
// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754.
bitset& operator=(const bitset&) noexcept = default;
bitset& operator=(bitset&&) noexcept = default;

View File

@ -554,17 +554,6 @@ class DeviceCachingAllocator {
}
}
double getMemoryFraction() {
if (!set_fraction) {
return 1.0;
}
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
return static_cast<double>(allowed_memory_maximum) /
static_cast<double>(device_prop.global_mem_size);
}
void setMemoryFraction(double fraction) {
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
@ -735,11 +724,6 @@ class XPUAllocator : public DeviceAllocator {
device_allocators[device]->resetAccumulatedStats();
}
double getMemoryFraction(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getMemoryFraction();
}
void setMemoryFraction(double fraction, DeviceIndex device) {
assertValidDevice(device);
TORCH_CHECK_VALUE(
@ -793,10 +777,6 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
return allocator.recordStream(dataPtr, stream);
}
double getMemoryFraction(DeviceIndex device) {
return allocator.getMemoryFraction(device);
}
void setMemoryFraction(double fraction, DeviceIndex device) {
return allocator.setMemoryFraction(fraction, device);
}

View File

@ -25,8 +25,6 @@ C10_XPU_API void raw_delete(void* ptr);
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
C10_XPU_API double getMemoryFraction(DeviceIndex device);
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);
} // namespace c10::xpu::XPUCachingAllocator

View File

@ -1,5 +1,4 @@
list(APPEND Caffe2_CPU_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/common.cc"
)
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)

View File

@ -38,7 +38,7 @@ uint32_t crc32_combine (uint32_t crcA, uint32_t crcB, size_t lengthB);
/// compute CRC32 (bitwise algorithm)
uint32_t crc32_bitwise (const void* data, size_t length, uint32_t previousCrc32 = 0);
/// compute CRC32 (half-byte algorithm)
/// compute CRC32 (half-byte algoritm)
uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32 = 0);
#ifdef CRC32_USE_LOOKUP_TABLE_BYTE
@ -96,7 +96,7 @@ uint32_t crc32_16bytes_prefetch(const void* data, size_t length, uint32_t previo
#define __BIG_ENDIAN 4321
#endif
// define endianness and some integer data types
// define endianess and some integer data types
#if defined(_MSC_VER) || defined(__MINGW32__)
// Windows always little endian
#define __BYTE_ORDER __LITTLE_ENDIAN
@ -168,7 +168,7 @@ namespace
/// zlib's CRC32 polynomial
const uint32_t Polynomial = 0xEDB88320;
/// swap endianness
/// swap endianess
static inline uint32_t swap(uint32_t x)
{
#if defined(__GNUC__) || defined(__clang__)
@ -229,7 +229,7 @@ uint32_t crc32_bitwise(const void* data, size_t length, uint32_t previousCrc32)
}
/// compute CRC32 (half-byte algorithm)
/// compute CRC32 (half-byte algoritm)
uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32)
{
uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF
@ -662,7 +662,7 @@ uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB)
// - if you append length(B) zeros to A and call it A' (think of it as AAAA000)
// and prepend length(A) zeros to B and call it B' (think of it as 0000BBB)
// then exists a C' = A' ^ B'
// - remember: if you XOR something with zero, it remains unchanged: X ^ 0 = X
// - remember: if you XOR someting with zero, it remains unchanged: X ^ 0 = X
// - that means C' = A concat B so that crc(A concat B) = crc(C') = crc(A') ^ crc(B')
// - the trick is to compute crc(A') based on crc(A)
// and crc(B') based on crc(B)

View File

@ -76,7 +76,7 @@ typedef struct mz_zip_archive mz_zip_archive;
// 2) Writing with 1-pass sequential access
// -> We must take care not to require updating values that have already
// been written. We place the variable-length index at the end and do
// not put any index into the header to fulfill this constraint.
// not put any indicies into the header to fulfill this constraint.
// The model.json, which contains all the metadata information,
// should be written as the last file. One reason is that the size of tensor

View File

@ -519,7 +519,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoadWithAllocator) {
std::tie(data_ptr, size) = reader.getRecord("key1", &overrideAllocator);
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes);
// allocate with base allocator
// allcoate with base allocator
std::tie(data_ptr, size) = reader.getRecord("key1");
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1);

View File

@ -2,9 +2,9 @@
## Overview
The LibTorch Stable ABI (Application Binary Interface) provides a limited interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases. This limited set of APIs is not intended to replace existing LibTorch, but rather to provide a stable foundation for a majority of custom extension use cases. If there is any API you would like to see added to the stable ABI, please file a request through a [new issue on the PyTorch repo](https://github.com/pytorch/pytorch/issues).
The LibTorch Stable ABI (Application Binary Interface) provides an interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases.
The limited stable ABI consists of three main components:
The stable ABI consists of three main components:
1. **Stable C headers** - Low-level C API implemented by libtorch (primarily `torch/csrc/inductor/aoti_torch/c/shim.h`)
2. **Header-only C++ library** - Standalone utilities implemented in only headers such that there is no dependence on libtorch (`torch/headeronly/*`)
@ -14,8 +14,8 @@ We discuss each of these in detail
### `torch/headeronly`
The inlined C++ headers living in [`torch/headeronly`](https://github.com/pytorch/pytorch/tree/main/torch/headeronly) are completely decoupled from LibTorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the
`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`, as well as a libtorch-independent version of `TORCH_CHECK` that is `STD_TORCH_CHECK`. You can trust all APIs in the `torch::headeronly` namespace to not depend on `libtorch.so`. These APIs are also globally listed in [torch/header_only_apis.txt](https://github.com/pytorch/pytorch/blob/main/torch/header_only_apis.txt).
This is a set of inlined C++ headers are completely decoupled from libtorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the
`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`.
### `torch/csrc/stable`
@ -34,14 +34,8 @@ We are continuing to improve coverage in our `torch/csrc/stable` APIs. Please fi
### Stable C headers
The stable C headers started by AOTInductor form the foundation of the stable ABI. Presently, the available C headers include:
- [torch/csrc/inductor/aoti_torch/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/c/shim.h): Includes C-style shim APIs for commonly used regarding Tensors, dtypes, CUDA, and the like.
- [torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h): Includes C-style shim APIs for ATen ops from `native_functions.yaml` (e.g. `aoti_torch_aten_new_empty`).
- [torch/csrc/inductor/aoti_torch/generated/c_shim_*.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated): Includes C-style shim APIs for specific backend kernels dispatched from `native_functions.yaml` (e.g. `aoti_torch_cuda_pad`). These APIs should only be used for the specific backend they are named after (e.g. `aoti_torch_cuda_pad` should only be used within CUDA kernels), as they opt out of the dispatcher.
- [torch/csrc/stable/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/stable/c/shim.h): We are building out more ABIs to logically live in `torch/csrc/stable/c` instead of continuing the AOTI naming that no longer makes sense for our general use case.
These headers are promised to be ABI stable across releases and adhere to a stronger backwards compatibility policy than LibTorch. Specifically, we promise not to modify them for at least 2 years after they are released. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs. Further, the stack-based APIs discussed below which allow the user to call into the PyTorch dispatcher do not provide strong guarantees on forward and backward compatibility of the underlying op that is called.
The stable C headers used by AOTInductor form the foundation of the stable ABI. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs.
Further, the stack-based APIs discussed below which allow the user to call the PyTorch dispatcher don't provide strong guarantees on forward and backward compatibility.
Unless absolutely necessary, we recommend the high-level C++ API in `torch/csrc/stable`
which will handle all the rough edges of the C API for the user.
@ -128,38 +122,12 @@ The above is relevant in two places:
}
```
2. `torch_call_dispatcher`
2. `aoti_torch_call_dispatcher`
This API allows you to call the PyTorch dispatcher from C/C++ code. It has the following signature:
```cpp
torch_call_dispatcher(const char* opName, const char* overloadName, StableIValue* stack, uint64_t extension_build_version);
aoti_torch_call_dispatcher(const char* opName, const char* overloadName, StableIValue* stack);
```
`torch_call_dispatcher` will call the op overload defined by a given `opName`, `overloadName`, a stack of
StableIValues and the `TORCH_ABI_VERSION` of the user extension. This call will populate any return values of the
op into the stack in their StableIValue form, with `ret0` at index 0, `ret1` at index 1, and so on.
We caution against using this API to call functions that have been registered to the dispatcher by other extensions
unless the caller can guarantee that the signature they expect matches that which the custom extension has
registered.
### Versioning and Forward/Backward compatibility guarantees
We provide a `TORCH_ABI_VERSION` macro in `torch/headeronly/version.h` of the form
```
[ byte ][ byte ][ byte ][ byte ][ byte ][ byte ][ byte ][ byte ]
[MAJ ][ MIN ][PATCH ][ ABI TAG ]
```
In the present phase of development, APIs in the C-shim will be versioned based on major.minor.patch release that they are first introduced in, with 2.10 being the first release where this will be enforced. The ABI tag is reserved for future use.
Extensions can select the minimum abi version to be compatible with using:
```
#define TORCH_TARGET_VERSION (((0ULL + major) << 56) | ((0ULL + minor) << 48))
```
before including any stable headers or by passing the equivalent `-D` option to the compiler. Otherwise, the default will be the current `TORCH_ABI_VERSION`.
The above ensures that if a user defines `TORCH_TARGET_VERSION` to be 0x0209000000000000 (2.9) and attempts to use a C shim API `foo` that was introduced in version 2.10, a compilation error will be raised. Similarly, the C++ wrapper APIs in `torch/csrc/stable` are compatible with older libtorch binaries up to the TORCH_ABI_VERSION they are exposed in and forward compatible with newer libtorch binaries.
`aoti_torch_call_dispatcher` will call the op overload defined by a given `opName`, `overloadName`, and a stack of
StableIValues. This call will populate any return values of the op into the stack in their StableIValue form,
with `ret0` at index 0, `ret1` at index 1, and so on.

View File

@ -59,6 +59,7 @@ torch.special <special>
torch.overrides
torch.nativert <nativert>
torch.package <package>
torch.pytree <pytree>
profiler
nn.init
nn.attention
@ -76,6 +77,7 @@ sparse
storage
torch.testing <testing>
torch.utils <utils>
torch.utils.pytree
torch.utils.benchmark <benchmark_utils>
torch.utils.checkpoint <checkpoint>
torch.utils.cpp_extension <cpp_extension>

7
docs/source/pytree.rst Normal file
View File

@ -0,0 +1,7 @@
torch.pytree
============
.. currentmodule:: torch.pytree
.. automodule:: torch.pytree
:members:

View File

@ -0,0 +1,7 @@
torch.utils.pytree
==================
.. currentmodule:: torch.utils.pytree
.. automodule:: torch.utils.pytree
:members:

View File

@ -76,7 +76,6 @@
:nosignatures:
empty_cache
get_per_process_memory_fraction
max_memory_allocated
max_memory_reserved
mem_get_info

View File

@ -29,6 +29,7 @@ files =
benchmarks/instruction_counts,
tools,
torch/profiler/_memory_profiler.py,
torch/utils/pytree/__init__.py,
torch/utils/_pytree.py,
torch/utils/_cxx_pytree.py,
torch/utils/benchmark/utils/common.py,

View File

@ -1106,7 +1106,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
continue
self.copy_file(source_lib, target_lib)
# Delete old rpath and add @loader_lib to the rpath
# This should prevent deallocate from attempting to package another instance
# This should prevent delocate from attempting to package another instance
# of OpenMP library in torch wheel as well as loading two libomp.dylib into
# the address space, as libraries are cached by their unresolved names
install_name_tool_args = [

View File

@ -687,6 +687,28 @@
"kineto_available",
"record_function"
],
"torch.pytree": [
"PyTreeSpec",
"register_node",
"all",
"all_only",
"any",
"any_only",
"flatten",
"iter",
"leaves",
"map",
"map_",
"map_only",
"map_only_",
"structure",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance"
],
"torch.quantization": [
"ABC",
"DeQuantStub",

View File

@ -58,8 +58,7 @@ wrapper__scaled_dot_product_fused_attention_overrideable(
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale,
bool compute_log_sumexp) {
std::optional<double> scale) {
return at::native::openreg::_scaled_dot_product_fused_attention_overrideable(
query,
key,
@ -68,8 +67,7 @@ wrapper__scaled_dot_product_fused_attention_overrideable(
dropout_p,
is_causal,
return_debug_mask,
scale,
compute_log_sumexp);
scale);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>

View File

@ -47,8 +47,7 @@ _scaled_dot_product_fused_attention_overrideable(
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale,
bool compute_log_sumexp) {
std::optional<double> scale) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim_v = value.size(3);

View File

@ -39,8 +39,7 @@ _scaled_dot_product_fused_attention_overrideable(
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale,
bool compute_log_sumexp);
std::optional<double> scale);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor& grad_out,

View File

@ -827,7 +827,7 @@ class TestFullyShardShardPlacementFnMultiProcess(FSDPTest):
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
for _ in range(5):
for iter_idx in range(5):
ref_loss = ref_model(inp).sum()
loss = model(inp).sum()
self.assertEqual(ref_loss, loss)

View File

@ -31,17 +31,17 @@ if TEST_WITH_DEV_DBG_ASAN:
sys.exit(0)
_DISTRIBUTED_STATE_DICT_IMPLS = {
_DISTRIBUTED_STATE_DICT_IMPLS = (
StateDictType.LOCAL_STATE_DICT,
StateDictType.SHARDED_STATE_DICT,
}
)
class TestDistributedCheckpoint(FSDPTest):
@property
def world_size(self):
if torch.accelerator.is_available():
gpu_cnt = torch.accelerator.device_count()
if torch.cuda.is_available():
gpu_cnt = torch.cuda.device_count()
if gpu_cnt < 2:
return gpu_cnt
return 2
@ -93,9 +93,7 @@ class TestDistributedCheckpoint(FSDPTest):
# TODO: add resharding test case.
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestDistributedCheckpoint, globals(), only_for=devices, allow_xpu=True
)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestDistributedCheckpoint, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -36,8 +36,8 @@ device_type = torch.device(get_devtype())
class TestApply(FSDPTest):
@property
def world_size(self):
if torch.accelerator.is_available():
gpu_cnt = torch.accelerator.device_count()
if torch.cuda.is_available():
gpu_cnt = torch.cuda.device_count()
if gpu_cnt < 2:
return gpu_cnt
return 2

View File

@ -2,6 +2,7 @@
# Owner(s): ["oncall: distributed"]
import sys
from pathlib import Path
import torch
import torch.distributed as dist
@ -44,19 +45,53 @@ class TestInstantiator(TestCase):
self.assertEqual(return_type_str, "Tuple[Tensor, int, str]")
def test_instantiate_scripted_remote_module_template(self):
dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
# Cleanup.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
for file_path in file_paths:
file_path.unlink()
# Check before run.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
num_files_before = len(list(file_paths))
self.assertEqual(num_files_before, 0)
generated_module = instantiator.instantiate_scriptable_remote_module_template(
MyModuleInterface
)
self.assertTrue(hasattr(generated_module, "_remote_forward"))
self.assertTrue(hasattr(generated_module, "_generated_methods"))
# Check after run.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
num_files_after = len(list(file_paths))
self.assertEqual(num_files_after, 1)
def test_instantiate_non_scripted_remote_module_template(self):
dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
# Cleanup.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
for file_path in file_paths:
file_path.unlink()
# Check before run.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
num_files_before = len(list(file_paths))
self.assertEqual(num_files_before, 0)
generated_module = (
instantiator.instantiate_non_scriptable_remote_module_template()
)
self.assertTrue(hasattr(generated_module, "_remote_forward"))
self.assertTrue(hasattr(generated_module, "_generated_methods"))
# Check after run.
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
num_files_after = len(list(file_paths))
self.assertEqual(num_files_after, 1)
if __name__ == "__main__":
run_tests()

View File

@ -64,10 +64,6 @@ class TestDTensorDebugMode(TestCase):
self.assertTrue(isinstance(debug_mode.operators[2], _RedistributeCall))
self.assertEqual(next(iter(debug_mode.operators[1])), torch.ops.aten.mm.default)
# check stringification
self.assertTrue(hasattr(debug_mode.operators[0], "args_str"))
self.assertFalse(hasattr(debug_mode.operators[0], "args"))
def test_debug_string_inside_context(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
@ -271,7 +267,6 @@ class TestDTensorDebugMode(TestCase):
record_torchfunction=True,
record_faketensor=True,
record_tensor_attributes=["a1", "a2"],
store_original_args=True,
) as debug_mode:
torch.matmul(y, x)
@ -284,9 +279,6 @@ class TestDTensorDebugMode(TestCase):
aten::_unsafe_view(t: f32[64, 8], [8, 8, 8])""",
)
self.assertTrue(hasattr(debug_mode.operators[0], "args"))
self.assertEqual(id(debug_mode.operators[0].args[0]), id(y))
@parametrize("has_inner_mode", [True, False])
@parametrize("has_outer_mode", [True, False])
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):

View File

@ -20,18 +20,18 @@ from torch.distributed.tensor.experimental._attention import (
_cp_options,
_disable_context_parallel_dispatcher,
_enable_context_parallel_dispatcher,
_HeadTailLoadBalancer,
_is_causal_behavior,
_LoadBalancer,
_PerDocumentHeadTailLoadBalancer,
_PTRRLoadBalancer,
_RotateMethod,
context_parallel,
context_parallel_unshard,
set_rotate_method,
)
from torch.distributed.tensor.experimental._context_parallel._cp_custom_ops import (
flex_cp_allgather,
from torch.distributed.tensor.experimental._cp_custom_ops import flex_cp_allgather
from torch.distributed.tensor.experimental._load_balancer import (
_HeadTailLoadBalancer,
_LoadBalancer,
_PerDocumentHeadTailLoadBalancer,
_PTRRLoadBalancer,
)
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend
@ -52,9 +52,7 @@ from torch.testing._internal.common_cuda import (
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
map_local_tensor_for_rank,
with_comms,
)
@ -802,47 +800,11 @@ class TestSharding(DTensorTestBase):
chunks = freqs_cis.chunk(self.world_size * 2)
self.assertEqual(
freqs_cis_shard,
map_local_tensor_for_rank(
chunks,
self.rank,
lambda chunks, rank: torch.cat(
[chunks[rank], chunks[self.world_size * 2 - rank - 1]],
dim=0,
),
torch.cat(
[chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0
),
)
RingAttentionTestWithLocalTensor = create_local_tensor_test_class(
RingAttentionTest,
skipped_tests=[
# Need to make attention implementation local tensor friendly, e.g.
# rewrite "rank local" logic
"test_ring_attention_sdpa",
],
)
CPFlexAttentionTestWithLocalTensor = create_local_tensor_test_class(
CPFlexAttentionTest,
skipped_tests=[
# Missing support for batched tensors
"test_cp_flex_attention_causal_mask",
"test_cp_flex_attention_document_mask",
],
)
TestCPCustomOpsWithLocalTensor = create_local_tensor_test_class(
TestCPCustomOps,
skipped_tests=[
# Missing support for fake tensors
"test_flex_cp_custom_op",
],
)
TestShardingWithLocalTensor = create_local_tensor_test_class(
TestSharding,
)
if __name__ == "__main__":
run_tests()

View File

@ -16,7 +16,6 @@ from torch.distributed.tensor import (
from torch.nn import functional as F
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
@ -204,42 +203,34 @@ class DistConvolutionOpsTest(DTensorTestBase):
self.assertTrue(b_dt.grad is not None)
self.assertTrue(x_dt.grad is None)
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
device_mesh = self.build_device_mesh()
model_copy = copy.deepcopy(model).to(device=self.device_type)
dist_model = distribute_module(model, device_mesh, _conv_fn)
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
out_dt = dist_model(arg_dt.to(device=self.device_type))
out = model_copy(arg)
return (out_dt.full_tensor(), out)
@with_comms
def test_conv1d(self):
device_mesh = self.build_device_mesh()
model = nn.Conv1d(64, 64, 3, padding=1)
x = torch.randn(1, 64, 8, device=self.device_type)
out_dt, out = self._run_single_arg_fwd(model, x)
model_gt = copy.deepcopy(model)
x = torch.randn(1, 64, 8)
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
model_dt = distribute_module(
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
)
out_dt = model_dt(x_dt)
out = model_gt(x)
self.assertEqual(out_dt.shape, out.shape)
@with_comms
def test_conv3d(self):
device_mesh = self.build_device_mesh()
model = nn.Conv3d(64, 64, 3, padding=1)
model_gt = copy.deepcopy(model).to(device=self.device_type)
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
out_dt, out = self._run_single_arg_fwd(model, x)
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
model_dt = distribute_module(
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
)
out_dt = model_dt(x_dt)
out = model_gt(x)
self.assertEqual(out_dt.shape, out.shape)
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
DistConvolutionOpsTest,
# Send / recv ops are not supported
skipped_tests=[
"test_conv1d",
"test_conv3d",
"test_conv_backward_none_grad_inp",
"test_depthwise_convolution",
"test_downsampling_convolution",
],
)
if __name__ == "__main__":
run_tests()

View File

@ -464,6 +464,25 @@ def forward(self, b_parametrizations_buffer_original0, x):
run(g, 64, 8)
self.assertEqual(cnt.frame_count, 2)
def test_dtensor_requires_grad_recompile(self):
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
y = x * x
return y.to_local()
full_x = torch.randn(8, 8, requires_grad=False)
x = distribute_tensor(full_x, mesh, [Shard(0)])
f(x)
full_x = torch.randn(8, 8, requires_grad=True)
x = distribute_tensor(full_x, mesh, [Shard(0)])
f(x)
self.assertEqual(cnt.frame_count, 2)
def test_dtensor_attribute_access_on_intermediate(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

View File

@ -520,21 +520,6 @@ class DTensorExportTest(TestCase):
2,
)
def test_union_typed_annotation(self):
def fn(leaf: torch.Tensor | DTensor):
def nest_fn(leaf: torch.Tensor | DTensor):
# def nest_fn(leaf: Union[torch.Tensor, DTensor]): # this works
if isinstance(leaf, DTensor):
leaf = leaf.to_local()
return leaf
return nest_fn(leaf) + 1
z = torch.randn(16, 16)
gm = graph_capture_and_aot_export_joint_with_descriptors(fn, (z,))
self.assertEqual(fn(z), gm(z)[0])
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -60,9 +60,9 @@ class DistMathOpsTest(DTensorTestBase):
shard_spec = [Shard(0)]
tensor = torch.randn(12, 8, 8)
if op_str in ("any", "all"):
# Test bool tensor for any() and all() reduction ops
# Previously all() had a bug using sum reduction instead of product
# TODO: check `all` correctness and test `all` on a bool tensor
if op_str in ("any"):
# test out a bool tensor for any
tensor = tensor < 0
dtensor = distribute_tensor(tensor, device_mesh, shard_spec)

View File

@ -887,135 +887,6 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
correct = func(a, b, c, d, ranks=ranks)
self.assertTrue(same(test_out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_custom_estimation_with_fake_tensor_mode(self):
"""Test that custom estimation can use FakeTensorMode for analysis."""
from torch._subclasses.fake_tensor import FakeTensorMode
estimation_calls = 0
def estimate_with_fake_mode(fx_node, compute_multiplier=1.0):
with FakeTensorMode():
nonlocal estimation_calls
estimation_calls += 1
assert isinstance(torch.rand([20]), torch._subclasses.FakeTensor)
return 1.0
patches = get_bucket_patches()
patches["aten_distributed_optimizations.custom_runtime_estimation"] = (
estimate_with_fake_mode
)
def func(a, b, *, ranks):
# Two independent all_gathers that should be bucketed
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
# Matmul that can hide the collectives
mm1 = torch.matmul(a, a)
return ag1.sum() + ag2.sum() + mm1.sum()
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs_a = torch.ones(4, 4, dtype=torch.float, device=device_type)
inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
with torch._inductor.config.patch(patches):
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(
compiled, inputs_a, inputs_b
)
# Verify the custom estimation was called
self.assertTrue(
estimation_calls > 0, "Custom estimation should have been called"
)
correct = func(inputs_a, inputs_b, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_multidtype_bucketing(self):
"""Test that all_gathers with different dtypes get bucketed together."""
def func(a, b, c, *, ranks):
# Three all_gathers with different dtypes
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) # float32
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) # float16
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks) # float16
# Use all results
return ag1.sum() + ag2.sum() + ag3.sum()
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(4, 4, dtype=torch.float32, device=device_type)
b = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 2
c = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 3
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c)
# Should have 1 bucketed all_gather despite different dtypes
FileCheck().check_count(
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
).run(aten_graph_str)
# Verify correctness
correct = func(a, b, c, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_basic_all_reduce_bucketing(self):
"""Test that independent all_reduce operations get bucketed together."""
def func(a, b, c):
# Three independent all_reduces that should be bucketed
ar1 = _functional_collectives.all_reduce(a, "sum", "0")
ar2 = _functional_collectives.all_reduce(b, "sum", "0")
ar3 = _functional_collectives.all_reduce(c, "sum", "0")
return ar1.sum() + ar2.sum() + ar3.sum()
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2
c = torch.ones(4, 4, dtype=torch.float, device=device_type) * 3
compiled = torch.compile(func)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c)
# Should see a single bucketed all_reduce
FileCheck().check_count(
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
).run(aten_graph_str)
# Verify correctness
correct = func(a, b, c)
self.assertTrue(same(out, correct))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1,572 +0,0 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
import torch.distributed as dist
import torch.fx as fx
# for some reason importing functional collectives after dynamo breaks collectives handling!
from torch._C import FileCheck
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_distributed import requires_accelerator_dist_backend
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._ordered_set import OrderedSet
# flake8: noqa: B950
# Owner(s): ["module: inductor"]
aten = torch.ops.aten
from torch.testing._internal.common_fsdp import get_devtype
device_type = str(get_devtype())
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
# for some reason importing functional collectives after dynamo breaks collectives handling!
@requires_accelerator_dist_backend(["nccl", "xccl"])
def build_collective_info(graph, hiding_annotations):
"""
Build CollectiveInfo dict from manual hiding annotations.
hiding_annotations: dict mapping collective_start -> hiding_compute_node
"""
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
collective_info = {}
# Find all collective starts and their corresponding waits
start_to_wait = {}
for node in graph.nodes:
if node.op == "call_function" and "wait_tensor" in str(node.target):
wait_input = node.args[0]
if isinstance(wait_input, fx.Node):
start_to_wait[wait_input] = node
# Build CollectiveInfo for each collective
for start_node, wait_node in start_to_wait.items():
hiding_node = hiding_annotations.get(start_node)
# Estimate size and time
size_bytes = 16 * 4 # 4x4 tensor of floats
estimated_time_ms = 1.0 # Dummy time
exposed_time_ms = 0.0 if hiding_node else 1.0 # Hidden if has hiding_node
collective_info[start_node] = CollectiveInfo(
start_node=start_node,
wait_node=wait_node,
size_bytes=size_bytes,
estimated_time_ms=estimated_time_ms,
exposed_time_ms=exposed_time_ms,
hiding_node=hiding_node,
)
return collective_info
def compute_ancestors(graph):
"""Compute ancestor sets for all nodes in the graph."""
node_ancestors = {}
for node in graph.nodes:
ancestors = OrderedSet()
stack = list(node.all_input_nodes)
visited = set()
while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)
ancestors.add(current)
stack.extend(current.all_input_nodes)
node_ancestors[node] = ancestors
return node_ancestors
@requires_accelerator_dist_backend()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@instantiate_parametrized_tests
class TestOverlapPreservingBucketing(InductorTestCase):
"""
Unit tests for overlap-preserving bucketing pass.
"""
@classmethod
def setUpClass(cls):
super().setUpClass()
from torch.testing._internal.distributed.fake_pg import FakeStore
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
cls.device = "cuda"
@classmethod
def tearDownClass(cls):
super().tearDownClass()
dist.destroy_process_group()
def test_can_bucket_independent_collectives(self):
"""
Test that independent collectives with separate hiding nodes CAN bucket.
Graph structure:
ag1_start -> ag2_start -> mm1 (hides ag1) -> mm2 (hides ag2) -> ag1_wait -> ag2_wait
"""
def func(a, b):
group_name = "0"
group_size = 1
# Start both collectives
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a, group_size, group_name
)
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
b, group_size, group_name
)
# Independent compute that can hide both
mm1 = torch.mm(a, a)
mm2 = torch.mm(b, b)
# Wait for both
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum()
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device)
b = torch.ones(4, 4, device=self.device) * 2
# Trace with make_fx
traced = make_fx(func)(a, b)
# Find nodes using find_nodes
ag1, ag2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
mm1, mm2 = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
# Manually annotate hiding relationships
hiding_annotations = {
ag1: mm1, # mm1 hides ag1
ag2: mm2, # mm2 hides ag2
}
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Run bucketing
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
# Verify: should have 1 bucketed collective (all_gather_into_tensor_out)
graph_str = str(traced.graph)
FileCheck().check_count("all_gather_into_tensor_out", 1, exactly=False).run(
graph_str
)
def test_cant_bucket_nested_hiding_intervals(self):
"""
Test that nested hiding intervals prevent bucketing.
Graph structure:
ag1_start -> ag2_start -> mm2 (hides ag2) -> ag2_wait -> mm1 (hides ag1) -> ag1_wait
ag2's hiding interval is nested inside ag1's hiding interval.
"""
def func(a, b):
group_name = "0"
group_size = 1
# ag1 starts first
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a, group_size, group_name
)
# ag2 starts (inside ag1's interval)
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
b, group_size, group_name
)
# mm2 hides ag2
mm2 = torch.mm(b[:2, :2], b[:2, :2])
# ag2 waits (still inside ag1's interval)
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
# mm1 uses ag2's result and hides ag1
mm1 = torch.mm(a + ag2_out[:4, :4], a)
# ag1 waits last
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum()
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device)
b = torch.ones(4, 4, device=self.device) * 2
# Trace with make_fx
traced = make_fx(func)(a, b)
# Find nodes using find_nodes
ag1, ag2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
mm_nodes = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
# mm2 is the first mm, mm1 is the second (based on graph order)
mm2 = mm_nodes[0]
mm1 = mm_nodes[1]
# Manually annotate hiding relationships
hiding_annotations = {
ag1: mm1, # mm1 hides ag1
ag2: mm2, # mm2 hides ag2
}
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Run bucketing
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
# Verify: nested hiding intervals should prevent bucketing
# Should have 2 separate all_gathers, not 1 bucketed one
graph_str = str(traced.graph)
FileCheck().check_count("all_gather_into_tensor", 2, exactly=False).run(
graph_str
)
@parametrize("final_mm_hidden", (True, False))
def test_cant_bucket_ag_with_rs_hiding_interval_between(self, final_mm_hidden):
"""
Test that all_gathers can't bucket when a reduce_scatter's hiding interval is between them.
Graph structure:
ag1_start -> mm1 (hides ag1) -> ag1_wait ->
rs_start -> mm2 (hides rs) -> rs_wait ->
if final_mm_hidden:
ag2_start -> mm3 (hides ag2) -> ag2_wait
if final_mm_hidden:
Bucketing ag1 and ag2 would require moving one of them, which would break hiding relationships:
- Moving ag2 earlier would break ag2's hiding by mm3
- Moving ag1 later would break ag1's hiding by mm1
- The rs hiding interval creates an obstacle between them
otherwise, we can bucket
"""
def func(a, b, c):
group_name = dist.distributed_c10d._get_default_group().group_name
group_size = 1
# First all_gather
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a, group_size, group_name
)
mm1 = torch.mm(a, a) # hides ag1
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
# Reduce scatter in between
rs = torch.ops._c10d_functional.reduce_scatter_tensor(
b, "sum", group_size, group_name
)
mm2 = torch.mm(b[:4, :4], b[:4, :4]) # hides rs
rs_out = torch.ops._c10d_functional.wait_tensor(rs)
# Second all_gather
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
c, group_size, group_name
)
mm3 = torch.mm(c, c) # hides ag2
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
return ag1_out.sum() + rs_out.sum() + ag2_out.sum(), mm1, mm2, mm3
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device)
b = torch.ones(8, 4, device=self.device)
c = torch.ones(4, 4, device=self.device)
# Trace with make_fx
traced = make_fx(func)(a, b, c)
ag1, ag2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
(rs,) = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.reduce_scatter_tensor.default,
)
mm1, mm2, mm3 = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
# Manually annotate hiding relationships
hiding_annotations = {
ag1: mm1, # mm1 hides ag1
# rs: mm2, # mm2 hides rs
ag2: mm3,
}
if final_mm_hidden:
hiding_annotations[rs] = mm2
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Run bucketing logic to find buckets (without applying them, which would require process groups)
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
graph_str = str(traced.graph)
# check order of mms preserved
FileCheck().check("%mm").check("%mm_1").check("%mm_2").run(graph_str)
if final_mm_hidden:
# Should NOT bucket - 2 separate all_gathers
# Count all_gather node names (works even when wrapped in control_deps)
FileCheck().check_count("%all_gather_into_tensor", 2, exactly=False).run(
graph_str
)
else:
# Should bucket - 1 bucketed all_gather (all_gather_into_tensor_out)
FileCheck().check_count(
"%all_gather_into_tensor_out", 1, exactly=False
).run(graph_str)
def test_can_bucket_all_reduce(self):
"""
Test that all_reduce operations CAN bucket together.
Graph structure:
ar1_start -> ar2_start -> mm1 (hides ar1) -> mm2 (hides ar2) -> ar1_wait -> ar2_wait
"""
def func(a, b):
group_name = "0"
# Start both all_reduce operations
ar1 = torch.ops._c10d_functional.all_reduce(a, "sum", group_name)
ar2 = torch.ops._c10d_functional.all_reduce(b, "sum", group_name)
# Independent compute that can hide both
mm1 = torch.mm(a, a)
mm2 = torch.mm(b, b)
# Wait for both
ar1_out = torch.ops._c10d_functional.wait_tensor(ar1)
ar2_out = torch.ops._c10d_functional.wait_tensor(ar2)
return ar1_out.sum() + ar2_out.sum() + mm1.sum() + mm2.sum()
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device)
b = torch.ones(4, 4, device=self.device) * 2
# Trace with make_fx
traced = make_fx(func)(a, b)
# Find nodes
ar1, ar2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_reduce.default,
)
mm1, mm2 = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
# For all_reduce, start_node == wait_node (no separate wait)
hiding_annotations = {
ar1: mm1,
ar2: mm2,
}
# Build collective info
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Run bucketing
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
# Verify: should have 1 bucketed all_reduce
# After bucketing, there should be only one all_reduce node (the bucketed one)
graph_str = str(traced.graph)
FileCheck().check_count("%all_reduce", 1, exactly=True).check_count(
"%mm", 2
).run(graph_str)
def test_can_bucket_multidtype_collectives(self):
"""
Test that all_gathers with different dtypes CAN bucket together.
Graph structure:
ag1_float32 -> mm1 (hides ag1) -> ag1_wait
ag2_bfloat16 -> mm2 (hides ag2) -> ag2_wait
"""
def func(a, b):
group_name = "0"
group_size = 1
# Start both collectives with different dtypes
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a,
group_size,
group_name, # float32
)
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
b,
group_size,
group_name, # bfloat16
)
# Independent compute that can hide both
mm1 = torch.mm(a, a)
mm2 = torch.mm(b.float(), b.float())
# Wait for both
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum()
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device, dtype=torch.float32)
b = torch.ones(4, 4, device=self.device, dtype=torch.bfloat16)
# Trace with make_fx
traced = make_fx(func)(a, b)
# Find nodes using find_nodes
ag1, ag2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
mm_nodes = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
mm1 = mm_nodes[0]
mm2 = mm_nodes[1]
# Manually annotate hiding relationships
hiding_annotations = {
ag1: mm1, # mm1 hides ag1
ag2: mm2, # mm2 hides ag2
}
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Run bucketing with multidtype mode
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
bucket_mode="custom_ops_multidtype",
)
bucketer.bucket_collectives()
# Verify: should have 1 bucketed collective (all_gather_into_tensor_out)
# even though dtypes are different
graph_str = str(traced.graph)
FileCheck().check_count("all_gather_into_tensor_out", 1, exactly=False).run(
graph_str
)
if __name__ == "__main__":
run_tests()

View File

@ -2064,23 +2064,6 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
self.assertEqual(f(), 1)
def test_error_on_graph_break_nonempty_checkpoint(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts)
def fn(x):
x = x + 1
x = x + 1
x = x + 1
with torch._dynamo.error_on_graph_break(True):
torch._dynamo.graph_break()
return x + 1
with self.assertRaises(Unsupported):
fn(torch.ones(3))
self.assertEqual(cnts.frame_count, 0)
def test_nested_compile_fullgraph(self):
# Test that fullgraph=True cannot be toggled back by fullgraph=False
inp = torch.ones(3)

View File

@ -341,7 +341,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
def fn(x, d):
y = 0
for idx, value in enumerate(d.values()):
for idx, (key, value) in enumerate(d.items()):
if idx == 0:
y += torch.sin(x * value)
else:
@ -366,7 +366,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
def fn(x, d):
y = 0
for idx, value in enumerate(d.values()):
for idx, (key, value) in enumerate(d.items()):
if idx == 0:
y += torch.sin(x * value)
else:
@ -847,7 +847,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
d = {"a": 2, "b": 3, "c": 5 * x}
mp = types.MappingProxyType(d)
y = torch.sin(x * mp["a"])
for v in mp.values():
for k, v in mp.items(): # noqa: PERF102
y += torch.cos(x * v)
return mp
@ -864,7 +864,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
def fn(x):
mp = types.MappingProxyType(d)
y = torch.sin(x * mp["a"])
for v in mp.values():
for k, v in mp.items(): # noqa: PERF102
y += torch.cos(x * v)
d["d"] = 4
return mp
@ -885,7 +885,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
def fn(x, mp):
y = torch.sin(x * mp["a"])
for v in mp.values():
for k, v in mp.items(): # noqa: PERF102
y += torch.cos(x * v)
if isinstance(mp, types.MappingProxyType):
y *= 2
@ -1100,20 +1100,6 @@ class DictTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
def test_iter_default_dict(self):
def f(x):
d = defaultdict(list)
d[0] = 42
for k in d:
d[k] += 1
return x + 1, d
x = torch.ones(2)
ref = f(x)
res = torch.compile(f, backend="eager", fullgraph=True)(x)
self.assertEqual(ref, res)
@parametrize("op", ["or_", "and_", "xor", "sub"])
def test_dict_keys_binop(self, op):
op = getattr(operator, op)
@ -1637,12 +1623,6 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
self.assertNotEqual(self.thetype, other)
self.assertTrue(self.thetype is not other, f"{self.thetype=}, {other=}")
@make_dynamo_test
def test_dict___iter__(self):
d = self.thetype({1: 2})
it = d.__iter__()
self.assertEqual(next(it), 1)
class DictSubclassMethodsTests(DictMethodsTests):
thetype = SimpleDict

View File

@ -147,8 +147,8 @@ class GraphModule(torch.nn.Module):
t: "f32[10]" = l_x_ + l_y_
trace_point_tensor_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_spec
trace_point_tensor_input_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_input_spec
trace_point_tensor_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_spec
trace_point_tensor_input_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_input_spec
res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
return (res,)
""", # NOQA: B950

View File

@ -363,40 +363,6 @@ class FxGraphRunnableTest(TestCase):
self._exec_and_verify_payload()
def test_metrics_context(self):
"""
When TORCH_COMPILE_DEBUG is set, provenance_tracking_level is set to 1, and
the generated fx_graph_runnable crashed with,
RuntimeError: Cannot add inductor_provenance outside of a MetricsContext
"""
import torch._inductor.config as inductor_config
def f(x):
return x * 2 + 1
# Enable provenance tracking to trigger the code path that adds metrics
with inductor_config.patch(
{"trace.enabled": True, "trace.provenance_tracking_level": 1}
):
x = torch.randn(4, 4)
torch.compile(f)(x)
self._exec_and_verify_payload()
@torch._dynamo.config.patch(assume_static_by_default=False)
def test_dynamic_expression(self):
"""
Test not emitting something like "s27*s53**2 = 36"
"""
def f(x):
return torch.ops.aten._adaptive_avg_pool2d(
x, (6, 6)
), torch.ops.aten._adaptive_avg_pool2d(x + 1, (2, 5))
x = torch.randn(2, 4, 16, 16)
torch.compile(f)(x)
self._exec_and_verify_payload()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -2858,7 +2858,7 @@ class GraphModule(torch.nn.Module):
def fn(x):
return wrap(lambda x: model(x), x)
for _ in range(2):
for i in range(2):
# second iteration is key, hooks would have fired during aot trace
# on first iter
activations.clear()

Some files were not shown because too many files have changed in this diff Show More