mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 07:27:32 +08:00
Compare commits
135 Commits
mwizak/fix
...
fixflashgi
| Author | SHA1 | Date | |
|---|---|---|---|
| 2f10f1b888 | |||
| 517f267085 | |||
| cac7242b91 | |||
| b54dc58cb5 | |||
| 4efdd216bd | |||
| 95654a32f5 | |||
| 2f5c2ccf7a | |||
| 813cae6074 | |||
| ef4730d5bb | |||
| 3ad3df90c3 | |||
| 257bf0e654 | |||
| 02d16522d8 | |||
| e6d3372157 | |||
| 3f3d86adf2 | |||
| 58478b0ab8 | |||
| 98e554222f | |||
| 700d608f4a | |||
| 1b27857415 | |||
| 73995b1b5e | |||
| 03d7c77071 | |||
| 019d9cda40 | |||
| 3620191a0a | |||
| 5a722ca130 | |||
| 8746e3cea2 | |||
| 8cd1996b57 | |||
| 73c23f3554 | |||
| 1da3d6f595 | |||
| b1033789fe | |||
| 07d896fa48 | |||
| 31681bcacc | |||
| e901866dd7 | |||
| 70d1043bdf | |||
| 69fa26d9b4 | |||
| d9c80ef97d | |||
| ac1bc51608 | |||
| ed90040d33 | |||
| 4dab208d97 | |||
| 9fd53a2bdc | |||
| 17ab99463a | |||
| eca6ac2293 | |||
| 12d4cb0122 | |||
| 590224f83c | |||
| cc8b14d09a | |||
| 96c3b9e275 | |||
| 9ddfc59b9b | |||
| 6d4dfa0878 | |||
| 11ccb95ccb | |||
| bd0907dc4c | |||
| 8bb71c07c4 | |||
| fa90090735 | |||
| 591997490a | |||
| 531f3bf5e1 | |||
| 2a5ce2feb4 | |||
| 3787a5a60e | |||
| c66d18d24d | |||
| e0f118585f | |||
| 10a005e87f | |||
| 1f3995cdc8 | |||
| abfcce58a4 | |||
| 5b1c39f5a1 | |||
| 8df3f2fa98 | |||
| 7a9119948e | |||
| 28c1d2f81b | |||
| c4bbc6433e | |||
| ad7e3c93b1 | |||
| 7f3dc45300 | |||
| ff715366aa | |||
| 60a4961ff4 | |||
| bec6541d84 | |||
| 1f1de20ba9 | |||
| 2810977d3a | |||
| ae4fd4ea75 | |||
| adc11a7634 | |||
| 99e28ffab3 | |||
| 01dd2c2b42 | |||
| d3bdf8c32e | |||
| 1ce9563ff6 | |||
| 9e631392dc | |||
| 1cce6efdb8 | |||
| 5a93f00c79 | |||
| e30f01b5b5 | |||
| ffc645c870 | |||
| 60f0a356fd | |||
| d2c5f231f6 | |||
| cc5d74c366 | |||
| a707042353 | |||
| d615f6b935 | |||
| 719b64ee8b | |||
| 1cf1b9138d | |||
| 5ed4672477 | |||
| 2600f8b3d1 | |||
| 9ce31e4278 | |||
| 0657de9c61 | |||
| 4ead8ebf70 | |||
| d4b785a6a7 | |||
| 9278b18ec0 | |||
| 008b0a9425 | |||
| 44677ad917 | |||
| 1c9987fdf4 | |||
| 7cbc011700 | |||
| 09c774145e | |||
| 763ab2a6ed | |||
| 4b8fe795f8 | |||
| 84e1cd7392 | |||
| 937869657e | |||
| 7d7ae4d7b2 | |||
| 906fe7b120 | |||
| 7edd18f0fd | |||
| 3564cd294c | |||
| 1412a4a42f | |||
| 96330f490d | |||
| 66abba8f49 | |||
| e88cca0691 | |||
| 5c020beba4 | |||
| edd9e07aff | |||
| 0fb89b84b9 | |||
| 79fcfd49d6 | |||
| 71b4fada57 | |||
| 46ec0664e3 | |||
| 410ed3006b | |||
| 77354e22e1 | |||
| 7f29c47a4f | |||
| ace6c76103 | |||
| 1310d6a1f9 | |||
| 7f4c3e7d2f | |||
| 6e5b4249a5 | |||
| 5274753873 | |||
| 7afcb030d8 | |||
| bbf6816f35 | |||
| ace89350fc | |||
| 7d59e37434 | |||
| 92108f4abd | |||
| 0b2fdc30a2 | |||
| 0d7994ca97 | |||
| c39357bab6 |
@ -13,49 +13,6 @@ def list_dir(path: str) -> list[str]:
|
||||
return check_output(["ls", "-1", path]).decode().split("\n")
|
||||
|
||||
|
||||
def build_ArmComputeLibrary() -> None:
|
||||
"""
|
||||
Using ArmComputeLibrary for aarch64 PyTorch
|
||||
"""
|
||||
print("Building Arm Compute Library")
|
||||
acl_build_flags = [
|
||||
"debug=0",
|
||||
"neon=1",
|
||||
"opencl=0",
|
||||
"os=linux",
|
||||
"openmp=1",
|
||||
"cppthreads=0",
|
||||
"arch=armv8a",
|
||||
"multi_isa=1",
|
||||
"fixed_format_kernels=1",
|
||||
"build=native",
|
||||
]
|
||||
acl_install_dir = "/acl"
|
||||
acl_checkout_dir = os.getenv("ACL_SOURCE_DIR", "ComputeLibrary")
|
||||
if os.path.isdir(acl_install_dir):
|
||||
shutil.rmtree(acl_install_dir)
|
||||
if not os.path.isdir(acl_checkout_dir) or not len(os.listdir(acl_checkout_dir)):
|
||||
check_call(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"https://github.com/ARM-software/ComputeLibrary.git",
|
||||
"-b",
|
||||
"v25.02",
|
||||
"--depth",
|
||||
"1",
|
||||
"--shallow-submodules",
|
||||
]
|
||||
)
|
||||
|
||||
check_call(
|
||||
["scons", "Werror=1", f"-j{os.cpu_count()}"] + acl_build_flags,
|
||||
cwd=acl_checkout_dir,
|
||||
)
|
||||
for d in ["arm_compute", "include", "utils", "support", "src", "build"]:
|
||||
shutil.copytree(f"{acl_checkout_dir}/{d}", f"{acl_install_dir}/{d}")
|
||||
|
||||
|
||||
def replace_tag(filename) -> None:
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
@ -356,19 +313,13 @@ if __name__ == "__main__":
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
|
||||
|
||||
if enable_mkldnn:
|
||||
build_ArmComputeLibrary()
|
||||
print("build pytorch with mkldnn+acl backend")
|
||||
build_vars += (
|
||||
"USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
|
||||
"ACL_ROOT_DIR=/acl "
|
||||
"LD_LIBRARY_PATH=/pytorch/build/lib:/acl/build:$LD_LIBRARY_PATH "
|
||||
"ACL_INCLUDE_DIR=/acl/build "
|
||||
"ACL_LIBRARY=/acl/build "
|
||||
)
|
||||
build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
|
||||
build_vars += "ACL_ROOT_DIR=/acl "
|
||||
if enable_cuda:
|
||||
build_vars += "BLAS=NVPL "
|
||||
else:
|
||||
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/OpenBLAS "
|
||||
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS "
|
||||
else:
|
||||
print("build pytorch without mkldnn backend")
|
||||
|
||||
|
||||
@ -299,40 +299,6 @@ def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None:
|
||||
)
|
||||
|
||||
|
||||
def build_OpenBLAS(host: RemoteHost, git_clone_flags: str = "") -> None:
|
||||
print("Building OpenBLAS")
|
||||
host.run_cmd(
|
||||
f"git clone https://github.com/xianyi/OpenBLAS -b v0.3.28 {git_clone_flags}"
|
||||
)
|
||||
make_flags = "NUM_THREADS=64 USE_OPENMP=1 NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=ARMV8"
|
||||
host.run_cmd(
|
||||
f"pushd OpenBLAS && make {make_flags} -j8 && sudo make {make_flags} install && popd && rm -rf OpenBLAS"
|
||||
)
|
||||
|
||||
|
||||
def build_ArmComputeLibrary(host: RemoteHost, git_clone_flags: str = "") -> None:
|
||||
print("Building Arm Compute Library")
|
||||
acl_build_flags = " ".join(
|
||||
[
|
||||
"debug=0",
|
||||
"neon=1",
|
||||
"opencl=0",
|
||||
"os=linux",
|
||||
"openmp=1",
|
||||
"cppthreads=0",
|
||||
"arch=armv8a",
|
||||
"multi_isa=1",
|
||||
"fixed_format_kernels=1",
|
||||
"build=native",
|
||||
]
|
||||
)
|
||||
host.run_cmd(
|
||||
f"git clone https://github.com/ARM-software/ComputeLibrary.git -b v25.02 {git_clone_flags}"
|
||||
)
|
||||
|
||||
host.run_cmd(f"cd ComputeLibrary && scons Werror=1 -j8 {acl_build_flags}")
|
||||
|
||||
|
||||
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
|
||||
host.run_cmd("pip3 install auditwheel")
|
||||
host.run_cmd(
|
||||
@ -700,7 +666,6 @@ def start_build(
|
||||
configure_system(
|
||||
host, compiler=compiler, use_conda=use_conda, python_version=python_version
|
||||
)
|
||||
build_OpenBLAS(host, git_clone_flags)
|
||||
|
||||
if host.using_docker():
|
||||
print("Move libgfortant.a into a standard location")
|
||||
@ -723,6 +688,8 @@ def start_build(
|
||||
f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}"
|
||||
)
|
||||
|
||||
host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh")
|
||||
|
||||
print("Building PyTorch wheel")
|
||||
build_opts = ""
|
||||
if pytorch_build_number is not None:
|
||||
@ -743,16 +710,18 @@ def start_build(
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
if enable_mkldnn:
|
||||
build_ArmComputeLibrary(host, git_clone_flags)
|
||||
host.run_cmd("pytorch/.ci/docker/common/install_acl.sh")
|
||||
print("build pytorch with mkldnn+acl backend")
|
||||
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
|
||||
build_vars += " BLAS=OpenBLAS"
|
||||
build_vars += " OpenBLAS_HOME=/opt/OpenBLAS"
|
||||
build_vars += " ACL_ROOT_DIR=/acl"
|
||||
host.run_cmd(
|
||||
f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && "
|
||||
f"{build_vars} python3 -m build --wheel --no-isolation{build_opts}"
|
||||
f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
|
||||
)
|
||||
print("Repair the wheel")
|
||||
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
|
||||
ld_library_path = "$HOME/acl/build:$HOME/pytorch/build/lib"
|
||||
ld_library_path = "/acl/build:$HOME/pytorch/build/lib"
|
||||
host.run_cmd(
|
||||
f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}"
|
||||
)
|
||||
@ -908,7 +877,7 @@ def terminate_instances(instance_type: str) -> None:
|
||||
def parse_arguments():
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("Builid and test AARCH64 wheels using EC2")
|
||||
parser = ArgumentParser("Build and test AARCH64 wheels using EC2")
|
||||
parser.add_argument("--key-name", type=str)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--build-only", action="store_true")
|
||||
|
||||
@ -1 +1 @@
|
||||
bbb06c0334a6772b92d24bde54956e675c8c6604
|
||||
27664085f804afc83df26f740bb46c365854f2c4
|
||||
|
||||
27
.ci/docker/common/install_acl.sh
Normal file → Executable file
27
.ci/docker/common/install_acl.sh
Normal file → Executable file
@ -1,16 +1,27 @@
|
||||
set -euo pipefail
|
||||
#!/bin/bash
|
||||
# Script used only in CD pipeline
|
||||
|
||||
readonly version=v25.02
|
||||
readonly src_host=https://github.com/ARM-software
|
||||
readonly src_repo=ComputeLibrary
|
||||
set -eux
|
||||
|
||||
ACL_VERSION=${ACL_VERSION:-"v25.02"}
|
||||
ACL_INSTALL_DIR="/acl"
|
||||
|
||||
# Clone ACL
|
||||
[[ ! -d ${src_repo} ]] && git clone ${src_host}/${src_repo}.git
|
||||
cd ${src_repo}
|
||||
|
||||
git checkout $version
|
||||
git clone https://github.com/ARM-software/ComputeLibrary.git -b "${ACL_VERSION}" --depth 1 --shallow-submodules
|
||||
|
||||
ACL_CHECKOUT_DIR="ComputeLibrary"
|
||||
# Build with scons
|
||||
pushd $ACL_CHECKOUT_DIR
|
||||
scons -j8 Werror=0 debug=0 neon=1 opencl=0 embed_kernels=0 \
|
||||
os=linux arch=armv8a build=native multi_isa=1 \
|
||||
fixed_format_kernels=1 openmp=1 cppthreads=0
|
||||
popd
|
||||
|
||||
# Install ACL
|
||||
sudo mkdir -p ${ACL_INSTALL_DIR}
|
||||
for d in arm_compute include utils support src build
|
||||
do
|
||||
sudo cp -r ${ACL_CHECKOUT_DIR}/${d} ${ACL_INSTALL_DIR}/${d}
|
||||
done
|
||||
|
||||
rm -rf $ACL_CHECKOUT_DIR
|
||||
12
.ci/docker/common/install_openblas.sh
Normal file → Executable file
12
.ci/docker/common/install_openblas.sh
Normal file → Executable file
@ -3,8 +3,10 @@
|
||||
|
||||
set -ex
|
||||
|
||||
cd /
|
||||
git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION:-v0.3.30}" --depth 1 --shallow-submodules
|
||||
OPENBLAS_VERSION=${OPENBLAS_VERSION:-"v0.3.30"}
|
||||
|
||||
# Clone OpenBLAS
|
||||
git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" --depth 1 --shallow-submodules
|
||||
|
||||
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
|
||||
OPENBLAS_BUILD_FLAGS="
|
||||
@ -17,5 +19,7 @@ CFLAGS=-O3
|
||||
BUILD_BFLOAT16=1
|
||||
"
|
||||
|
||||
make -j8 ${OPENBLAS_BUILD_FLAGS} -C ${OPENBLAS_CHECKOUT_DIR}
|
||||
make -j8 ${OPENBLAS_BUILD_FLAGS} install -C ${OPENBLAS_CHECKOUT_DIR}
|
||||
make -j8 ${OPENBLAS_BUILD_FLAGS} -C $OPENBLAS_CHECKOUT_DIR
|
||||
sudo make install -C $OPENBLAS_CHECKOUT_DIR
|
||||
|
||||
rm -rf $OPENBLAS_CHECKOUT_DIR
|
||||
@ -62,6 +62,13 @@ ARG OPENBLAS_VERSION
|
||||
ADD ./common/install_openblas.sh install_openblas.sh
|
||||
RUN bash ./install_openblas.sh && rm install_openblas.sh
|
||||
|
||||
# Install Arm Compute Library
|
||||
FROM base as arm_compute
|
||||
# use python3.9 to install scons
|
||||
RUN python3.9 -m pip install scons==4.7.0
|
||||
RUN ln -sf /opt/python/cp39-cp39/bin/scons /usr/local/bin
|
||||
COPY ./common/install_acl.sh install_acl.sh
|
||||
RUN bash ./install_acl.sh && rm install_acl.sh
|
||||
FROM base as final
|
||||
|
||||
# remove unnecessary python versions
|
||||
@ -70,4 +77,5 @@ RUN rm -rf /opt/python/cp26-cp26mu /opt/_internal/cpython-2.6.9-ucs4
|
||||
RUN rm -rf /opt/python/cp33-cp33m /opt/_internal/cpython-3.3.6
|
||||
RUN rm -rf /opt/python/cp34-cp34m /opt/_internal/cpython-3.4.6
|
||||
COPY --from=openblas /opt/OpenBLAS/ /opt/OpenBLAS/
|
||||
ENV LD_LIBRARY_PATH=/opt/OpenBLAS/lib:$LD_LIBRARY_PATH
|
||||
COPY --from=arm_compute /acl /acl
|
||||
ENV LD_LIBRARY_PATH=/opt/OpenBLAS/lib:/acl/build/:$LD_LIBRARY_PATH
|
||||
@ -86,6 +86,15 @@ FROM base as nvpl
|
||||
ADD ./common/install_nvpl.sh install_nvpl.sh
|
||||
RUN bash ./install_nvpl.sh && rm install_nvpl.sh
|
||||
|
||||
# Install Arm Compute Library
|
||||
FROM base as arm_compute
|
||||
# use python3.9 to install scons
|
||||
RUN python3.9 -m pip install scons==4.7.0
|
||||
RUN ln -sf /opt/python/cp39-cp39/bin/scons /usr/local/bin
|
||||
COPY ./common/install_acl.sh install_acl.sh
|
||||
RUN bash ./install_acl.sh && rm install_acl.sh
|
||||
FROM base as final
|
||||
|
||||
FROM final as cuda_final
|
||||
ARG BASE_CUDA_VERSION
|
||||
RUN rm -rf /usr/local/cuda-${BASE_CUDA_VERSION}
|
||||
@ -93,5 +102,7 @@ COPY --from=cuda /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda-${BAS
|
||||
COPY --from=magma /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda-${BASE_CUDA_VERSION}
|
||||
COPY --from=nvpl /opt/nvpl/lib/ /usr/local/lib/
|
||||
COPY --from=nvpl /opt/nvpl/include/ /usr/local/include/
|
||||
COPY --from=arm_compute /acl /acl
|
||||
RUN ln -sf /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda
|
||||
ENV PATH=/usr/local/cuda/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=/acl/build/:$LD_LIBRARY_PATH
|
||||
|
||||
@ -28,6 +28,7 @@ fi
|
||||
MANY_LINUX_VERSION=${MANY_LINUX_VERSION:-}
|
||||
DOCKERFILE_SUFFIX=${DOCKERFILE_SUFFIX:-}
|
||||
OPENBLAS_VERSION=${OPENBLAS_VERSION:-}
|
||||
ACL_VERSION=${ACL_VERSION:-}
|
||||
|
||||
case ${image} in
|
||||
manylinux2_28-builder:cpu)
|
||||
@ -41,7 +42,6 @@ case ${image} in
|
||||
GPU_IMAGE=arm64v8/almalinux:8
|
||||
DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=13 --build-arg NINJA_VERSION=1.12.1"
|
||||
MANY_LINUX_VERSION="2_28_aarch64"
|
||||
OPENBLAS_VERSION="v0.3.30"
|
||||
;;
|
||||
manylinuxs390x-builder:cpu-s390x)
|
||||
TARGET=final
|
||||
@ -119,7 +119,8 @@ tmp_tag=$(basename "$(mktemp -u)" | tr '[:upper:]' '[:lower:]')
|
||||
DOCKER_BUILDKIT=1 docker build \
|
||||
${DOCKER_GPU_BUILD_ARG} \
|
||||
--build-arg "GPU_IMAGE=${GPU_IMAGE}" \
|
||||
--build-arg "OPENBLAS_VERSION=${OPENBLAS_VERSION}" \
|
||||
--build-arg "OPENBLAS_VERSION=${OPENBLAS_VERSION:-}" \
|
||||
--build-arg "ACL_VERSION=${ACL_VERSION:-}" \
|
||||
--target "${TARGET}" \
|
||||
-t "${tmp_tag}" \
|
||||
$@ \
|
||||
|
||||
@ -52,10 +52,10 @@ flatbuffers==24.12.23
|
||||
#Pinned versions: 24.12.23
|
||||
#test that import:
|
||||
|
||||
hypothesis==5.35.1
|
||||
hypothesis==6.56.4
|
||||
# Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
|
||||
#Description: advanced library for generating parametrized tests
|
||||
#Pinned versions: 5.35.1
|
||||
#Pinned versions: 6.56.4
|
||||
#test that import: test_xnnpack_integration.py, test_pruning_op.py, test_nn.py
|
||||
|
||||
junitparser==2.1.1
|
||||
@ -98,7 +98,7 @@ librosa==0.10.2 ; python_version == "3.12" and platform_machine != "s390x"
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
mypy==1.16.0 ; platform_system != "Windows"
|
||||
mypy==1.16.0 ; platform_system == "Linux"
|
||||
# Pin MyPy version because new errors are likely to appear with each release
|
||||
# Skip on Windows as lots of type annotations are POSIX specific
|
||||
#Description: linter
|
||||
@ -169,7 +169,7 @@ optree==0.13.0
|
||||
|
||||
pillow==11.0.0
|
||||
#Description: Python Imaging Library fork
|
||||
#Pinned versions: 10.3.0
|
||||
#Pinned versions: 11.0.0
|
||||
#test that import:
|
||||
|
||||
protobuf==5.29.5
|
||||
@ -217,7 +217,7 @@ pytest-subtests==0.13.1
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
xdoctest==1.1.0
|
||||
xdoctest==1.3.0
|
||||
#Description: runs doctests in pytest
|
||||
#Pinned versions: 1.1.0
|
||||
#test that import:
|
||||
@ -268,7 +268,7 @@ scipy==1.14.1 ; python_version >= "3.12"
|
||||
#test that import:
|
||||
|
||||
# needed by torchgen utils
|
||||
typing-extensions>=4.10.0
|
||||
typing-extensions==4.12.2
|
||||
#Description: type hints for python
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
@ -361,9 +361,10 @@ pwlf==2.2.1
|
||||
#test that import: test_sac_estimator.py
|
||||
|
||||
# To build PyTorch itself
|
||||
pyyaml
|
||||
pyyaml==6.0.2
|
||||
pyzstd
|
||||
setuptools>=70.1.0
|
||||
setuptools==78.1.1
|
||||
packaging==23.1
|
||||
six
|
||||
|
||||
scons==4.5.2 ; platform_machine == "aarch64"
|
||||
@ -384,7 +385,10 @@ cmake==3.31.6
|
||||
tlparse==0.4.0
|
||||
#Description: required for log parsing
|
||||
|
||||
cuda-bindings>=12.0,<13.0 ; platform_machine != "s390x"
|
||||
filelock==3.18.0
|
||||
#Description: required for inductor testing
|
||||
|
||||
cuda-bindings>=12.0,<13.0 ; platform_machine != "s390x" and platform_system != "Darwin"
|
||||
#Description: required for testing CUDAGraph::raw_cuda_graph(). See https://nvidia.github.io/cuda-python/cuda-bindings/latest/support.html for how this version was chosen. Note "Any fix in the latest bindings would be backported to the prior major version" means that only the newest version of cuda-bindings will get fixes. Depending on the latest version of 12.x is okay because all 12.y versions will be supported via "CUDA minor version compatibility". Pytorch builds against 13.z versions of cuda toolkit work with 12.x versions of cuda-bindings as well because newer drivers work with old toolkits.
|
||||
#test that import: test_cuda.py
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ standard-imghdr==3.13.0; python_version >= "3.13"
|
||||
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
|
||||
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
|
||||
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@d53b0ffb9b1cda68260693ea98f3483823c88d8e#egg=pytorch_sphinx_theme2
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
# something related to Docker setup. We can investigate this later.
|
||||
|
||||
@ -107,6 +107,10 @@ if [[ $ROCM_INT -ge 60200 ]]; then
|
||||
ROCM_SO_FILES+=("librocm-core.so")
|
||||
fi
|
||||
|
||||
if [[ $ROCM_INT -ge 70000 ]]; then
|
||||
ROCM_SO_FILES+=("librocroller.so")
|
||||
fi
|
||||
|
||||
OS_NAME=`awk -F= '/^NAME/{print $2}' /etc/os-release`
|
||||
if [[ "$OS_NAME" == *"CentOS Linux"* || "$OS_NAME" == *"AlmaLinux"* ]]; then
|
||||
LIBGOMP_PATH="/usr/lib64/libgomp.so.1"
|
||||
|
||||
@ -89,7 +89,7 @@ fi
|
||||
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
|
||||
export USE_MKLDNN=1
|
||||
export USE_MKLDNN_ACL=1
|
||||
export ACL_ROOT_DIR=/ComputeLibrary
|
||||
export ACL_ROOT_DIR=/acl
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then
|
||||
|
||||
@ -26,6 +26,7 @@ if [[ "${SHARD_NUMBER:-2}" == "2" ]]; then
|
||||
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
|
||||
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
|
||||
time python test/run_test.py --verbose -i distributed/test_compute_comm_reordering
|
||||
time python test/run_test.py --verbose -i distributed/test_aten_comm_compute_reordering
|
||||
time python test/run_test.py --verbose -i distributed/test_store
|
||||
time python test/run_test.py --verbose -i distributed/test_symmetric_memory
|
||||
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
|
||||
|
||||
@ -435,7 +435,7 @@ test_inductor_distributed() {
|
||||
|
||||
# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
|
||||
# with if required # gpus aren't available
|
||||
python test/run_test.py --include distributed/test_dynamo_distributed distributed/test_inductor_collectives distributed/test_compute_comm_reordering --verbose
|
||||
python test/run_test.py --include distributed/test_dynamo_distributed distributed/test_inductor_collectives distributed/test_aten_comm_compute_reordering distributed/test_compute_comm_reordering --verbose
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
32
.ci/pytorch/test_fa3_abi_stable.sh
Executable file
32
.ci/pytorch/test_fa3_abi_stable.sh
Executable file
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex -o pipefail
|
||||
|
||||
# Suppress ANSI color escape sequences
|
||||
export TERM=vt100
|
||||
|
||||
# shellcheck source=./common.sh
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common.sh"
|
||||
# shellcheck source=./common-build.sh
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh"
|
||||
|
||||
echo "Environment variables"
|
||||
env
|
||||
|
||||
echo "Testing FA3 stable wheel still works with currently built torch"
|
||||
|
||||
echo "Installing ABI Stable FA3 wheel"
|
||||
# The wheel was built on https://github.com/Dao-AILab/flash-attention/commit/b3846b059bf6b143d1cd56879933be30a9f78c81
|
||||
# on torch nightly torch==2.9.0.dev20250830+cu129
|
||||
$MAYBE_SUDO pip -q install https://s3.amazonaws.com/ossci-linux/wheels/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
|
||||
|
||||
pushd flash-attention/hopper
|
||||
export PYTHONPATH=$PWD
|
||||
pytest -v -s \
|
||||
"test_flash_attn.py::test_flash_attn_output[1-1-192-False-False-False-0.0-False-False-mha-dtype0]" \
|
||||
"test_flash_attn.py::test_flash_attn_varlen_output[511-1-64-True-False-False-0.0-False-False-gqa-dtype2]" \
|
||||
"test_flash_attn.py::test_flash_attn_kvcache[1-128-128-False-False-True-None-0.0-False-False-True-False-True-False-gqa-dtype0]" \
|
||||
"test_flash_attn.py::test_flash_attn_race_condition[97-97-192-True-dtype0]" \
|
||||
"test_flash_attn.py::test_flash_attn_combine[2-3-64-dtype1]" \
|
||||
"test_flash_attn.py::test_flash3_bw_compatibility"
|
||||
popd
|
||||
@ -38,10 +38,12 @@ if errorlevel 1 goto fail
|
||||
if not errorlevel 0 goto fail
|
||||
|
||||
:: Update CMake
|
||||
:: TODO: Investigate why this helps MKL detection, even when CMake from choco is not used
|
||||
call choco upgrade -y cmake --no-progress --installargs 'ADD_CMAKE_TO_PATH=System' --apply-install-arguments-to-dependencies --version=3.27.9
|
||||
if errorlevel 1 goto fail
|
||||
if not errorlevel 0 goto fail
|
||||
|
||||
:: TODO: Move to .ci/docker/requirements-ci.txt
|
||||
call pip install mkl==2024.2.0 mkl-static==2024.2.0 mkl-include==2024.2.0
|
||||
if errorlevel 1 goto fail
|
||||
if not errorlevel 0 goto fail
|
||||
|
||||
@ -37,27 +37,8 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
|
||||
export PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda"
|
||||
fi
|
||||
|
||||
# TODO: Move both of them to Windows AMI
|
||||
python -m pip install tensorboard==2.13.0 protobuf==5.29.4 pytest-subtests==0.13.1
|
||||
|
||||
# Copied from https://github.com/pytorch/test-infra/blob/be01a40157c36cd5a48391fdf44a7bc3ebd4c7e3/aws/ami/windows/scripts/Installers/Install-Pip-Dependencies.ps1#L16 with some adjustments
|
||||
# pytest-rerunfailures==10.3 as 10.2 fails with INTERNALERROR> pluggy._manager.PluginValidationError: unknown hook 'pytest_configure_node'
|
||||
# scipy from 1.6.3 to 1.10
|
||||
# expecttest from 0.1.3 to 0.3.0
|
||||
# xdoctest from 1.0.2 to 1.3.0
|
||||
python -m pip install "future==0.18.2" "hypothesis==5.35.1" "expecttest==0.3.0" "librosa>=0.6.2" "scipy==1.10.1" "psutil==5.9.1" "pynvml==11.4.1" "pillow==9.2.0" "unittest-xml-reporting<=3.2.0,>=2.0.0" "pytest==7.1.3" "pytest-xdist==2.5.0" "pytest-flakefinder==1.1.0" "pytest-rerunfailures==10.3" "pytest-shard==0.1.2" "sympy==1.11.1" "xdoctest==1.3.0" "pygments==2.12.0" "opt-einsum>=3.3" "networkx==2.8.8" "mpmath==1.2.1" "pytest-cpp==2.3.0" "boto3==1.35.42"
|
||||
|
||||
# Install Z3 optional dependency for Windows builds.
|
||||
python -m pip install z3-solver==4.15.1.0
|
||||
|
||||
# Install tlparse for test\dynamo\test_structured_trace.py UTs.
|
||||
python -m pip install tlparse==0.4.0
|
||||
|
||||
# Install parameterized
|
||||
python -m pip install parameterized==0.8.1
|
||||
|
||||
# Install pulp for testing ilps under torch\distributed\_tools
|
||||
python -m pip install pulp==2.9.0
|
||||
# TODO: Move this to .ci/docker/requirements-ci.txt
|
||||
python -m pip install "psutil==5.9.1" "pynvml==11.4.1" "pytest-shard==0.1.2"
|
||||
|
||||
run_tests() {
|
||||
# Run nvidia-smi if available
|
||||
|
||||
3
.github/actions/teardown-win/action.yml
vendored
3
.github/actions/teardown-win/action.yml
vendored
@ -23,9 +23,6 @@ runs:
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
|
||||
- name: Clean up leftover processes on non-ephemeral Windows runner
|
||||
uses: pytorch/test-infra/.github/actions/cleanup-runner@main
|
||||
|
||||
# Cleaning up Windows workspace sometimes fails flakily with device or resource busy
|
||||
# error, meaning one or more processes haven't stopped completely yet. So trying to
|
||||
# retry this step several time similar to how checkout-pytorch GHA does
|
||||
|
||||
2
.github/ci_commit_pins/vllm.txt
vendored
2
.github/ci_commit_pins/vllm.txt
vendored
@ -1 +1 @@
|
||||
0307428d65acf5cf1a73a70a7722e076bbb83f22
|
||||
78a47f87ce259a48f0391fa9ae15add05ea7432b
|
||||
|
||||
16
.github/ci_configs/vllm/Dockerfile.tmp_vllm
vendored
16
.github/ci_configs/vllm/Dockerfile.tmp_vllm
vendored
@ -202,7 +202,7 @@ ARG max_jobs=16
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
ARG nvcc_threads=4
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
|
||||
ARG torch_cuda_arch_list='8.0 8.6 8.9 9.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
ARG USE_SCCACHE
|
||||
@ -297,16 +297,28 @@ RUN echo "[INFO] Listing current directory before torch install step:" && \
|
||||
echo "[INFO] Showing torch_build_versions.txt content:" && \
|
||||
cat torch_build_versions.txt
|
||||
|
||||
# Install build and runtime dependencies, this is needed for flashinfer install
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
COPY use_existing_torch.py use_existing_torch.py
|
||||
RUN python3 use_existing_torch.py
|
||||
RUN cat requirements/build.txt
|
||||
|
||||
# Install uv for faster pip installs if not existed
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if ! python3 -m uv --version > /dev/null 2>&1; then \
|
||||
python3 -m pip install uv==0.8.4; \
|
||||
fi
|
||||
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt
|
||||
|
||||
|
||||
# Default mount file as placeholder, this just avoid the mount error
|
||||
ARG TORCH_WHEELS_PATH="./requirements"
|
||||
# Install torch, torchaudio and torchvision
|
||||
@ -332,13 +344,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
# Install xformers wheel from previous stage
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system /wheels/xformers/*.whl --verbose
|
||||
|
||||
# Build flashinfer from source.
|
||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
|
||||
# install package for build flashinfer
|
||||
# see issue: https://github.com/flashinfer-ai/flashinfer/issues/738
|
||||
|
||||
RUN pip install build==1.3.0
|
||||
RUN pip freeze | grep -E 'setuptools|packaging|build'
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
import glob
|
||||
import os
|
||||
|
||||
|
||||
requires_files = glob.glob("requirements/*.txt")
|
||||
requires_files += ["pyproject.toml"]
|
||||
|
||||
for file in requires_files:
|
||||
if not os.path.exists(file):
|
||||
print(f"!!! skipping missing {file}")
|
||||
continue
|
||||
print(f">>> cleaning {file}")
|
||||
with open(file) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
37
.github/requirements/pip-requirements-macOS.txt
vendored
37
.github/requirements/pip-requirements-macOS.txt
vendored
@ -1,37 +0,0 @@
|
||||
boto3==1.35.42
|
||||
build==1.2.2.post1
|
||||
cmake==3.27.*
|
||||
expecttest==0.3.0
|
||||
fbscribelogger==0.1.7
|
||||
filelock==3.18.0
|
||||
hypothesis==6.56.4
|
||||
librosa>=0.6.2
|
||||
mpmath==1.3.0
|
||||
networkx==2.8.7
|
||||
ninja==1.10.2.4
|
||||
numba==0.59.0
|
||||
numpy==1.26.4
|
||||
opt-einsum>=3.3
|
||||
optree==0.13.0
|
||||
packaging==23.1
|
||||
parameterized==0.8.1
|
||||
pillow==10.3.0
|
||||
protobuf==5.29.5
|
||||
psutil==5.9.8
|
||||
pygments==2.15.0
|
||||
pytest-cpp==2.3.0
|
||||
pytest-flakefinder==1.1.0
|
||||
pytest-rerunfailures==10.3
|
||||
pytest-subtests==0.13.1
|
||||
pytest-xdist==3.3.1
|
||||
pytest==7.3.2
|
||||
pyyaml==6.0.2
|
||||
scipy==1.12.0
|
||||
setuptools==78.1.1
|
||||
sympy==1.13.3
|
||||
tlparse==0.4.0
|
||||
tensorboard==2.13.0
|
||||
typing-extensions==4.12.2
|
||||
unittest-xml-reporting<=3.2.0,>=2.0.0
|
||||
xdoctest==1.1.0
|
||||
z3-solver==4.15.1.0
|
||||
93
.github/scripts/generate_ci_workflows.py
vendored
93
.github/scripts/generate_ci_workflows.py
vendored
@ -127,53 +127,6 @@ LINUX_BINARY_BUILD_WORFKLOWS = [
|
||||
),
|
||||
]
|
||||
|
||||
ROCM_SMOKE_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.LINUX,
|
||||
package_type="manywheel",
|
||||
build_variant="rocm",
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
|
||||
OperatingSystem.LINUX,
|
||||
arches=["6.4"],
|
||||
python_versions=["3.10"],
|
||||
),
|
||||
ciflow_config=CIFlowConfig(
|
||||
labels={
|
||||
LABEL_CIFLOW_BINARIES,
|
||||
LABEL_CIFLOW_BINARIES_WHEEL,
|
||||
LABEL_CIFLOW_ROCM,
|
||||
},
|
||||
isolated_workflow=True,
|
||||
),
|
||||
branches="main",
|
||||
),
|
||||
]
|
||||
|
||||
LINUX_BINARY_SMOKE_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.LINUX,
|
||||
package_type="manywheel",
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
|
||||
OperatingSystem.LINUX,
|
||||
arches=["13.0"],
|
||||
python_versions=["3.12"],
|
||||
),
|
||||
branches="main",
|
||||
),
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.LINUX,
|
||||
package_type="libtorch",
|
||||
build_variant=generate_binary_build_matrix.RELEASE,
|
||||
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
|
||||
OperatingSystem.LINUX,
|
||||
generate_binary_build_matrix.RELEASE,
|
||||
arches=["cpu"],
|
||||
libtorch_variants=["shared-with-deps"],
|
||||
),
|
||||
branches="main",
|
||||
),
|
||||
]
|
||||
|
||||
WINDOWS_BINARY_BUILD_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.WINDOWS,
|
||||
@ -259,39 +212,6 @@ WINDOWS_BINARY_BUILD_WORKFLOWS = [
|
||||
),
|
||||
]
|
||||
|
||||
WINDOWS_BINARY_SMOKE_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.WINDOWS,
|
||||
package_type="libtorch",
|
||||
build_variant=generate_binary_build_matrix.RELEASE,
|
||||
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
|
||||
OperatingSystem.WINDOWS,
|
||||
generate_binary_build_matrix.RELEASE,
|
||||
arches=["cpu"],
|
||||
libtorch_variants=["shared-with-deps"],
|
||||
),
|
||||
branches="main",
|
||||
ciflow_config=CIFlowConfig(
|
||||
isolated_workflow=True,
|
||||
),
|
||||
),
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.WINDOWS,
|
||||
package_type="libtorch",
|
||||
build_variant=generate_binary_build_matrix.DEBUG,
|
||||
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
|
||||
OperatingSystem.WINDOWS,
|
||||
generate_binary_build_matrix.DEBUG,
|
||||
arches=["cpu"],
|
||||
libtorch_variants=["shared-with-deps"],
|
||||
),
|
||||
branches="main",
|
||||
ciflow_config=CIFlowConfig(
|
||||
isolated_workflow=True,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
MACOS_BINARY_BUILD_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.MACOS_ARM64,
|
||||
@ -372,23 +292,10 @@ def main() -> None:
|
||||
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
|
||||
S390X_BINARY_BUILD_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
# Give rocm it's own workflow file
|
||||
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
|
||||
ROCM_SMOKE_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
|
||||
LINUX_BINARY_SMOKE_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
|
||||
WINDOWS_BINARY_BUILD_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
|
||||
WINDOWS_BINARY_SMOKE_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("macos_binary_build_workflow.yml.j2"),
|
||||
MACOS_BINARY_BUILD_WORKFLOWS,
|
||||
|
||||
255
.github/workflows/_linux-test-stable-fa3.yml
vendored
Normal file
255
.github/workflows/_linux-test-stable-fa3.yml
vendored
Normal file
@ -0,0 +1,255 @@
|
||||
# The point of this workflow is to test that a FA3 wheel that was built based off the
|
||||
# stable ABI as of torch nightly 20250830 can still run on the newer torch.
|
||||
#
|
||||
# This workflow is very similar to the _linux-test.yml workflow, with the following
|
||||
# differences:
|
||||
# 1. It is simpler (there is no test matrix)
|
||||
# 2. It pulls flash-attention as a secondary repository in order to access the tests.
|
||||
# Note that it does not BUILD anything from flash-attention, as we have a prebuilt
|
||||
# wheel. We pull flash-attention only to run a few tests.
|
||||
# 3. It runs only FA3 tests. No PyTorch tests are run.
|
||||
name: linux-test-stable-fa3
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
build-environment:
|
||||
required: true
|
||||
type: string
|
||||
description: Top-level label for what's being built/tested.
|
||||
docker-image:
|
||||
required: true
|
||||
type: string
|
||||
description: Docker image to run in.
|
||||
timeout-minutes:
|
||||
required: false
|
||||
type: number
|
||||
default: 30
|
||||
description: |
|
||||
Set the maximum (in minutes) how long the workflow should take to finish
|
||||
s3-bucket:
|
||||
description: S3 bucket to download artifact
|
||||
required: false
|
||||
type: string
|
||||
default: "gha-artifacts"
|
||||
secrets:
|
||||
HUGGING_FACE_HUB_TOKEN:
|
||||
required: false
|
||||
description: |
|
||||
HF Auth token to avoid rate limits when downloading models or datasets from hub
|
||||
VLLM_TEST_HUGGING_FACE_TOKEN:
|
||||
required: false
|
||||
description: |
|
||||
HF Auth token to test vllm
|
||||
SCRIBE_GRAPHQL_ACCESS_TOKEN:
|
||||
required: false
|
||||
description: |
|
||||
FB app token to write to scribe endpoint
|
||||
|
||||
env:
|
||||
GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
|
||||
jobs:
|
||||
test:
|
||||
# Don't run on forked repos
|
||||
if: github.repository_owner == 'pytorch'
|
||||
runs-on: linux.aws.h100
|
||||
timeout-minutes: ${{ inputs.timeout-minutes || 30 }}
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout PyTorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
no-sudo: true
|
||||
|
||||
- name: Checkout flash-attention as a secondary repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: Dao-AILab/flash-attention
|
||||
path: flash-attention
|
||||
|
||||
- name: Setup Linux
|
||||
uses: ./.github/actions/setup-linux
|
||||
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-image-name: ${{ inputs.docker-image }}
|
||||
|
||||
- name: Use following to pull public copy of the image
|
||||
id: print-ghcr-mirror
|
||||
env:
|
||||
ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
shell: bash
|
||||
run: |
|
||||
tag=${ECR_DOCKER_IMAGE##*:}
|
||||
echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}"
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
|
||||
- name: Check if in a container runner
|
||||
shell: bash
|
||||
id: check_container_runner
|
||||
run: echo "IN_CONTAINER_RUNNER=$(if [ -f /.inarc ] || [ -f /.incontainer ]; then echo true ; else echo false; fi)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Setup GPU_FLAG for docker run
|
||||
id: setup-gpu-flag
|
||||
run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Setup SCCACHE_SERVER_PORT environment for docker run when on container
|
||||
id: setup-sscache-port-flag
|
||||
run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}"
|
||||
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' }}
|
||||
|
||||
- name: Get workflow job id
|
||||
id: get-job-id
|
||||
uses: ./.github/actions/get-workflow-job-id
|
||||
if: always()
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Download build artifacts
|
||||
uses: ./.github/actions/download-build-artifacts
|
||||
with:
|
||||
name: ${{ inputs.build-environment }}
|
||||
s3-bucket: ${{ inputs.s3-bucket }}
|
||||
|
||||
- name: Parse ref
|
||||
id: parse-ref
|
||||
run: .github/scripts/parse_ref.py
|
||||
|
||||
- name: Set Test step time
|
||||
id: test-timeout
|
||||
shell: bash
|
||||
env:
|
||||
JOB_TIMEOUT: ${{ inputs.timeout-minutes }}
|
||||
run: |
|
||||
echo "timeout=$((JOB_TIMEOUT-30))" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Preserve github env variables for use in docker
|
||||
shell: bash
|
||||
run: |
|
||||
env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}"
|
||||
env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}"
|
||||
|
||||
- name: Test
|
||||
id: test
|
||||
timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }}
|
||||
env:
|
||||
BUILD_ENVIRONMENT: ${{ inputs.build-environment }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
GITHUB_REPOSITORY: ${{ github.repository }}
|
||||
GITHUB_WORKFLOW: ${{ github.workflow }}
|
||||
GITHUB_JOB: ${{ github.job }}
|
||||
GITHUB_RUN_ID: ${{ github.run_id }}
|
||||
GITHUB_RUN_NUMBER: ${{ github.run_number }}
|
||||
GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }}
|
||||
JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
|
||||
JOB_NAME: ${{ steps.get-job-id.outputs.job-name }}
|
||||
BRANCH: ${{ steps.parse-ref.outputs.branch }}
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
BASE_SHA: ${{ github.event.pull_request.base.sha || github.sha }}
|
||||
SHM_SIZE: '2g'
|
||||
DOCKER_IMAGE: ${{ inputs.docker-image }}
|
||||
VLLM_TEST_HUGGING_FACE_TOKEN: ${{ secrets.VLLM_TEST_HUGGING_FACE_TOKEN }}
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }}
|
||||
ARTIFACTS_FILE_SUFFIX: ${{ github.job }}-${{ steps.get-job-id.outputs.job-id }}
|
||||
run: |
|
||||
set -x
|
||||
|
||||
TEST_COMMAND=.ci/pytorch/test_fa3_abi_stable.sh
|
||||
|
||||
# Leaving 1GB for the runner and other things
|
||||
TOTAL_AVAILABLE_MEMORY_IN_GB=$(awk '/MemTotal/ { printf "%.3f \n", $2/1024/1024 - 1 }' /proc/meminfo)
|
||||
# https://docs.docker.com/engine/containers/resource_constraints/#--memory-swap-details, the 3GB swap
|
||||
# comes from https://github.com/pytorch/test-infra/pull/6058
|
||||
TOTAL_MEMORY_WITH_SWAP=$(("${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}" + 3))
|
||||
|
||||
|
||||
SHM_OPTS="--shm-size=${SHM_SIZE}"
|
||||
JENKINS_USER="--user jenkins"
|
||||
DOCKER_SHELL_CMD=
|
||||
|
||||
# detached container should get cleaned up by teardown_ec2_linux
|
||||
# TODO: Stop building test binaries as part of the build phase
|
||||
# Used for GPU_FLAG, SHM_OPTS, JENKINS_USER and DOCKER_SHELL_CMD since that doesn't play nice
|
||||
# shellcheck disable=SC2086,SC2090
|
||||
container_name=$(docker run \
|
||||
${GPU_FLAG:-} \
|
||||
${SCCACHE_SERVER_PORT_DOCKER_FLAG:-} \
|
||||
-e BUILD_ENVIRONMENT \
|
||||
-e PR_NUMBER \
|
||||
-e GITHUB_ACTIONS \
|
||||
-e GITHUB_REPOSITORY \
|
||||
-e GITHUB_WORKFLOW \
|
||||
-e GITHUB_JOB \
|
||||
-e GITHUB_RUN_ID \
|
||||
-e GITHUB_RUN_NUMBER \
|
||||
-e GITHUB_RUN_ATTEMPT \
|
||||
-e JOB_ID \
|
||||
-e JOB_NAME \
|
||||
-e BASE_SHA \
|
||||
-e BRANCH \
|
||||
-e SHA1 \
|
||||
-e MAX_JOBS="$(nproc --ignore=2)" \
|
||||
-e HUGGING_FACE_HUB_TOKEN \
|
||||
-e VLLM_TEST_HUGGING_FACE_TOKEN \
|
||||
-e SCRIBE_GRAPHQL_ACCESS_TOKEN \
|
||||
-e ARTIFACTS_FILE_SUFFIX \
|
||||
--memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \
|
||||
--memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \
|
||||
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
|
||||
--security-opt seccomp=unconfined \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--ipc=host \
|
||||
${SHM_OPTS} \
|
||||
--tty \
|
||||
--detach \
|
||||
--name="${container_name}" \
|
||||
${JENKINS_USER} \
|
||||
-v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \
|
||||
-w /var/lib/jenkins/workspace \
|
||||
"${DOCKER_IMAGE}" \
|
||||
${DOCKER_SHELL_CMD}
|
||||
)
|
||||
|
||||
echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}"
|
||||
|
||||
docker exec -t "${container_name}" sh -c "python3 -m pip install $(echo dist/*.whl)[opt-einsum] && ${TEST_COMMAND}"
|
||||
|
||||
- name: Collect backtraces from coredumps (if any)
|
||||
if: always()
|
||||
run: |
|
||||
# shellcheck disable=SC2156
|
||||
find . -iname "core.[1-9]*" -exec docker exec "${DOCKER_CONTAINER_ID}" sh -c "gdb python {} -ex 'bt' -ex 'q'" \;
|
||||
|
||||
- name: Store Core dumps on S3
|
||||
uses: seemethere/upload-artifact-s3@baba72d0712b404f646cebe0730933554ebce96a # v5.1.0
|
||||
if: failure()
|
||||
with:
|
||||
name: coredumps-fa3-stable-abi-smoke-tests
|
||||
retention-days: 14
|
||||
if-no-files-found: ignore
|
||||
path: ./**/core.[1-9]*
|
||||
|
||||
- name: Upload utilization stats
|
||||
if: ${{ always() && steps.test.conclusion && steps.test.conclusion != 'skipped' }}
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/upload-utilization-stats
|
||||
with:
|
||||
job_id: ${{ steps.get-job-id.outputs.job-id }}
|
||||
job_name: ${{ steps.get-job-id.outputs.job-name }}
|
||||
workflow_name: ${{ github.workflow }}
|
||||
workflow_run_id: ${{github.run_id}}
|
||||
workflow_attempt: ${{github.run_attempt}}
|
||||
|
||||
- name: Teardown Linux
|
||||
uses: pytorch/test-infra/.github/actions/teardown-linux@main
|
||||
if: always() && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false'
|
||||
2
.github/workflows/_mac-build.yml
vendored
2
.github/workflows/_mac-build.yml
vendored
@ -85,7 +85,7 @@ jobs:
|
||||
uses: pytorch/test-infra/.github/actions/setup-python@main
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
pip-requirements-file: .github/requirements/pip-requirements-macOS.txt
|
||||
pip-requirements-file: .ci/docker/requirements-ci.txt
|
||||
|
||||
- name: Install sccache (only for non-forked PRs, and pushes to trunk)
|
||||
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
|
||||
|
||||
2
.github/workflows/_mac-test.yml
vendored
2
.github/workflows/_mac-test.yml
vendored
@ -122,7 +122,7 @@ jobs:
|
||||
uses: pytorch/test-infra/.github/actions/setup-python@main
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
pip-requirements-file: .github/requirements/pip-requirements-macOS.txt
|
||||
pip-requirements-file: .ci/docker/requirements-ci.txt
|
||||
|
||||
- name: Start monitoring script
|
||||
id: monitor-script
|
||||
|
||||
3
.github/workflows/_win-build.yml
vendored
3
.github/workflows/_win-build.yml
vendored
@ -84,9 +84,6 @@ jobs:
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
|
||||
- name: Clean up leftover processes on non-ephemeral Windows runner
|
||||
uses: pytorch/test-infra/.github/actions/cleanup-runner@main
|
||||
|
||||
- name: Setup SSH (Click me for login details)
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
with:
|
||||
|
||||
24
.github/workflows/_win-test.yml
vendored
24
.github/workflows/_win-test.yml
vendored
@ -77,9 +77,6 @@ jobs:
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
|
||||
- name: Clean up leftover processes on non-ephemeral Windows runner
|
||||
uses: pytorch/test-infra/.github/actions/cleanup-runner@main
|
||||
|
||||
- name: Setup SSH (Click me for login details)
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
with:
|
||||
@ -106,18 +103,6 @@ jobs:
|
||||
with:
|
||||
cuda-version: ${{ inputs.cuda-version }}
|
||||
|
||||
# TODO: Move to a requirements.txt file for windows
|
||||
- name: Install pip dependencies
|
||||
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
|
||||
with:
|
||||
shell: bash
|
||||
timeout_minutes: 5
|
||||
max_attempts: 5
|
||||
retry_wait_seconds: 30
|
||||
command: |
|
||||
set -eu
|
||||
python3 -m pip install 'xdoctest>=1.1.0'
|
||||
|
||||
- name: Get workflow job id
|
||||
id: get-job-id
|
||||
uses: ./.github/actions/get-workflow-job-id
|
||||
@ -272,15 +257,6 @@ jobs:
|
||||
shell: bash
|
||||
run: python3 .github/scripts/parse_ref.py
|
||||
|
||||
- name: Uninstall PyTorch
|
||||
if: always()
|
||||
continue-on-error: true
|
||||
shell: bash
|
||||
run: |
|
||||
# This step removes PyTorch installed by the test to give a clean slate
|
||||
# to the next job
|
||||
python3 -mpip uninstall -y torch
|
||||
|
||||
- name: Teardown Windows
|
||||
uses: ./.github/actions/teardown-win
|
||||
if: always()
|
||||
|
||||
2
.github/workflows/build-almalinux-images.yml
vendored
2
.github/workflows/build-almalinux-images.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
runs-on: linux.9xlarge.ephemeral
|
||||
strategy:
|
||||
matrix:
|
||||
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.3", "rocm6.4", "rocm7.0", "cpu"]
|
||||
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.4", "rocm7.0", "cpu"]
|
||||
steps:
|
||||
- name: Build docker image
|
||||
uses: pytorch/pytorch/.github/actions/binary-docker-build@main
|
||||
|
||||
87
.github/workflows/generated-linux-binary-libtorch-release-main.yml
generated
vendored
87
.github/workflows/generated-linux-binary-libtorch-release-main.yml
generated
vendored
@ -1,87 +0,0 @@
|
||||
# @generated DO NOT EDIT MANUALLY
|
||||
|
||||
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
|
||||
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||
name: linux-binary-libtorch-release
|
||||
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- 'ciflow/trunk/*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
|
||||
env:
|
||||
# Needed for conda builds
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
BINARY_ENV_FILE: /tmp/env
|
||||
BUILD_ENVIRONMENT: linux-binary-libtorch-release
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
|
||||
PYTORCH_ROOT: /pytorch
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
SKIP_ALL_TESTS: 0
|
||||
concurrency:
|
||||
group: linux-binary-libtorch-release-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
libtorch-cpu-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
uses: ./.github/workflows/_binary-build-linux.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: cpu
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build_name: libtorch-cpu-shared-with-deps-release
|
||||
build_environment: linux-binary-libtorch-release
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
libtorch-cpu-shared-with-deps-release-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-cpu-shared-with-deps-release-build
|
||||
- get-label-type
|
||||
uses: ./.github/workflows/_binary-test-linux.yml
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: cpu
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
build_name: libtorch-cpu-shared-with-deps-release
|
||||
build_environment: linux-binary-libtorch-release
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runs_on: linux.4xlarge
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
88
.github/workflows/generated-linux-binary-manywheel-main.yml
generated
vendored
88
.github/workflows/generated-linux-binary-manywheel-main.yml
generated
vendored
@ -1,88 +0,0 @@
|
||||
# @generated DO NOT EDIT MANUALLY
|
||||
|
||||
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
|
||||
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||
name: linux-binary-manywheel
|
||||
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- 'ciflow/trunk/*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
|
||||
env:
|
||||
# Needed for conda builds
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
BINARY_ENV_FILE: /tmp/env
|
||||
BUILD_ENVIRONMENT: linux-binary-manywheel
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
|
||||
PYTORCH_ROOT: /pytorch
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
SKIP_ALL_TESTS: 0
|
||||
concurrency:
|
||||
group: linux-binary-manywheel-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
manywheel-py3_12-cuda13_0-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
uses: ./.github/workflows/_binary-build-linux.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu130
|
||||
GPU_ARCH_VERSION: "13.0"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: cuda13.0
|
||||
DESIRED_PYTHON: "3.12"
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build_name: manywheel-py3_12-cuda13_0
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.48; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.48; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.0.0.19; platform_system == 'Linux' | nvidia-cufft==12.0.0.15; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.3.29; platform_system == 'Linux' | nvidia-cusparse==12.6.2.49; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.3.24; platform_system == 'Linux' | nvidia-nvtx==13.0.39; platform_system == 'Linux' | nvidia-nvjitlink==13.0.39; platform_system == 'Linux' | nvidia-cufile==1.15.0.42; platform_system == 'Linux'
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_12-cuda13_0-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- manywheel-py3_12-cuda13_0-build
|
||||
- get-label-type
|
||||
uses: ./.github/workflows/_binary-test-linux.yml
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cu130
|
||||
GPU_ARCH_VERSION: "13.0"
|
||||
GPU_ARCH_TYPE: cuda
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: cuda13.0
|
||||
DESIRED_PYTHON: "3.12"
|
||||
build_name: manywheel-py3_12-cuda13_0
|
||||
build_environment: linux-binary-manywheel
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
136
.github/workflows/generated-linux-binary-manywheel-rocm-main.yml
generated
vendored
136
.github/workflows/generated-linux-binary-manywheel-rocm-main.yml
generated
vendored
@ -1,136 +0,0 @@
|
||||
# @generated DO NOT EDIT MANUALLY
|
||||
|
||||
# Template is at: .github/templates/linux_binary_build_workflow.yml.j2
|
||||
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||
name: linux-binary-manywheel-rocm
|
||||
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- 'ciflow/binaries/*'
|
||||
- 'ciflow/binaries_wheel/*'
|
||||
- 'ciflow/rocm/*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
|
||||
env:
|
||||
# Needed for conda builds
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
BINARY_ENV_FILE: /tmp/env
|
||||
BUILD_ENVIRONMENT: linux-binary-manywheel-rocm
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
|
||||
PYTORCH_ROOT: /pytorch
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
SKIP_ALL_TESTS: 0
|
||||
concurrency:
|
||||
group: linux-binary-manywheel-rocm-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
manywheel-py3_10-rocm6_4-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
uses: ./.github/workflows/_binary-build-linux.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.4
|
||||
GPU_ARCH_VERSION: "6.4"
|
||||
GPU_ARCH_TYPE: rocm
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
DESIRED_PYTHON: "3.10"
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
timeout-minutes: 300
|
||||
build_name: manywheel-py3_10-rocm6_4
|
||||
build_environment: linux-binary-manywheel-rocm
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_10-rocm6_4-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- manywheel-py3_10-rocm6_4-build
|
||||
- get-label-type
|
||||
runs-on: linux.rocm.gpu.mi250
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.4
|
||||
GPU_ARCH_VERSION: "6.4"
|
||||
GPU_ARCH_TYPE: rocm
|
||||
SKIP_ALL_TESTS: 1
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: manywheel-py3_10-rocm6_4
|
||||
path: "${{ runner.temp }}/artifacts/"
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: ROCm set GPU_FLAG
|
||||
run: |
|
||||
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
|
||||
docker-image-name: manylinux2_28-builder
|
||||
custom-tag-prefix: rocm6.4
|
||||
docker-build-dir: .ci/docker
|
||||
working-directory: pytorch
|
||||
- name: Pull Docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
- name: Test Pytorch binary
|
||||
uses: ./pytorch/.github/actions/test-pytorch-binary
|
||||
env:
|
||||
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
- name: Teardown ROCm
|
||||
uses: ./.github/actions/teardown-rocm
|
||||
261
.github/workflows/generated-windows-binary-libtorch-debug-main.yml
generated
vendored
261
.github/workflows/generated-windows-binary-libtorch-debug-main.yml
generated
vendored
@ -1,261 +0,0 @@
|
||||
# @generated DO NOT EDIT MANUALLY
|
||||
|
||||
# Template is at: .github/templates/windows_binary_build_workflow.yml.j2
|
||||
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||
name: windows-binary-libtorch-debug
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
# Needed for conda builds
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
BUILD_ENVIRONMENT: windows-binary-libtorch-debug
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
SKIP_ALL_TESTS: 1
|
||||
OS: windows
|
||||
concurrency:
|
||||
group: windows-binary-libtorch-debug-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
libtorch-cpu-shared-with-deps-debug-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: debug
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Build PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
|
||||
- uses: actions/upload-artifact@v4.4.0
|
||||
if: always()
|
||||
with:
|
||||
name: libtorch-cpu-shared-with-deps-debug
|
||||
retention-days: 14
|
||||
if-no-files-found: error
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
|
||||
libtorch-cpu-shared-with-deps-debug-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-cpu-shared-with-deps-debug-build
|
||||
- get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: debug
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: libtorch-cpu-shared-with-deps-debug
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Test PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
261
.github/workflows/generated-windows-binary-libtorch-release-main.yml
generated
vendored
261
.github/workflows/generated-windows-binary-libtorch-release-main.yml
generated
vendored
@ -1,261 +0,0 @@
|
||||
# @generated DO NOT EDIT MANUALLY
|
||||
|
||||
# Template is at: .github/templates/windows_binary_build_workflow.yml.j2
|
||||
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||
name: windows-binary-libtorch-release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
# Needed for conda builds
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
BUILD_ENVIRONMENT: windows-binary-libtorch-release
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
SKIP_ALL_TESTS: 1
|
||||
OS: windows
|
||||
concurrency:
|
||||
group: windows-binary-libtorch-release-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
libtorch-cpu-shared-with-deps-release-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Build PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
|
||||
- uses: actions/upload-artifact@v4.4.0
|
||||
if: always()
|
||||
with:
|
||||
name: libtorch-cpu-shared-with-deps-release
|
||||
retention-days: 14
|
||||
if-no-files-found: error
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
|
||||
libtorch-cpu-shared-with-deps-release-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs:
|
||||
- libtorch-cpu-shared-with-deps-release-build
|
||||
- get-label-type
|
||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
timeout-minutes: 360
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
PACKAGE_TYPE: libtorch
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: cpu
|
||||
GPU_ARCH_TYPE: cpu
|
||||
SKIP_ALL_TESTS: 1
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
# This is a dummy value for libtorch to work correctly with our batch scripts
|
||||
# without this value pip does not get installed for some reason
|
||||
DESIRED_PYTHON: "3.10"
|
||||
steps:
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
continue-on-error: true
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
|
||||
shell: bash
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
git config --global core.symlinks true
|
||||
|
||||
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
|
||||
# the directory on Windows and prevent GHA from checking out as reported
|
||||
# in https://github.com/actions/checkout/issues/1018
|
||||
git config --global core.fsmonitor false
|
||||
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
|
||||
- name: Enable long paths on Windows
|
||||
shell: powershell
|
||||
run: |
|
||||
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
|
||||
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
|
||||
# removed once Windows Defender is removed from the AMI
|
||||
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
|
||||
continue-on-error: true
|
||||
shell: powershell
|
||||
run: |
|
||||
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
|
||||
# Let's both exclude the path and disable Windows Defender completely just to be sure
|
||||
# that it doesn't interfere
|
||||
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
show-progress: false
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
# NOTE: These environment variables are put here so that they can be applied on every job equally
|
||||
# They are also here because setting them at a workflow level doesn't give us access to the
|
||||
# runner.temp variable, which we need.
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
|
||||
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
|
||||
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
|
||||
- uses: actions/download-artifact@v4.1.7
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: libtorch-cpu-shared-with-deps-release
|
||||
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
|
||||
- name: Populate binary env
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
|
||||
- name: Test PyTorch binary
|
||||
shell: bash
|
||||
run: |
|
||||
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
|
||||
- name: Wait until all sessions have drained
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
timeout-minutes: 120
|
||||
run: |
|
||||
.github\scripts\wait_for_ssh_to_drain.ps1
|
||||
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
|
||||
shell: powershell
|
||||
working-directory: pytorch
|
||||
if: always()
|
||||
run: |
|
||||
.github\scripts\kill_active_ssh_sessions.ps1
|
||||
12
.github/workflows/test-h100.yml
vendored
12
.github/workflows/test-h100.yml
vendored
@ -61,3 +61,15 @@ jobs:
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-sm90-FA3-ABI-stable-test:
|
||||
name: linux-jammy-cuda12_8-py3_10-gcc11-sm90-FA3-ABI-stable-test
|
||||
uses: ./.github/workflows/_linux-test-stable-fa3.yml
|
||||
needs:
|
||||
- linux-jammy-cuda12_8-py3_10-gcc11-sm90-build
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build.outputs.docker-image }}
|
||||
timeout-minutes: 30
|
||||
s3-bucket: gha-artifacts
|
||||
secrets: inherit
|
||||
|
||||
1
.github/workflows/update-viablestrict.yml
vendored
1
.github/workflows/update-viablestrict.yml
vendored
@ -49,5 +49,6 @@ jobs:
|
||||
pip install awscli==1.29.40
|
||||
aws s3 cp "/tmp/${LATEST_SHA}.json" "s3://ossci-raw-job-status/stable_pushes/pytorch/pytorch/${LATEST_SHA}.json"
|
||||
# Push new viable/strict tag
|
||||
cd pytorch/pytorch
|
||||
git push origin "${LATEST_SHA}:refs/tags/viable/strict/${TIME}"
|
||||
fi
|
||||
|
||||
2
.github/workflows/vllm.yml
vendored
2
.github/workflows/vllm.yml
vendored
@ -42,7 +42,7 @@ jobs:
|
||||
build-external-packages: "vllm"
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm
|
||||
cuda-arch-list: '8.0;8.9;9.0'
|
||||
cuda-arch-list: '8.0 8.9 9.0'
|
||||
runner: linux.24xlarge.memory
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
|
||||
@ -1260,6 +1260,7 @@ exclude_patterns = [
|
||||
'test/test_masked.py',
|
||||
'test/test_maskedtensor.py',
|
||||
'test/test_matmul_cuda.py',
|
||||
'test/test_scaled_matmul_cuda.py',
|
||||
'test/test_meta.py',
|
||||
'test/test_metal.py',
|
||||
'test/test_mkl_verbose.py',
|
||||
|
||||
@ -81,7 +81,7 @@ git remote add upstream git@github.com:pytorch/pytorch.git
|
||||
make setup-env
|
||||
# Or run `make setup-env-cuda` for pre-built CUDA binaries
|
||||
# Or run `make setup-env-rocm` for pre-built ROCm binaries
|
||||
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
|
||||
```
|
||||
|
||||
### Tips and Debugging
|
||||
@ -182,28 +182,36 @@ You can use this script to check out a new nightly branch with the following:
|
||||
|
||||
```bash
|
||||
./tools/nightly.py checkout -b my-nightly-branch
|
||||
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
|
||||
```
|
||||
|
||||
To install the nightly binaries built with CUDA, you can pass in the flag `--cuda`:
|
||||
|
||||
```bash
|
||||
./tools/nightly.py checkout -b my-nightly-branch --cuda
|
||||
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
|
||||
```
|
||||
|
||||
To install the nightly binaries built with ROCm, you can pass in the flag `--rocm`:
|
||||
|
||||
```bash
|
||||
./tools/nightly.py checkout -b my-nightly-branch --rocm
|
||||
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
|
||||
```
|
||||
|
||||
You can also use this tool to pull the nightly commits into the current branch:
|
||||
|
||||
```bash
|
||||
./tools/nightly.py pull -p my-env
|
||||
source my-env/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
|
||||
./tools/nightly.py pull
|
||||
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
|
||||
```
|
||||
|
||||
To create the virtual environment with a specific Python interpreter, you can
|
||||
pass in the `--python` argument:
|
||||
|
||||
```bash
|
||||
./tools/nightly.py --python /path/to/python3.12
|
||||
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
|
||||
```
|
||||
|
||||
Pulling will recreate a fresh virtual environment and reinstall the development
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include <c10/core/thread_pool.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <c10/util/llvmMathExtras.h>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
|
||||
#include <deque>
|
||||
@ -75,6 +76,9 @@ struct TORCH_API HostStats {
|
||||
|
||||
// COUNT: number of times cudaHostFree/cudaHostUnregister was called.
|
||||
int64_t num_host_free = 0; // This is derived from segment or timing
|
||||
|
||||
// Count of cudaHostFree/cudaHostUnregister per bucket
|
||||
std::vector<int64_t> bucket_allocation = std::vector<int64_t>(MAX_SIZE_INDEX);
|
||||
};
|
||||
|
||||
// Struct containing memory allocator summary statistics for host, as they
|
||||
@ -196,27 +200,7 @@ struct CachingHostAllocatorImpl {
|
||||
// background.
|
||||
if (!pinned_use_background_threads()) {
|
||||
process_events();
|
||||
}
|
||||
|
||||
// Round up the allocation to the nearest power of two to improve reuse.
|
||||
// These power of two sizes are also used to index into the free list.
|
||||
size_t roundSize = c10::llvm::PowerOf2Ceil(size);
|
||||
|
||||
// First, try to allocate from the free list
|
||||
auto* block = get_free_block(roundSize);
|
||||
if (block) {
|
||||
return {block->ptr_, reinterpret_cast<void*>(block)};
|
||||
}
|
||||
|
||||
// Check in the recently freed blocks with pending events to see if we
|
||||
// can reuse them. Call get_free_block again after processing events
|
||||
if (pinned_use_background_threads()) {
|
||||
process_events_for_specific_size(roundSize);
|
||||
block = get_free_block(roundSize);
|
||||
if (block) {
|
||||
return {block->ptr_, reinterpret_cast<void*>(block)};
|
||||
}
|
||||
|
||||
} else {
|
||||
// Launch the background thread and process events in a loop.
|
||||
static bool background_thread_flag [[maybe_unused]] = [this] {
|
||||
getBackgroundThreadPool()->run([&]() {
|
||||
@ -229,6 +213,16 @@ struct CachingHostAllocatorImpl {
|
||||
}();
|
||||
}
|
||||
|
||||
// Round up the allocation to the nearest power of two to improve reuse.
|
||||
// These power of two sizes are also used to index into the free list.
|
||||
size_t roundSize = c10::llvm::PowerOf2Ceil(size);
|
||||
|
||||
// First, try to allocate from the free list
|
||||
auto* block = get_free_block(roundSize);
|
||||
if (block) {
|
||||
return {block->ptr_, reinterpret_cast<void*>(block)};
|
||||
}
|
||||
|
||||
// Slow path: if we can't allocate from the cached free list, we need
|
||||
// to create a new block.
|
||||
void* ptr = nullptr;
|
||||
@ -278,8 +272,6 @@ struct CachingHostAllocatorImpl {
|
||||
auto index = size_index(block->size_);
|
||||
std::lock_guard<std::mutex> g(free_list_[index].mutex_);
|
||||
free_list_[index].list_.push_back(block);
|
||||
stats_.allocation_bucket_stats[index].decrease(1);
|
||||
stats_.allocated_bytes_bucket_stats[index].decrease(block->size_);
|
||||
} else {
|
||||
// restore these events that record by used streams.
|
||||
std::lock_guard<std::mutex> g(events_mutex_);
|
||||
@ -339,9 +331,12 @@ struct CachingHostAllocatorImpl {
|
||||
for (auto* block : blocks_to_remove) {
|
||||
blocks_.erase(block);
|
||||
ptr_to_block_.erase(block->ptr_);
|
||||
auto index = size_index(block->size_);
|
||||
free_block(block);
|
||||
stats_.allocation.decrease(1);
|
||||
stats_.allocated_bytes.decrease(block->size_);
|
||||
free_block(block);
|
||||
stats_.allocation_bucket_stats[index].decrease(1);
|
||||
stats_.allocated_bytes_bucket_stats[index].decrease(block->size_);
|
||||
delete block;
|
||||
}
|
||||
}
|
||||
@ -398,6 +393,7 @@ struct CachingHostAllocatorImpl {
|
||||
// a best effort manner, since we can't really replay the cached events per bucket.
|
||||
add_bucket_stats(stats.allocation, stats_.allocation_bucket_stats[i]);
|
||||
add_bucket_stats(stats.allocated_bytes, stats_.allocated_bytes_bucket_stats[i]);
|
||||
stats.bucket_allocation[i] = stats_.allocation_bucket_stats[i].allocated;
|
||||
}
|
||||
|
||||
// Get the timing stats
|
||||
@ -488,8 +484,6 @@ struct CachingHostAllocatorImpl {
|
||||
B* block = free_list_[index].list_.back();
|
||||
free_list_[index].list_.pop_back();
|
||||
block->allocated_ = true;
|
||||
stats_.allocation_bucket_stats[index].increase(1);
|
||||
stats_.allocated_bytes_bucket_stats[index].increase(size);
|
||||
return block;
|
||||
}
|
||||
return nullptr;
|
||||
@ -583,8 +577,6 @@ struct CachingHostAllocatorImpl {
|
||||
auto index = size_index(block->size_);
|
||||
std::lock_guard<std::mutex> g(free_list_[index].mutex_);
|
||||
free_list_[index].list_.push_back(block);
|
||||
stats_.allocation_bucket_stats[index].decrease(1);
|
||||
stats_.allocated_bytes_bucket_stats[index].decrease(size);
|
||||
if (size != -1) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#include <c10/core/impl/PythonDispatcherTLS.h>
|
||||
#include <ATen/core/PythonFallbackKernel.h>
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
namespace {
|
||||
|
||||
@ -53,20 +54,24 @@ void pythonFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_
|
||||
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
|
||||
// c10::impl::ForceDispatchKeyGuard dispatcher_guard(tls_on_entry.value());
|
||||
// StashTLSOnEntryGuard stash_guard;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(after_Python_keyset);
|
||||
c10::impl::ExcludeDispatchKeyGuard exclude_guard(after_Python_keyset);
|
||||
|
||||
const auto& schema = op.schema();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
|
||||
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
|
||||
if (mode_stack_len > 0) {
|
||||
RECORD_FUNCTION("PythonDispatchMode", torch::jit::last(*stack, num_arguments));
|
||||
const auto& cur_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
|
||||
cur_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack);
|
||||
return;
|
||||
}
|
||||
|
||||
RECORD_FUNCTION("PythonSubclass", torch::jit::last(*stack, num_arguments));
|
||||
|
||||
// Otherwise, find a PyInterpreter on a Tensor
|
||||
const auto& schema = op.schema();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
// It is safe to dispatch on the very first Tensor with a pyobj_interpreter
|
||||
// without checking the interpreters of any of the arguments, because when
|
||||
// we actually run dispatch(), we will take out PyObjects in the context
|
||||
|
||||
@ -1,22 +1,32 @@
|
||||
#include <ATen/core/PythonOpRegistrationTrampoline.h>
|
||||
#include <c10/core/impl/PyInterpreterHooks.h>
|
||||
|
||||
// TODO: delete this
|
||||
namespace at::impl {
|
||||
|
||||
c10::impl::PyInterpreter* PythonOpRegistrationTrampoline::interpreter_ = nullptr;
|
||||
// The strategy is that all python interpreters attempt to register themselves
|
||||
// as the main interpreter, but only one wins. Only that interpreter is
|
||||
// allowed to interact with the C++ dispatcher. Furthermore, when we execute
|
||||
// logic on that interpreter, we do so hermetically, never setting pyobj field
|
||||
// on Tensor.
|
||||
|
||||
std::atomic<c10::impl::PyInterpreter*>
|
||||
PythonOpRegistrationTrampoline::interpreter_{nullptr};
|
||||
|
||||
c10::impl::PyInterpreter* PythonOpRegistrationTrampoline::getInterpreter() {
|
||||
return c10::impl::getGlobalPyInterpreter();
|
||||
return PythonOpRegistrationTrampoline::interpreter_.load();
|
||||
}
|
||||
|
||||
bool PythonOpRegistrationTrampoline::registerInterpreter(
|
||||
c10::impl::PyInterpreter* interp) {
|
||||
if (interpreter_ != nullptr) {
|
||||
c10::impl::PyInterpreter* expected = nullptr;
|
||||
interpreter_.compare_exchange_strong(expected, interp);
|
||||
if (expected != nullptr) {
|
||||
// This is the second (or later) Python interpreter, which means we need
|
||||
// non-trivial hermetic PyObject TLS
|
||||
c10::impl::HermeticPyObjectTLS::init_state();
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
interpreter_ = interp;
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace at::impl
|
||||
|
||||
@ -2,21 +2,19 @@
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
// TODO: We can get rid of this
|
||||
// TODO: this can probably live in c10
|
||||
|
||||
|
||||
namespace at::impl {
|
||||
|
||||
// Manages the single Python interpreter instance for PyTorch.
|
||||
class TORCH_API PythonOpRegistrationTrampoline final {
|
||||
static c10::impl::PyInterpreter* interpreter_;
|
||||
static std::atomic<c10::impl::PyInterpreter*> interpreter_;
|
||||
|
||||
public:
|
||||
// Register the Python interpreter. Returns true on first registration,
|
||||
// false if an interpreter was already registered.
|
||||
// Returns true if you successfully registered yourself (that means
|
||||
// you are in the hot seat for doing the operator registrations!)
|
||||
static bool registerInterpreter(c10::impl::PyInterpreter*);
|
||||
|
||||
// Returns the registered interpreter via the global PyInterpreter hooks.
|
||||
// Returns nullptr if no interpreter has been registered yet.
|
||||
static c10::impl::PyInterpreter* getInterpreter();
|
||||
};
|
||||
|
||||
@ -151,11 +151,6 @@ struct CUDACachingHostAllocatorImpl
|
||||
}
|
||||
|
||||
bool query_event(EventPool::Event& event) override {
|
||||
// Do not call cudaEventQuery if capturing is underway
|
||||
if (at::cuda::currentStreamCaptureStatusMayInitCtx() !=
|
||||
at::cuda::CaptureStatus::None) {
|
||||
return false;
|
||||
}
|
||||
cudaError_t err = cudaEventQuery(*event);
|
||||
if (err == cudaErrorNotReady) {
|
||||
(void)cudaGetLastError(); // clear CUDA error
|
||||
|
||||
@ -90,6 +90,10 @@ public:
|
||||
allocator_->setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
std::vector<HIPCachingAllocator::StreamSegmentSize> getExpandableSegmentSizes(c10::DeviceIndex device) override {
|
||||
return allocator_->getExpandableSegmentSizes(device);
|
||||
}
|
||||
|
||||
void enable(bool value) override {
|
||||
allocator_->enable(value);
|
||||
}
|
||||
|
||||
@ -670,6 +670,8 @@ Tensor rrelu_with_noise_backward(
|
||||
}
|
||||
|
||||
Tensor rrelu(const Tensor & self, const Scalar& lower, const Scalar& upper, bool training, std::optional<Generator> generator) {
|
||||
TORCH_CHECK(std::isfinite(lower.to<double>()), "rrelu: lower bound must be finite, got ", lower.to<double>());
|
||||
TORCH_CHECK(std::isfinite(upper.to<double>()), "rrelu: upper bound must be finite, got ", upper.to<double>());
|
||||
TORCH_CHECK(lower.to<double>() <= upper.to<double>(), "Lower bound should be less than or equal to the upper bound")
|
||||
auto noise = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
return at::rrelu_with_noise(self, noise, lower, upper, training, std::move(generator));
|
||||
|
||||
@ -2801,6 +2801,7 @@ Tensor matrix_exp(const Tensor& a) {
|
||||
// TODO This should be deprecated in favor of linalg_matrix_exp_differential
|
||||
// in FunctionsManual.cpp
|
||||
Tensor matrix_exp_backward(const Tensor& self, const Tensor& grad) {
|
||||
squareCheckInputs(self, "matrix_exp_backward");
|
||||
NoTF32Guard disable_tf32;
|
||||
return backward_analytic_function_of_a_matrix(
|
||||
self, grad,
|
||||
|
||||
@ -1880,43 +1880,34 @@ Tensor repeat(const Tensor& self, IntArrayRef repeats) {
|
||||
|
||||
Tensor xtensor = self.expand(padded_size);
|
||||
|
||||
Tensor urtensor;
|
||||
if (self.is_quantized()) {
|
||||
urtensor = at::empty_quantized(target_size, self);
|
||||
} else {
|
||||
urtensor = at::empty(target_size, self.options());
|
||||
}
|
||||
|
||||
// return an empty tensor if one of the repeat dimensions is zero
|
||||
if (zero_tensor) {
|
||||
return urtensor;
|
||||
return self.is_quantized() ? at::empty_quantized(target_size, self)
|
||||
: at::empty(target_size, self.options());
|
||||
}
|
||||
|
||||
// Create view of shape [r0, s0, r1, s1, ...]
|
||||
// where ri is repeat[i], si is self.size(i).
|
||||
Tensor view = xtensor;
|
||||
auto expand_shape = std::vector<int64_t>();
|
||||
expand_shape.reserve(xtensor.dim() * 2);
|
||||
for (const auto i : c10::irange(xtensor.dim())) {
|
||||
// can't unfold with step 0, so make sure step is at least 1
|
||||
// (it doesn't matter what it is in that case, because the size is 0).
|
||||
auto size_i = xtensor.sizes()[i];
|
||||
urtensor = urtensor.unfold(i, size_i, std::max<int64_t>(size_i, 1));
|
||||
view = view.unsqueeze(2 * i);
|
||||
expand_shape.push_back(repeats[i]);
|
||||
expand_shape.push_back(xtensor.size(i));
|
||||
}
|
||||
// expanded_view is non-contiguous because .expand set stride to 0.
|
||||
auto expanded_view = view.expand(expand_shape);
|
||||
|
||||
urtensor.copy_(xtensor.expand_as(urtensor));
|
||||
// copy to contiguous tensor.
|
||||
auto contiguous_copy = at::empty(
|
||||
expanded_view.sizes(),
|
||||
expanded_view.options(),
|
||||
at::MemoryFormat::Contiguous);
|
||||
contiguous_copy.copy_(expanded_view);
|
||||
|
||||
// Combine the dimensions to produce the target_size.
|
||||
// xtensor dims: [a0, ..., ad-1]
|
||||
// urtensor dims: [a0, ..., ad-1, b0, ..., bd-1]
|
||||
// b dims are produced by unfold.
|
||||
// Transform urtensor to [a0 * b0, ..., ad-1 * bd-1]
|
||||
const int64_t n_dims = xtensor.dim();
|
||||
auto range_a = at::arange(xtensor.dim(), at::TensorOptions(at::kLong));
|
||||
auto range_b = range_a + n_dims;
|
||||
auto stacked = stack({std::move(range_a), std::move(range_b)}, 1).flatten();
|
||||
auto permutation = IntArrayRef(stacked.data_ptr<int64_t>(), n_dims * 2);
|
||||
// Permute from [a0, ..., ad-1, b0, ..., bd-1] to [a0, b0, ..., ad-1, bd-1]
|
||||
urtensor = urtensor.permute(permutation);
|
||||
// Reshape from [a0, b0, ..., ad-1, bd-1] to [a0 * b0, ..., ad-1 * bd-1]
|
||||
urtensor = urtensor.reshape(target_size);
|
||||
|
||||
return urtensor;
|
||||
// Reshape to [s0 * r0, s1 * r1, ...].
|
||||
// No extra copy of data during reshape for a contiguous tensor.
|
||||
return contiguous_copy.view(target_size);
|
||||
}
|
||||
|
||||
Tensor tile_symint(const Tensor& self, SymIntArrayRef reps) {
|
||||
|
||||
@ -1831,6 +1831,37 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
return out;
|
||||
}
|
||||
|
||||
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
|
||||
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
|
||||
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
|
||||
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
|
||||
|
||||
const auto batch1_sizes = batch1.sizes();
|
||||
const auto batch2_sizes = batch2.sizes();
|
||||
|
||||
int64_t bs = batch1_sizes[0];
|
||||
int64_t contraction_size = batch1_sizes[2];
|
||||
int64_t res_rows = batch1_sizes[1];
|
||||
int64_t res_cols = batch2_sizes[2];
|
||||
std::vector<int64_t> output_size {bs, res_rows, res_cols};
|
||||
|
||||
TORCH_CHECK(batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size,
|
||||
"Expected size for first two dimensions of batch2 tensor to be: [",
|
||||
bs, ", ", contraction_size, "] but got: [", batch2_sizes[0], ", ", batch2_sizes[1], "].");
|
||||
|
||||
TORCH_CHECK(batch1.scalar_type() == batch2.scalar_type(), "batch1 and batch2 must have the same dtype");
|
||||
|
||||
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
|
||||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
|
||||
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
|
||||
|
||||
if (!is_bmm && self_baddbmm.has_value()) {
|
||||
const auto& self = self_baddbmm.value();
|
||||
TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
|
||||
TORCH_CHECK(self.sizes() == output_size, "self must have the same shape as the output");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
|
||||
IntArrayRef batch1_sizes = batch1.sizes();
|
||||
IntArrayRef batch2_sizes = batch2.sizes();
|
||||
@ -1840,12 +1871,7 @@ Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::Sca
|
||||
}
|
||||
|
||||
Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, Tensor &out) {
|
||||
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
||||
|
||||
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
|
||||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
|
||||
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
|
||||
|
||||
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype, true);
|
||||
Scalar beta(0.0);
|
||||
Scalar alpha(1.0);
|
||||
{
|
||||
@ -1864,12 +1890,7 @@ Tensor _baddbmm_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tenso
|
||||
}
|
||||
|
||||
Tensor& _baddbmm_out_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
|
||||
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
||||
|
||||
TORCH_CHECK(out_dtype == batch1.scalar_type() ||
|
||||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
|
||||
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
|
||||
|
||||
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, false, self);
|
||||
{
|
||||
NoNamesGuard guard;
|
||||
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha);
|
||||
@ -1884,6 +1905,12 @@ Tensor _mm_dtype_cuda(const Tensor& self, const Tensor& mat2, const at::ScalarTy
|
||||
}
|
||||
|
||||
Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::ScalarType out_dtype, Tensor &out) {
|
||||
TORCH_CHECK(self.dim() == 2, "self must be a matrix, got ", self.dim(), "-D tensor");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
|
||||
TORCH_CHECK(
|
||||
self.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||
self.sizes()[0], "x", self.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
|
||||
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
||||
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "input dtypes must be the same");
|
||||
TORCH_CHECK(out_dtype == self.scalar_type() ||
|
||||
@ -1903,6 +1930,14 @@ Tensor _addmm_dtype_cuda(const Tensor& self, const Tensor& mat1, const Tensor& m
|
||||
}
|
||||
|
||||
Tensor& _addmm_dtype_out_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
|
||||
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type());
|
||||
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
|
||||
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
||||
TORCH_CHECK(out_dtype == self.scalar_type() ||
|
||||
(out_dtype == at::ScalarType::Float && (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16)),
|
||||
|
||||
@ -10256,6 +10256,7 @@
|
||||
structured: True
|
||||
dispatch:
|
||||
CPU, CUDA: all_all_out
|
||||
MTIA: all_all_out_mtia
|
||||
MPS: all_all_out_mps
|
||||
|
||||
- func: any(Tensor self) -> Tensor
|
||||
|
||||
@ -101,6 +101,9 @@ __device__ inline bool isinf_device(float v) {
|
||||
__device__ inline bool isinf_device(c10::BFloat16 v) {
|
||||
return ::isinf(static_cast<float>(v));
|
||||
}
|
||||
__device__ inline bool isinf_device(at::Half v) {
|
||||
return ::isinf(static_cast<float>(v));
|
||||
}
|
||||
|
||||
// CUDA kernel to compute Moving Average Min/Max of the tensor.
|
||||
// It uses the running_min and running_max along with averaging const, c.
|
||||
@ -160,8 +163,8 @@ void _calculate_moving_average(
|
||||
std::tie(x_min, x_max) = at::aminmax(x, 1);
|
||||
int num_threads = std::min(size, (int64_t)512);
|
||||
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
|
||||
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
|
||||
|
||||
@ -181,8 +184,8 @@ void _calculate_moving_average(
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
std::tie(x_min, x_max) = at::aminmax(x);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
|
||||
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
|
||||
|
||||
@ -221,8 +224,8 @@ void _calc_moving_avg_qparams_helper(
|
||||
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
||||
int64_t* fake_quant_on_data = fake_quant_on.data_ptr<int64_t>();
|
||||
if (per_row_fq) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
int num_threads = std::min(size, (int64_t)512);
|
||||
@ -244,8 +247,8 @@ void _calc_moving_avg_qparams_helper(
|
||||
});
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::kBFloat16, at::kHalf, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(
|
||||
|
||||
@ -316,6 +316,12 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
|
||||
return false;
|
||||
#endif
|
||||
#else
|
||||
if (!at::cuda::is_available()) {
|
||||
if (debug) {
|
||||
TORCH_WARN("flash attention requires a CUDA device, which is not available.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (!check_sm_version<sm80, sm121>(dprops)) {
|
||||
if (debug) {
|
||||
@ -367,6 +373,12 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
|
||||
return false;
|
||||
#endif
|
||||
#else
|
||||
if (!at::cuda::is_available()) {
|
||||
if (debug) {
|
||||
TORCH_WARN("Mem Efficient attention requires a CUDA device, which is not available.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (!check_sm_version<sm50, sm121>(dprops)) {
|
||||
if (debug) {
|
||||
@ -597,6 +609,12 @@ bool check_cudnn_layout(sdp_params const& params, bool debug) {
|
||||
bool check_cudnn_hardware_support(sdp_params const& params, bool debug) {
|
||||
using sm80 = SMVersion<8, 0>;
|
||||
using sm121 = SMVersion<12, 1>;
|
||||
if (!at::cuda::is_available()) {
|
||||
if (debug) {
|
||||
TORCH_WARN("cuDNN SDPA requires a CUDA device, which is not available.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (!check_sm_version<sm80, sm121>(dprops)) {
|
||||
if (debug) {
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <torch/csrc/profiler/orchestration/vulkan.h>
|
||||
#endif // USE_KINETO
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
|
||||
@ -1,10 +1,83 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
#include <ATen/test/allocator_clone_test.h>
|
||||
|
||||
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
|
||||
|
||||
std::unordered_map<void*, size_t> allocation_sizes;
|
||||
|
||||
void* logging_malloc(size_t size, int device, cudaStream_t stream) {
|
||||
void* ptr;
|
||||
cudaMalloc(&ptr, size);
|
||||
allocation_sizes[ptr] = size;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void logging_free(void* ptr, size_t size, int device, cudaStream_t stream) {
|
||||
if (allocation_sizes.find(ptr) != allocation_sizes.end()) {
|
||||
if (allocation_sizes[ptr] != size) {
|
||||
throw std::runtime_error("free mismatch");
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("free of unknown ptr");
|
||||
}
|
||||
cudaFree(ptr);
|
||||
allocation_sizes.erase(ptr);
|
||||
}
|
||||
|
||||
TEST(TestTorchUnique, UniqueComparisonTest) {
|
||||
if (!at::cuda::is_available()) return;
|
||||
auto custom_allocator =
|
||||
torch::cuda::CUDAPluggableAllocator::createCustomAllocator(logging_malloc, logging_free);
|
||||
torch::cuda::CUDAPluggableAllocator::changeCurrentAllocator(custom_allocator);
|
||||
// Run the command 3 times; the first 2 will pass and the third invocation will have
|
||||
// different sizes in alloc and free if the test fails.
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
// Initialize simple sorted tensor with repeats
|
||||
at::Tensor sorted_tensor =
|
||||
at::tensor({0, 0, 0, 1, 1, 2, 3, 3, 3, 3, 5},
|
||||
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA));
|
||||
|
||||
// This operation will call malloc/free with different sizes on the same pointer
|
||||
auto unique_dim_result = at::unique_consecutive(sorted_tensor, false, true, 0);
|
||||
|
||||
// Everything below is only there to validate correct results
|
||||
auto unique_dim_values = std::get<0>(unique_dim_result);
|
||||
auto unique_dim_counts = std::get<2>(unique_dim_result);
|
||||
|
||||
// Check tensor sizes
|
||||
EXPECT_EQ(unique_dim_values.size(0), 5);
|
||||
EXPECT_EQ(unique_dim_counts.size(0), 5);
|
||||
|
||||
// Copy to CPU before accessing elements
|
||||
at::Tensor cpu_values = unique_dim_values.cpu();
|
||||
at::Tensor cpu_counts = unique_dim_counts.cpu();
|
||||
|
||||
// Use accessors on the CPU tensors
|
||||
auto values_accessor = cpu_values.accessor<float, 1>();
|
||||
auto counts_accessor = cpu_counts.accessor<int64_t, 1>();
|
||||
|
||||
// Check individual values using accessors
|
||||
EXPECT_EQ(values_accessor[0], 0.0f);
|
||||
EXPECT_EQ(values_accessor[1], 1.0f);
|
||||
EXPECT_EQ(values_accessor[2], 2.0f);
|
||||
EXPECT_EQ(values_accessor[3], 3.0f);
|
||||
EXPECT_EQ(values_accessor[4], 5.0f);
|
||||
|
||||
// Check count values using accessors
|
||||
EXPECT_EQ(counts_accessor[0], 3);
|
||||
EXPECT_EQ(counts_accessor[1], 2);
|
||||
EXPECT_EQ(counts_accessor[2], 1);
|
||||
EXPECT_EQ(counts_accessor[3], 4);
|
||||
EXPECT_EQ(counts_accessor[4], 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(AllocatorTestCUDA, test_clone) {
|
||||
if (!at::cuda::is_available()) return;
|
||||
test_allocator_clone(c10::cuda::CUDACachingAllocator::get());
|
||||
}
|
||||
|
||||
@ -50,6 +50,7 @@ run_if_exists cuda_complex_test
|
||||
run_if_exists cuda_complex_math_test
|
||||
run_if_exists cuda_cub_test
|
||||
run_if_exists cuda_atomic_ops_test
|
||||
run_if_exists cuda_allocator_test
|
||||
|
||||
if [ "$VALGRIND" == "ON" ]; then
|
||||
# NB: As these tests are invoked by valgrind, let's leave them for now as it's
|
||||
|
||||
@ -156,7 +156,7 @@ ROOT = "//" if IS_OSS else "//xplat/caffe2"
|
||||
# for targets in subfolders
|
||||
ROOT_PATH = "//" if IS_OSS else "//xplat/caffe2/"
|
||||
|
||||
C10 = "//c10:c10" if IS_OSS else ("//xplat/caffe2/c10:c10_ovrsource" if is_arvr_mode() else "//xplat/caffe2/c10:c10")
|
||||
C10 = "//c10:c10" if IS_OSS else "//xplat/caffe2/c10:c10"
|
||||
|
||||
# a dictionary maps third party library name to fbsource and oss target
|
||||
THIRD_PARTY_LIBS = {
|
||||
|
||||
@ -1,100 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
// This is directly synchronized with caffe2/proto/caffe2.proto, but
|
||||
// doesn't require me to figure out how to get Protobuf headers into
|
||||
// ATen/core (which would require a lot more build system hacking.)
|
||||
// If you modify me, keep me synchronized with that file.
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
// If you modified DeviceType in caffe2/proto/caffe2.proto, please also sync
|
||||
// your changes into torch/headeronly/core/DeviceType.h.
|
||||
#include <torch/headeronly/core/DeviceType.h>
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// These contains all device types that also have a BackendComponent
|
||||
// and therefore participate in per-backend functionality dispatch keys.
|
||||
// This is most backends except PrivateUse2 and PrivateUse3
|
||||
#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \
|
||||
_(CPU, extra) \
|
||||
_(CUDA, extra) \
|
||||
_(HIP, extra) \
|
||||
_(XLA, extra) \
|
||||
_(MPS, extra) \
|
||||
_(IPU, extra) \
|
||||
_(XPU, extra) \
|
||||
_(HPU, extra) \
|
||||
_(VE, extra) \
|
||||
_(Lazy, extra) \
|
||||
_(Meta, extra) \
|
||||
_(MTIA, extra) \
|
||||
_(PrivateUse1, extra)
|
||||
|
||||
enum class DeviceType : int8_t {
|
||||
CPU = 0,
|
||||
CUDA = 1, // CUDA.
|
||||
MKLDNN = 2, // Reserved for explicit MKLDNN
|
||||
OPENGL = 3, // OpenGL
|
||||
OPENCL = 4, // OpenCL
|
||||
IDEEP = 5, // IDEEP.
|
||||
HIP = 6, // AMD HIP
|
||||
FPGA = 7, // FPGA
|
||||
MAIA = 8, // ONNX Runtime / Microsoft
|
||||
XLA = 9, // XLA / TPU
|
||||
Vulkan = 10, // Vulkan
|
||||
Metal = 11, // Metal
|
||||
XPU = 12, // XPU
|
||||
MPS = 13, // MPS
|
||||
Meta = 14, // Meta (tensors with no data)
|
||||
HPU = 15, // HPU / HABANA
|
||||
VE = 16, // SX-Aurora / NEC
|
||||
Lazy = 17, // Lazy Tensors
|
||||
IPU = 18, // Graphcore IPU
|
||||
MTIA = 19, // Meta training and inference devices
|
||||
PrivateUse1 = 20, // PrivateUse1 device
|
||||
// NB: If you add more devices:
|
||||
// - Change the implementations of DeviceTypeName and isValidDeviceType
|
||||
// in DeviceType.cpp
|
||||
// - Change the number below
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES = 21,
|
||||
};
|
||||
|
||||
constexpr DeviceType kCPU = DeviceType::CPU;
|
||||
constexpr DeviceType kCUDA = DeviceType::CUDA;
|
||||
constexpr DeviceType kHIP = DeviceType::HIP;
|
||||
constexpr DeviceType kFPGA = DeviceType::FPGA;
|
||||
constexpr DeviceType kMAIA = DeviceType::MAIA;
|
||||
constexpr DeviceType kXLA = DeviceType::XLA;
|
||||
constexpr DeviceType kMPS = DeviceType::MPS;
|
||||
constexpr DeviceType kMeta = DeviceType::Meta;
|
||||
constexpr DeviceType kVulkan = DeviceType::Vulkan;
|
||||
constexpr DeviceType kMetal = DeviceType::Metal;
|
||||
constexpr DeviceType kXPU = DeviceType::XPU;
|
||||
constexpr DeviceType kHPU = DeviceType::HPU;
|
||||
constexpr DeviceType kVE = DeviceType::VE;
|
||||
constexpr DeviceType kLazy = DeviceType::Lazy;
|
||||
constexpr DeviceType kIPU = DeviceType::IPU;
|
||||
constexpr DeviceType kMTIA = DeviceType::MTIA;
|
||||
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;
|
||||
|
||||
// define explicit int constant
|
||||
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
|
||||
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
|
||||
|
||||
static_assert(
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES <= 21,
|
||||
"Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
|
||||
"for this constant to reflect the actual number of DeviceTypes we support "
|
||||
"in PyTorch; it's important that this number is not too large as we "
|
||||
"use this to allocate stack arrays in some places in our code. If you "
|
||||
"are indeed just adding the 20th device type, feel free to change "
|
||||
"the check to 32; but if you are adding some sort of extensible device "
|
||||
"types registration, please be aware that you are affecting code that "
|
||||
"this number is small. Try auditing uses of this constant.");
|
||||
|
||||
C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false);
|
||||
|
||||
C10_API bool isValidDeviceType(DeviceType d);
|
||||
@ -108,15 +24,6 @@ C10_API bool is_privateuse1_backend_registered();
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<c10::DeviceType> {
|
||||
std::size_t operator()(c10::DeviceType k) const {
|
||||
return std::hash<int>()(static_cast<int>(k));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
namespace torch {
|
||||
// NOLINTNEXTLINE(misc-unused-using-decls)
|
||||
using c10::DeviceType;
|
||||
|
||||
21
c10/core/impl/HermeticPyObjectTLS.cpp
Normal file
21
c10/core/impl/HermeticPyObjectTLS.cpp
Normal file
@ -0,0 +1,21 @@
|
||||
#include <c10/core/impl/HermeticPyObjectTLS.h>
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
thread_local static std::atomic<bool> hermeticPyObjectState{false};
|
||||
|
||||
std::atomic<bool> HermeticPyObjectTLS::haveState_{false};
|
||||
|
||||
void HermeticPyObjectTLS::set_state(bool state) {
|
||||
hermeticPyObjectState = state;
|
||||
}
|
||||
|
||||
bool HermeticPyObjectTLS::get_tls_state() {
|
||||
return hermeticPyObjectState;
|
||||
}
|
||||
|
||||
void HermeticPyObjectTLS::init_state() {
|
||||
haveState_ = true;
|
||||
}
|
||||
|
||||
} // namespace c10::impl
|
||||
62
c10/core/impl/HermeticPyObjectTLS.h
Normal file
62
c10/core/impl/HermeticPyObjectTLS.h
Normal file
@ -0,0 +1,62 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <atomic>
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
// This TLS controls whether or not we permanently associate PyObject
|
||||
// with Tensor the first time it is allocated. When hermetic PyObject
|
||||
// TLS is enabled (state is true), we DO NOT save PyObjects to Tensor,
|
||||
// meaning you get a distinct PyObject whenever you execute the code in
|
||||
// question.
|
||||
struct C10_API HermeticPyObjectTLS {
|
||||
static void set_state(bool state);
|
||||
static bool get_state() {
|
||||
// Hypothetical fastpath if torchdeploy/multipy // codespell:ignore multipy
|
||||
// isn't used. Per
|
||||
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
|
||||
// this qualifies relaxed access because it is a single-location data
|
||||
// structure (only the boolean here).
|
||||
//
|
||||
// Forgetting about data races for a moment, is there a logical race?
|
||||
//
|
||||
// - Boolean only ever transitions from false to true. So the
|
||||
// critical situation is when one interpreter is already running
|
||||
// when a second interpreter switches haveState from false to true.
|
||||
//
|
||||
// - The first interpreter is indifferent whether or not it sees
|
||||
// hasState true/false; obviously false works (this is what the
|
||||
// interpreter was previously using; more directly, the interpreter
|
||||
// calls into itself as the handler, so being hermetic is not
|
||||
// required), and true simply means serviced python operator calls will
|
||||
// be hermetic; in these cases it is expected to be functionally
|
||||
// equivalent.
|
||||
//
|
||||
// - The second interpreter MUST see hasState true (as its requests will
|
||||
// be forwarded to the first interpreter), but it is assumed that there
|
||||
// is a synchronization between the interpreter initialization, and
|
||||
// when we actually perform operations, so it is guaranteed to see
|
||||
// hasState true.
|
||||
//
|
||||
// QED.
|
||||
//
|
||||
// This fastpath is currently disabled so that we can more easily test that
|
||||
// hermetic mode works correctly even on stock build of PyTorch.
|
||||
if (false && !haveState_.load(std::memory_order_relaxed))
|
||||
return false;
|
||||
return get_tls_state();
|
||||
}
|
||||
// Call this from the multipy/torchdeploy // codespell:ignore multipy
|
||||
// top level
|
||||
static void init_state();
|
||||
|
||||
private:
|
||||
// This only flipped once from false to true during
|
||||
// torchdeploy/multipy initialization, // codespell:ignore multipy
|
||||
// and never again.
|
||||
static std::atomic<bool> haveState_;
|
||||
static bool get_tls_state();
|
||||
};
|
||||
|
||||
} // namespace c10::impl
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/impl/HermeticPyObjectTLS.h>
|
||||
#include <c10/core/impl/PyInterpreter.h>
|
||||
#include <c10/core/impl/PyInterpreterHooks.h>
|
||||
#include <c10/util/python_stub.h>
|
||||
@ -41,15 +42,32 @@ struct C10_API PyObjectSlot {
|
||||
|
||||
PyObject* _unchecked_untagged_pyobj() const;
|
||||
|
||||
// Test the interpreter / PyObj as they may be null
|
||||
// Test the interpreter tag. If tagged for the current interpreter, return
|
||||
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
|
||||
// returns a nullopt. If it is definitely invalid, raises an error.
|
||||
//
|
||||
// If `ignore_hermetic_tls` is false and this function is called from a
|
||||
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
|
||||
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
|
||||
// context is ignored, allowing you to check the interpreter tag of a
|
||||
// nonhermetic PyObject from within a hermetic context. This is necessary
|
||||
// because there are some cases where the deallocator function of a
|
||||
// nonhermetic PyObject is called from within a hermetic context, so it must
|
||||
// be properly treated as a nonhermetic PyObject.
|
||||
//
|
||||
// NB: this lives in header so that we can avoid actually creating the
|
||||
// std::optional
|
||||
|
||||
// @todo alban: I'm not too sure what's going on here, we can probably delete
|
||||
// it but it's worthwhile making sure
|
||||
std::optional<PyObject*> check_pyobj() const {
|
||||
impl::PyInterpreter* interpreter = getGlobalPyInterpreter();
|
||||
if (interpreter == nullptr || pyobj_ == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return _unchecked_untagged_pyobj();
|
||||
}
|
||||
|
||||
|
||||
@ -382,6 +382,7 @@ struct ExpandableSegment {
|
||||
peers_(std::move(peers)) {
|
||||
cudaDeviceProp prop{};
|
||||
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_));
|
||||
mapped_size_ = 0;
|
||||
// we allocate enough address space for 1 1/8 the total memory on the GPU.
|
||||
// This allows for some cases where we have to unmap pages earlier in the
|
||||
// segment to put them at the end.
|
||||
@ -493,6 +494,7 @@ struct ExpandableSegment {
|
||||
return SegmentRange{range.ptr, 0};
|
||||
}
|
||||
unmapHandles(begin, end);
|
||||
mapped_size_ -= (end - begin) * segment_size_;
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
@ -632,6 +634,18 @@ struct ExpandableSegment {
|
||||
return max_handles_ * segment_size_;
|
||||
}
|
||||
|
||||
cudaStream_t getStream() {
|
||||
return *stream_;
|
||||
}
|
||||
|
||||
size_t getMappedSize() {
|
||||
return mapped_size_;
|
||||
}
|
||||
|
||||
size_t getSegmentSize() {
|
||||
return segment_size_;
|
||||
}
|
||||
|
||||
void addPeer(c10::DeviceIndex device) {
|
||||
peers_.push_back(device);
|
||||
forEachAllocatedRange(
|
||||
@ -666,6 +680,7 @@ struct ExpandableSegment {
|
||||
handles_.at(i).value().handle,
|
||||
0ULL));
|
||||
}
|
||||
mapped_size_ += (end - begin) * segment_size_;
|
||||
setAccess(device_, begin, end);
|
||||
for (auto p : peers_) {
|
||||
setAccess(p, begin, end);
|
||||
@ -734,6 +749,7 @@ struct ExpandableSegment {
|
||||
std::optional<cudaStream_t> stream_;
|
||||
CUdeviceptr ptr_{};
|
||||
size_t segment_size_;
|
||||
size_t mapped_size_;
|
||||
size_t max_handles_;
|
||||
struct Handle {
|
||||
CUmemGenericAllocationHandle handle;
|
||||
@ -779,6 +795,17 @@ struct ExpandableSegment {
|
||||
size_t size() const {
|
||||
return 0;
|
||||
}
|
||||
cudaStream_t getStream() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t getMappedSize() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t getSegmentSize() {
|
||||
return 0;
|
||||
}
|
||||
void addPeer(c10::DeviceIndex device) {}
|
||||
};
|
||||
#endif
|
||||
@ -1183,6 +1210,16 @@ class DeviceCachingAllocator {
|
||||
// ends.
|
||||
ska::flat_hash_map<Block*, std::vector<cudaGraphNode_t>> deferred_blocks;
|
||||
|
||||
// Incremental reverse-traversal state cached per graph.
|
||||
// We never re-traverse nodes we've already seen
|
||||
struct GraphReuseContext {
|
||||
ska::flat_hash_map<cudaStream_t, ska::flat_hash_set<cudaGraphNode_t>>
|
||||
visited;
|
||||
};
|
||||
ska::flat_hash_map<MempoolId_t, CaptureId_t, MempoolIdHash>
|
||||
mempool_to_capture_id;
|
||||
ska::flat_hash_map<CaptureId_t, GraphReuseContext> graph_reuse_context;
|
||||
|
||||
// outstanding cuda events
|
||||
ska::flat_hash_map<
|
||||
cuda::CUDAStream,
|
||||
@ -1638,44 +1675,70 @@ class DeviceCachingAllocator {
|
||||
return block;
|
||||
}
|
||||
|
||||
// Insert "free marker" (empty nodes) into the CUDA graph for all streams that
|
||||
struct CaptureInfo {
|
||||
cudaGraph_t graph{};
|
||||
CaptureId_t capture_id{0};
|
||||
const cudaGraphNode_t* terminals{nullptr};
|
||||
size_t num_terminals{0};
|
||||
cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone};
|
||||
};
|
||||
|
||||
inline CaptureInfo stream_get_capture_info(cudaStream_t stream) {
|
||||
CaptureInfo info{};
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
stream,
|
||||
&info.status,
|
||||
&info.capture_id,
|
||||
&info.graph,
|
||||
&info.terminals,
|
||||
nullptr,
|
||||
&info.num_terminals));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
|
||||
stream,
|
||||
&info.status,
|
||||
&info.capture_id,
|
||||
&info.graph,
|
||||
&info.terminals,
|
||||
&info.num_terminals));
|
||||
#endif
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
info.status != cudaStreamCaptureStatusInvalidated,
|
||||
"Invalid stream capture status");
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
// Record "free marker" of the CUDA graph for all streams that
|
||||
// have used the block, including the allocation stream. These nodes mark the
|
||||
// last use of the block in the capture graph. Returns a vector of the
|
||||
// inserted nodes, or an empty vector if any stream is not capturing.
|
||||
std::vector<cudaGraphNode_t> insert_free_marker(Block* block) {
|
||||
std::vector<cudaGraphNode_t> empty_nodes;
|
||||
std::vector<cudaGraphNode_t> record_free_markers(Block* block) {
|
||||
// Is is possible to have the same marker recorded multiple times, so we use
|
||||
// a set to avoid duplicates
|
||||
ska::flat_hash_set<cudaGraphNode_t> markers;
|
||||
cudaGraph_t owning_graph = nullptr;
|
||||
|
||||
auto try_add_empty_node = [&](cudaStream_t stream) -> bool {
|
||||
cudaStreamCaptureStatus status{};
|
||||
cudaGraph_t graph{};
|
||||
const cudaGraphNode_t* deps = nullptr;
|
||||
size_t num_deps = 0;
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
stream, &status, nullptr, &graph, &deps, nullptr, &num_deps));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
|
||||
stream, &status, nullptr, &graph, &deps, &num_deps));
|
||||
#endif
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
status != cudaStreamCaptureStatusInvalidated,
|
||||
"Invalid stream capture status");
|
||||
|
||||
if (status == cudaStreamCaptureStatusNone) {
|
||||
return false;
|
||||
auto try_record = [&](cudaStream_t s) -> bool {
|
||||
auto info = stream_get_capture_info(s);
|
||||
if (info.status == cudaStreamCaptureStatusNone) {
|
||||
return false; // not capturing on this stream -> must defer
|
||||
}
|
||||
|
||||
cudaGraphNode_t node{};
|
||||
C10_CUDA_CHECK(cudaGraphAddEmptyNode(&node, graph, deps, num_deps));
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
|
||||
stream, &node, nullptr, 1, cudaStreamSetCaptureDependencies));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
|
||||
stream, &node, 1, cudaStreamSetCaptureDependencies));
|
||||
#endif
|
||||
empty_nodes.push_back(node);
|
||||
if (owning_graph == nullptr) {
|
||||
owning_graph = info.graph;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
info.graph == owning_graph,
|
||||
"All streams in the same capture should agree on the graph");
|
||||
|
||||
// Use current terminals as the free markers for the stream
|
||||
for (size_t i = 0; i < info.num_terminals; ++i) {
|
||||
auto terminal = info.terminals[i];
|
||||
markers.insert(terminal);
|
||||
}
|
||||
owning_graph = info.graph; // all streams in the same capture should agree
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -1683,81 +1746,34 @@ class DeviceCachingAllocator {
|
||||
// An empty vector indicates that the block should be deferred for freeing
|
||||
// until after capture.
|
||||
|
||||
// Attempt to add an empty node for the allocation stream.
|
||||
if (!try_add_empty_node(block->stream)) {
|
||||
// Allocation stream
|
||||
if (!try_record(block->stream)) {
|
||||
return {};
|
||||
}
|
||||
// Attempt to add empty nodes for all streams that have used the block.
|
||||
// Any extra streams that used this block
|
||||
for (const auto& s : block->stream_uses) {
|
||||
if (!try_add_empty_node(s.stream())) {
|
||||
if (!try_record(s.stream())) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
return empty_nodes;
|
||||
return std::vector<cudaGraphNode_t>(markers.begin(), markers.end());
|
||||
}
|
||||
|
||||
// Returns the current set of "terminal" nodes in the CUDA graph for a given
|
||||
// stream. These represent the current endpoints of the stream, and may
|
||||
// include additional nodes if the graph branches. Any new work captured will
|
||||
// be attached after one or more of these terminals.
|
||||
std::vector<cudaGraphNode_t> get_terminals(cudaStream_t stream) {
|
||||
std::vector<cudaGraphNode_t> result;
|
||||
|
||||
cudaStreamCaptureStatus status{};
|
||||
cudaGraph_t graph{};
|
||||
const cudaGraphNode_t* dependencies = nullptr;
|
||||
size_t num_dependencies = 0;
|
||||
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
stream,
|
||||
&status,
|
||||
nullptr,
|
||||
&graph,
|
||||
&dependencies,
|
||||
nullptr,
|
||||
&num_dependencies));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
|
||||
stream, &status, nullptr, &graph, &dependencies, &num_dependencies));
|
||||
#endif
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
status == cudaStreamCaptureStatusActive,
|
||||
"Invalid stream capture status");
|
||||
|
||||
for (size_t i = 0; i < num_dependencies; i++) {
|
||||
auto node = dependencies[i];
|
||||
if (node != nullptr) {
|
||||
result.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the set of "reusable" free markers (empty nodes) in the current
|
||||
// Returns the set of "reusable" free markers in the current
|
||||
// CUDA graph capture. A free marker is considered reusable if it is a
|
||||
// predecessor of every terminal node.
|
||||
// This ensures that all future captured work will occur after the free
|
||||
// marker, making it safe to reuse.
|
||||
ska::flat_hash_set<cudaGraphNode_t> get_reusable_empty_nodes(
|
||||
cudaStream_t stream) {
|
||||
auto terminals = get_terminals(stream);
|
||||
if (terminals.empty()) {
|
||||
// No terminal nodes found; nothing to free.
|
||||
return {};
|
||||
}
|
||||
|
||||
auto get_dependencies = [](cudaGraphNode_t node,
|
||||
cudaGraphNode_t* pDependencies,
|
||||
size_t* pNumDependencies) -> void {
|
||||
void update_visited(
|
||||
const CaptureInfo& info,
|
||||
ska::flat_hash_set<cudaGraphNode_t>& visited) {
|
||||
// This is the versioned cudaGraphNodeGetDependencies helper function.
|
||||
auto node_get_dependencies =
|
||||
[](cudaGraphNode_t n, cudaGraphNode_t* deps, size_t* count) -> void {
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(
|
||||
node, pDependencies, nullptr, pNumDependencies));
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count));
|
||||
#else
|
||||
C10_CUDA_CHECK(
|
||||
cudaGraphNodeGetDependencies(node, pDependencies, pNumDependencies));
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, count));
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -1765,62 +1781,43 @@ class DeviceCachingAllocator {
|
||||
auto get_parents =
|
||||
[&](cudaGraphNode_t node) -> std::vector<cudaGraphNode_t> {
|
||||
size_t count = 0;
|
||||
get_dependencies(node, nullptr, &count);
|
||||
|
||||
node_get_dependencies(node, nullptr, &count);
|
||||
std::vector<cudaGraphNode_t> out(count);
|
||||
if (count) {
|
||||
get_dependencies(node, out.data(), &count);
|
||||
node_get_dependencies(node, out.data(), &count);
|
||||
out.resize(count);
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
// Helper to determine if a node is an empty node (used as a free marker).
|
||||
auto is_empty_node = [](cudaGraphNode_t n) -> bool {
|
||||
cudaGraphNodeType type{};
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetType(n, &type));
|
||||
return type == cudaGraphNodeTypeEmpty;
|
||||
};
|
||||
|
||||
// For each terminal node, perform a reverse DFS to count, for each empty
|
||||
// node, how many terminals it can reach (i.e., for how many terminals it is
|
||||
// a predecessor). An empty node is reusable if it is a predecessor of all
|
||||
// terminal nodes.
|
||||
ska::flat_hash_map<cudaGraphNode_t, size_t> num_terminals_reachable;
|
||||
|
||||
for (auto terminal : terminals) {
|
||||
ska::flat_hash_set<cudaGraphNode_t> visited;
|
||||
ska::flat_hash_set<cudaGraphNode_t> empty_nodes;
|
||||
|
||||
std::function<void(cudaGraphNode_t)> reverse_dfs =
|
||||
[&](cudaGraphNode_t node) {
|
||||
if (!visited.insert(node).second)
|
||||
return;
|
||||
|
||||
if (is_empty_node(node)) {
|
||||
num_terminals_reachable[node]++;
|
||||
empty_nodes.insert(node);
|
||||
}
|
||||
auto parents = get_parents(node);
|
||||
for (auto p : parents) {
|
||||
reverse_dfs(p);
|
||||
}
|
||||
};
|
||||
|
||||
reverse_dfs(terminal);
|
||||
// For each terminal node, perform a reverse DFS to count, for each free
|
||||
// marker, how many terminals it can reach (i.e., for how many terminals it
|
||||
// is a predecessor). A free marker is reusable if it is a predecessor of
|
||||
// all terminal nodes.
|
||||
std::deque<cudaGraphNode_t> dfs;
|
||||
for (size_t i = 0; i < info.num_terminals; ++i) {
|
||||
dfs.push_back(info.terminals[i]);
|
||||
}
|
||||
|
||||
ska::flat_hash_set<cudaGraphNode_t> reusable_empty_nodes;
|
||||
for (auto [node, count] : num_terminals_reachable) {
|
||||
if (count == terminals.size()) {
|
||||
reusable_empty_nodes.insert(node);
|
||||
while (!dfs.empty()) {
|
||||
auto v = dfs.back();
|
||||
dfs.pop_back();
|
||||
|
||||
if (visited.count(v)) {
|
||||
continue;
|
||||
}
|
||||
visited.insert(v);
|
||||
|
||||
auto parents = get_parents(v);
|
||||
for (auto p : parents) {
|
||||
dfs.push_back(p);
|
||||
}
|
||||
}
|
||||
|
||||
return reusable_empty_nodes;
|
||||
}
|
||||
|
||||
// A block is considered reusable during CUDA graph capture if every free
|
||||
// marker (empty node) associated with the block is a predecessor of every
|
||||
// marker associated with the block is a predecessor of every
|
||||
// terminal node.
|
||||
//
|
||||
// This ensures that any new operation added to the graph will be attached
|
||||
@ -1829,36 +1826,52 @@ class DeviceCachingAllocator {
|
||||
// on every stream, so the block's previous lifetime ends before any new
|
||||
// lifetime begins. This check relies solely on the DAG topology and does not
|
||||
// require event queries, making it safe to use during capture.
|
||||
//
|
||||
// This function iterates over all deferred blocks, determines if their empty
|
||||
// nodes are reusable according to the above criteria, and frees the block if
|
||||
// so.
|
||||
void free_safe_blocks_in_capture(
|
||||
const std::shared_ptr<GatheredContext>& context,
|
||||
cudaStream_t stream) {
|
||||
auto reusable_empty_nodes = get_reusable_empty_nodes(stream);
|
||||
auto info = stream_get_capture_info(stream);
|
||||
|
||||
// If there are no reusable empty nodes (e.g., not currently capturing),
|
||||
// there is nothing to do.
|
||||
if (reusable_empty_nodes.empty()) {
|
||||
if (info.status == cudaStreamCaptureStatusNone || info.num_terminals == 0) {
|
||||
return;
|
||||
}
|
||||
if (graph_reuse_context.find(info.capture_id) ==
|
||||
graph_reuse_context.end()) {
|
||||
bool found = false;
|
||||
for (auto& entry : captures_underway) {
|
||||
if (entry.second(stream)) {
|
||||
auto graph_pool = graph_pools.find(entry.first);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
graph_pool != graph_pools.end(),
|
||||
"Could not find graph pool for capture.");
|
||||
auto mempool_id = graph_pool->first;
|
||||
graph_reuse_context[info.capture_id] = GraphReuseContext{};
|
||||
mempool_to_capture_id[mempool_id] = info.capture_id;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
found, "Could not find memory pool id for capture.");
|
||||
}
|
||||
auto& graph_context = graph_reuse_context[info.capture_id];
|
||||
auto& visited = graph_context.visited[stream];
|
||||
update_visited(info, visited);
|
||||
|
||||
std::vector<Block*> blocks_to_erase;
|
||||
|
||||
for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
|
||||
// Skip this block if it has no empty nodes, as we defer its freeing until
|
||||
for (auto& [block, markers] : deferred_blocks) {
|
||||
// Skip this block if it has no markers, as we defer its freeing until
|
||||
// after graph capture. Also skip if the block was not allocated on the
|
||||
// current stream; such blocks will be freed when
|
||||
// free_safe_blocks_in_capture is attempted on that stream.
|
||||
if (inserted_empty_nodes.empty() || block->stream != stream) {
|
||||
if (markers.empty() || block->stream != stream) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_reusable = true;
|
||||
|
||||
for (const auto& node : inserted_empty_nodes) {
|
||||
if (reusable_empty_nodes.find(node) == reusable_empty_nodes.end()) {
|
||||
for (auto m : markers) {
|
||||
if (!visited.count(m)) {
|
||||
is_reusable = false;
|
||||
break;
|
||||
}
|
||||
@ -1919,11 +1932,11 @@ class DeviceCachingAllocator {
|
||||
if (!block->stream_uses.empty()) {
|
||||
if (C10_UNLIKELY(!captures_underway.empty())) {
|
||||
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
|
||||
// insert_free_marker returns a vector of free markers,
|
||||
// record_free_markers returns a vector of free markers,
|
||||
// or an empty vector if any associated stream is not currently
|
||||
// capturing. The empty vector means that we will defer the free until
|
||||
// capture is finished.
|
||||
deferred_blocks.emplace(block, insert_free_marker(block));
|
||||
deferred_blocks.emplace(block, record_free_markers(block));
|
||||
} else {
|
||||
// If graph_capture_record_stream_reuse is not enabled, always defer
|
||||
// the free until capture is finished.
|
||||
@ -2025,6 +2038,22 @@ class DeviceCachingAllocator {
|
||||
set_fraction = true;
|
||||
}
|
||||
|
||||
/** get expandable segment size for all the streams on device **/
|
||||
std::vector<StreamSegmentSize> getExpandableSegmentSizes() {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
std::vector<StreamSegmentSize> sizes;
|
||||
for (auto& segment : expandable_segments_) {
|
||||
if (!segment->getStream()) {
|
||||
continue;
|
||||
}
|
||||
sizes.emplace_back(
|
||||
segment->getStream(),
|
||||
segment->getSegmentSize() == kSmallBuffer,
|
||||
segment->getMappedSize());
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
|
||||
/** returns cached blocks to the system allocator **/
|
||||
void emptyCache(MempoolId_t mempool_id) {
|
||||
auto context = maybeGatherContext(RecordContext::ALL);
|
||||
@ -2511,6 +2540,21 @@ class DeviceCachingAllocator {
|
||||
// Called by CUDAGraph::capture_end
|
||||
void endAllocateToPool(MempoolId_t mempool_id) {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
|
||||
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse() &&
|
||||
!graph_reuse_context.empty()) {
|
||||
auto capture_id = mempool_to_capture_id[mempool_id];
|
||||
auto graph_context = graph_reuse_context[capture_id];
|
||||
for (auto& [stream, _] : graph_context.visited) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
stream_get_capture_info(stream).status ==
|
||||
cudaStreamCaptureStatusNone,
|
||||
"This stream should not be capturing when the capture is ended");
|
||||
}
|
||||
graph_reuse_context.erase(capture_id);
|
||||
mempool_to_capture_id.erase(mempool_id);
|
||||
}
|
||||
|
||||
for (auto it = captures_underway.begin(); it != captures_underway.end();
|
||||
++it) {
|
||||
if (it->first == mempool_id) {
|
||||
@ -3837,6 +3881,16 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
device_allocator[device]->setMemoryFraction(fraction);
|
||||
}
|
||||
|
||||
std::vector<StreamSegmentSize> getExpandableSegmentSizes(
|
||||
c10::DeviceIndex device) override {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
0 <= device && static_cast<size_t>(device) < device_allocator.size(),
|
||||
"Allocator not initialized for device ",
|
||||
device,
|
||||
": did you call init?");
|
||||
return device_allocator[device]->getExpandableSegmentSizes();
|
||||
}
|
||||
|
||||
void recordHistory(
|
||||
bool enabled,
|
||||
CreateContextFn context_recorder,
|
||||
|
||||
@ -203,6 +203,14 @@ struct ShareableHandle {
|
||||
std::string handle;
|
||||
};
|
||||
|
||||
struct StreamSegmentSize {
|
||||
StreamSegmentSize(cudaStream_t s, bool small, size_t sz)
|
||||
: stream(s), is_small_pool(small), total_size(sz) {}
|
||||
cudaStream_t stream;
|
||||
bool is_small_pool;
|
||||
size_t total_size;
|
||||
};
|
||||
|
||||
class CUDAAllocator : public DeviceAllocator {
|
||||
public:
|
||||
virtual void* raw_alloc(size_t nbytes) = 0;
|
||||
@ -211,6 +219,8 @@ class CUDAAllocator : public DeviceAllocator {
|
||||
virtual void init(int device_count) = 0;
|
||||
virtual double getMemoryFraction(c10::DeviceIndex device) = 0;
|
||||
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
|
||||
virtual std::vector<StreamSegmentSize> getExpandableSegmentSizes(
|
||||
c10::DeviceIndex device) = 0;
|
||||
virtual void enable(bool value) = 0;
|
||||
virtual bool isEnabled() const = 0;
|
||||
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
|
||||
@ -365,6 +375,11 @@ inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
|
||||
return get()->setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
inline std::vector<StreamSegmentSize> getExpandableSegmentSizes(
|
||||
c10::DeviceIndex device) {
|
||||
return get()->getExpandableSegmentSizes(device);
|
||||
}
|
||||
|
||||
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
|
||||
return get()->emptyCache(mempool_id);
|
||||
}
|
||||
|
||||
@ -495,6 +495,13 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
// introduces performance nondeterminism.
|
||||
}
|
||||
|
||||
std::vector<StreamSegmentSize> getExpandableSegmentSizes(
|
||||
c10::DeviceIndex device) override {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"CUDAMallocAsyncAllocator does not yet support getExpandableSegmentSizes.");
|
||||
}
|
||||
|
||||
void emptyCache(/*unused*/ MempoolId_t mempool_id) override {
|
||||
std::lock_guard<std::mutex> lk(general_mutex);
|
||||
|
||||
|
||||
@ -16,21 +16,11 @@ cuda_supported_platforms = [
|
||||
"ovr_config//os:windows-cuda",
|
||||
]
|
||||
|
||||
# rocktenn apparently has its own copy of glog that comes with libmp.dll, so we
|
||||
# had better not try to use glog from c10 lest the glog symbols not be eliminated.
|
||||
C10_USE_GLOG = native.read_config("c10", "use_glog", "1") == "1"
|
||||
|
||||
# If you don't use any functionality that relies on static initializer in c10 (the
|
||||
# most notable ones are the allocators), you can turn off link_whole this way.
|
||||
# In practice, this is only used by rocktenn as well.
|
||||
C10_LINK_WHOLE = native.read_config("c10", "link_whole", "1") == "1"
|
||||
|
||||
def define_c10_ovrsource(name, is_mobile):
|
||||
pp_flags = []
|
||||
if is_mobile:
|
||||
pp_flags.append("-DC10_MOBILE=1")
|
||||
if C10_USE_GLOG:
|
||||
pp_flags.append("-DC10_USE_GLOG")
|
||||
pp_flags = ["-DC10_MOBILE=1"]
|
||||
else:
|
||||
pp_flags = []
|
||||
|
||||
oxx_static_library(
|
||||
name = name,
|
||||
@ -41,7 +31,6 @@ def define_c10_ovrsource(name, is_mobile):
|
||||
"util/*.cpp",
|
||||
]),
|
||||
compatible_with = cpu_supported_platforms,
|
||||
link_whole = C10_LINK_WHOLE,
|
||||
compiler_flags = select({
|
||||
"DEFAULT": [],
|
||||
"ovr_config//compiler:cl": [
|
||||
@ -88,7 +77,6 @@ def define_c10_ovrsource(name, is_mobile):
|
||||
"//arvr/third-party/gflags:gflags",
|
||||
"//third-party/cpuinfo:cpuinfo",
|
||||
"//third-party/fmt:fmt",
|
||||
# For some godforsaken reason, this is always required even when not C10_USE_GLOG
|
||||
"//third-party/glog:glog",
|
||||
],
|
||||
)
|
||||
|
||||
@ -2,175 +2,126 @@
|
||||
#include <arm_neon.h>
|
||||
#include <arm_neon_sve_bridge.h>
|
||||
#include <arm_sve.h>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
|
||||
#include "c10/macros/Macros.h"
|
||||
|
||||
// Log and exp approximations inspired from ACL implementation
|
||||
/// Select `svlog` accuracy:
|
||||
/// - 0: original.
|
||||
/// - 1: more accurate, similar performance.
|
||||
/// - 2: very high accuracy, a bit lower speed.
|
||||
#define SVLOG_ACCURACY 2
|
||||
|
||||
inline float32x4_t vtaylor_polyq_for_log_f32(float32x4_t x) {
|
||||
const float32x4_t log_tab_1 = vdupq_n_f32(-2.29561495781f);
|
||||
const float32x4_t log_tab_2 = vdupq_n_f32(-2.47071170807f);
|
||||
const float32x4_t log_tab_3 = vdupq_n_f32(-5.68692588806f);
|
||||
const float32x4_t log_tab_4 = vdupq_n_f32(-0.165253549814f);
|
||||
const float32x4_t log_tab_5 = vdupq_n_f32(5.17591238022f);
|
||||
const float32x4_t log_tab_6 = vdupq_n_f32(0.844007015228f);
|
||||
const float32x4_t log_tab_7 = vdupq_n_f32(4.58445882797f);
|
||||
const float32x4_t log_tab_8 = vdupq_n_f32(0.0141278216615f);
|
||||
/// Handle special cases in `svexp`:
|
||||
/// - 0: original.
|
||||
/// - 1: use clamp, better performance.
|
||||
/// - 2: no special case handling.
|
||||
#define SVEXP_SPECIAL_CLAMP 1
|
||||
|
||||
float32x4_t A = vmlaq_f32(log_tab_1, log_tab_5, x);
|
||||
float32x4_t B = vmlaq_f32(log_tab_3, log_tab_7, x);
|
||||
float32x4_t C = vmlaq_f32(log_tab_2, log_tab_6, x);
|
||||
float32x4_t x2 = vmulq_f32(x, x);
|
||||
float32x4_t D = svget_neonq(svmad_f32_x(
|
||||
svptrue_b8(),
|
||||
svset_neonq(svundef_f32(), x),
|
||||
svset_neonq(svundef_f32(), log_tab_8),
|
||||
svset_neonq(svundef_f32(), log_tab_4)));
|
||||
float32x4_t x4 = vmulq_f32(x2, x2);
|
||||
float32x4_t res = vmlaq_f32(vmlaq_f32(A, B, x2), vmlaq_f32(C, D, x2), x4);
|
||||
return res;
|
||||
#if SVLOG_ACCURACY == 2
|
||||
static inline svfloat32_t svlog(svfloat32_t x) {
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
svint32_t u = svreinterpret_s32(x) - 0x3F2AAAAB;
|
||||
|
||||
svfloat32_t r = svreinterpret_f32((u & 0x007FFFFF) + 0x3F2AAAAB) - 1.0f;
|
||||
svfloat32_t n = svcvt_f32_x(ptrue, u >> 23);
|
||||
asm("" : "+w"(r)); // NOTE: can improve instruction scheduling.
|
||||
|
||||
svfloat32_t r2 = r * r;
|
||||
svfloat32_t p = -0x1.4F9934p-3f + r * 0x1.5A9AA2p-3f;
|
||||
svfloat32_t q = -0x1.00187Cp-2f + r * 0x1.961348p-3f;
|
||||
svfloat32_t y = -0x1.FFFFC8p-2f + r * 0x1.555D7Cp-2f;
|
||||
return (r + n * 0x1.62E43p-1f) +
|
||||
(y + (q + (p + -0x1.3E737Cp-3f * r2) * r2) * r2) * r2;
|
||||
}
|
||||
#elif SVLOG_ACCURACY == 1
|
||||
static inline svfloat32_t svlog(svfloat32_t x) {
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
inline float32x4_t vlogq_f32(float32x4_t x) {
|
||||
const float32x4_t CONST_LN2 = vdupq_n_f32(0.6931471805f); // ln(2)
|
||||
svint32_t u = svreinterpret_s32(x) - 0x3F2AAAAB;
|
||||
|
||||
// Extract exponent
|
||||
int32x4_t m = svget_neonq(svsub_n_s32_x(
|
||||
svptrue_b8(),
|
||||
svset_neonq(
|
||||
svundef_s32(),
|
||||
vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_f32(x), 23))),
|
||||
127));
|
||||
float32x4_t val = vreinterpretq_f32_s32(
|
||||
vsubq_s32(vreinterpretq_s32_f32(x), vshlq_n_s32(m, 23)));
|
||||
svfloat32_t r = svreinterpret_f32((u & 0x007FFFFF) + 0x3F2AAAAB) - 1.0f;
|
||||
svfloat32_t n = svcvt_f32_x(ptrue, u >> 23);
|
||||
asm("" : "+w"(r)); // NOTE: can improve instruction scheduling.
|
||||
|
||||
// Polynomial Approximation
|
||||
float32x4_t poly = vtaylor_polyq_for_log_f32(val);
|
||||
svfloat32_t r2 = r * r;
|
||||
svfloat32_t A = -0x1.923814p-3f + r * 0x1.689E5Ep-3f;
|
||||
svfloat32_t B = -0x1.FC0968p-3f + r * 0x1.93BF0Cp-3f;
|
||||
svfloat32_t C = -0x1.000478p-1f + r * 0x1.556906p-2f;
|
||||
|
||||
// Reconstruct
|
||||
poly = vmlaq_f32(poly, vcvtq_f32_s32(m), CONST_LN2);
|
||||
return (r + n * 0x1.62E43p-1f) + (C + (B + A * r2) * r2) * r2;
|
||||
}
|
||||
#elif SVLOG_ACCURACY == 0
|
||||
static inline svfloat32_t svlog(svfloat32_t x) {
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
svint32_t u = svsra_n_s32(svdup_n_s32(-127), svreinterpret_s32(x), 23);
|
||||
|
||||
svfloat32_t n = svcvt_f32_x(ptrue, u);
|
||||
svfloat32_t r = svreinterpret_f32(svreinterpret_s32(x) - (u << 23));
|
||||
|
||||
svfloat32_t D = -0.165253549814f + r * 0.0141278216615f;
|
||||
svfloat32_t C = -2.47071170807f + r * 0.844007015228f;
|
||||
svfloat32_t B = -5.68692588806f + r * 4.58445882797f;
|
||||
svfloat32_t A = -2.29561495781f + r * 5.17591238022f;
|
||||
|
||||
svfloat32_t r2 = r * r;
|
||||
return (A + n * 0.6931471805f) + (B + (C + D * r2) * r2) * r2;
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline svfloat32_t svexp(svfloat32_t x) {
|
||||
// Clamp interval set to prevent denormals!
|
||||
const svfloat32_t max_input = svdup_n_f32(88.722839f);
|
||||
const svfloat32_t min_input = svdup_n_f32(-87.33654f);
|
||||
const svfloat32_t shift = svdup_n_f32(0x1.0000FEp+23f);
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
#if SVEXP_SPECIAL_CLAMP == 1
|
||||
x = svmax_x(ptrue, svmin_x(ptrue, x, max_input), min_input);
|
||||
#endif
|
||||
|
||||
svfloat32_t z = svmla_n_f32_x(ptrue, shift, x, 0x1.715476p+0f);
|
||||
svfloat32_t n = z - shift;
|
||||
svfloat32_t scale = svreinterpret_f32(svreinterpret_u32(z) << 23);
|
||||
|
||||
svfloat32_t r_hi = x - n * 0x1.62E400p-1f;
|
||||
svfloat32_t r = r_hi - n * 0x1.7F7D1Cp-20f;
|
||||
svfloat32_t r2 = r * r;
|
||||
|
||||
svfloat32_t C = 0x1.573E2Ep-5f + r * 0x1.0E4020p-7f;
|
||||
svfloat32_t B = 0x1.FFFDB6p-2f + r * 0x1.555E66p-3f;
|
||||
svfloat32_t A = r * 0x1.FFFFECp-1f;
|
||||
|
||||
svfloat32_t poly = scale + (A + (B + C * r2) * r2) * scale;
|
||||
|
||||
#if SVEXP_SPECIAL_CLAMP == 0
|
||||
const svfloat32_t inf = svdup_n_f32(std::numeric_limits<float>::infinity());
|
||||
poly = svsel_f32(svcmplt_f32(ptrue, x, min_input), svdup_n_f32(0.0f), poly);
|
||||
poly = svsel_f32(svcmpgt_f32(ptrue, x, max_input), inf, poly);
|
||||
#endif
|
||||
|
||||
return poly;
|
||||
}
|
||||
|
||||
inline float32x4_t vexpq_f32(float32x4_t x) {
|
||||
const auto c1 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3f7ffff6)));
|
||||
const auto c2 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3efffedb)));
|
||||
const auto c3 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3e2aaf33)));
|
||||
const auto c4 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3d2b9f17)));
|
||||
const auto c5 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3c072010)));
|
||||
|
||||
const auto shift = vreinterpretq_f32_u32(
|
||||
svget_neonq(svdup_n_u32(0x4b00007f))); // 2^23 + 127 = 0x1.0000fep23f
|
||||
const auto inv_ln2 = vreinterpretq_f32_u32(
|
||||
svget_neonq(svdup_n_u32(0x3fb8aa3b))); // 1 / ln(2) = 0x1.715476p+0f
|
||||
const auto neg_ln2_hi = vreinterpretq_f32_u32(svget_neonq(
|
||||
svdup_n_u32(0xbf317200))); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
|
||||
const auto neg_ln2_lo = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(
|
||||
0xb5bfbe8e))); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
|
||||
|
||||
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
|
||||
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
|
||||
const auto zero = svdup_n_f32(0.f);
|
||||
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
|
||||
|
||||
// Range reduction:
|
||||
// e^x = 2^n * e^r
|
||||
// where:
|
||||
// n = floor(x / ln(2))
|
||||
// r = x - n * ln(2)
|
||||
//
|
||||
// By adding x / ln(2) with 2^23 + 127 (shift):
|
||||
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
|
||||
// forces decimal part
|
||||
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n)
|
||||
// + 127 will occupy the whole fraction part of z in FP32 format.
|
||||
// Subtracting 2^23 + 127 (shift) from z will result in the integer part
|
||||
// of x / ln(2) (i.e. n) because the decimal part has been pushed out and
|
||||
// lost.
|
||||
// * The addition of 127 makes the FP32 fraction part of z ready to be used
|
||||
// as the exponent
|
||||
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
|
||||
const auto z = vfmaq_f32(shift, x, inv_ln2);
|
||||
const auto n = z - shift;
|
||||
const auto scale =
|
||||
vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n
|
||||
|
||||
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy
|
||||
// beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in term
|
||||
// of accuracy and performance.
|
||||
const auto r_hi = vfmaq_f32(x, n, neg_ln2_hi);
|
||||
const auto r = vfmaq_f32(r_hi, n, neg_ln2_lo);
|
||||
|
||||
// Compute the truncated Taylor series of e^r.
|
||||
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
|
||||
const auto r2 = r * r;
|
||||
|
||||
const auto p1 = c1 * r;
|
||||
const auto p23 = vfmaq_f32(c2, c3, r);
|
||||
const auto p45 = vfmaq_f32(c4, c5, r);
|
||||
const auto p2345 = vfmaq_f32(p23, p45, r2);
|
||||
const auto p12345 = vfmaq_f32(p1, p2345, r2);
|
||||
|
||||
auto poly = svset_neonq(svundef_f32(), vfmaq_f32(scale, p12345, scale));
|
||||
|
||||
auto pHigh = svcmpgt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), max_input);
|
||||
auto pLow = svcmplt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), min_input);
|
||||
|
||||
auto bound = svsel_f32(
|
||||
pHigh,
|
||||
inf,
|
||||
zero);
|
||||
|
||||
auto pCombined = svorr_b_z(svptrue_b8(), pLow, pHigh);
|
||||
|
||||
// Handle underflow and overflow.
|
||||
poly = svsel_f32(
|
||||
pCombined,
|
||||
bound,
|
||||
poly);
|
||||
|
||||
return svget_neonq(poly);
|
||||
}
|
||||
|
||||
// ln(x) = log2(x) * ln(2)
|
||||
// pow(x, n) = exp(n * ln(x))
|
||||
inline float32x4_t compute_batch_box_cox_vec_sve128_float(
|
||||
static inline svfloat32_t compute_batch_box_cox_vec_sve128_float(
|
||||
svfloat32_t lambda1_v,
|
||||
svfloat32_t lambda2_v,
|
||||
svfloat32_t data_v,
|
||||
svfloat32_t k_eps) {
|
||||
// sum_v = lambda2_v + data_v
|
||||
float32x4_t sum_v = vaddq_f32(svget_neonq(data_v), svget_neonq(lambda2_v));
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
// test lambda1_v: predNZ == 1 iff lambda1_v != 0
|
||||
svbool_t predNZ = svcmpne_n_f32(svptrue_b8(), lambda1_v, 0.0f);
|
||||
|
||||
// clamp sum_v: sum_v = max(sum_v, k_eps)
|
||||
sum_v = vmaxq_f32(sum_v, svget_neonq(k_eps));
|
||||
|
||||
// lnData = log(sum_v)
|
||||
svfloat32_t lnData = svset_neonq(svundef_f32(), vlogq_f32(sum_v));
|
||||
|
||||
// if any lambda1 != 0, compute pow(sum_v, lambda1) using lnData
|
||||
// pow(sum_v, lambda1) == exp(lambda1 * ln(sum_v))
|
||||
svfloat32_t lnData = svlog(svmax_x(ptrue, data_v + lambda2_v, k_eps));
|
||||
svbool_t predNZ = svcmpne_n_f32(ptrue, lambda1_v, 0.0f);
|
||||
if (C10_LIKELY(svptest_any(predNZ, predNZ))) {
|
||||
// mult = lambda1 * ln(sum_v)
|
||||
float32x4_t mult = vmulq_f32(svget_neonq(lnData), svget_neonq(lambda1_v));
|
||||
|
||||
// lambda1_r = 1 / lambda1
|
||||
svfloat32_t lambda1_r = svdivr_f32_m(predNZ, lambda1_v, svdup_n_f32(1.0f));
|
||||
|
||||
// pow = exp(mult)
|
||||
float32x4_t pow = vexpq_f32(mult);
|
||||
|
||||
// merge results
|
||||
// lnData if lambda1 == 0, (lambda1_r * pow - lambda1_r) if lambda1 != 0
|
||||
svfloat32_t pow = svexp(lnData * lambda1_v);
|
||||
lnData = svsel_f32(predNZ, lambda1_r, lnData);
|
||||
lnData =
|
||||
svnmsb_f32_m(predNZ, lnData, svset_neonq(svundef_f32(), pow), lnData);
|
||||
lnData = svnmsb_f32_m(predNZ, lnData, pow, lnData);
|
||||
}
|
||||
return svget_neonq(lnData);
|
||||
return lnData;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -186,11 +137,11 @@ template <>
|
||||
void compute_batch_box_cox_vec_sve128(
|
||||
std::size_t N,
|
||||
std::size_t D,
|
||||
const float* data_ptr,
|
||||
const float* __restrict lambda1_ptr,
|
||||
const float* __restrict lambda2_ptr,
|
||||
float* output_ptr) {
|
||||
svfloat32_t k_eps = svdup_n_f32(static_cast<float>(1e-6));
|
||||
const float *data_ptr,
|
||||
const float *__restrict lambda1_ptr,
|
||||
const float *__restrict lambda2_ptr,
|
||||
float *output_ptr) {
|
||||
const svfloat32_t k_eps = svdup_n_f32(static_cast<float>(1e-6));
|
||||
|
||||
std::size_t remainder = D % 4;
|
||||
std::size_t loopBound = D - remainder;
|
||||
@ -204,17 +155,17 @@ void compute_batch_box_cox_vec_sve128(
|
||||
svfloat32_t lambda2_v =
|
||||
svset_neonq(svundef_f32(), vld1q_f32(lambda2_ptr + j));
|
||||
svfloat32_t data_v = svset_neonq(svundef_f32(), vld1q_f32(data_ptr));
|
||||
float32x4_t result = compute_batch_box_cox_vec_sve128_float(
|
||||
svfloat32_t result = compute_batch_box_cox_vec_sve128_float(
|
||||
lambda1_v, lambda2_v, data_v, k_eps);
|
||||
vst1q_f32(output_ptr, result);
|
||||
vst1q_f32(output_ptr, svget_neonq(result));
|
||||
}
|
||||
if (C10_LIKELY(remainder > 0)) {
|
||||
svfloat32_t lambda1_v = svld1_f32(remainderPred, lambda1_ptr + loopBound);
|
||||
svfloat32_t lambda2_v = svld1_f32(remainderPred, lambda2_ptr + loopBound);
|
||||
svfloat32_t data_v = svld1_f32(remainderPred, data_ptr);
|
||||
float32x4_t result = compute_batch_box_cox_vec_sve128_float(
|
||||
svfloat32_t result = compute_batch_box_cox_vec_sve128_float(
|
||||
lambda1_v, lambda2_v, data_v, k_eps);
|
||||
svst1_f32(remainderPred, output_ptr, svset_neonq(svundef_f32(), result));
|
||||
svst1_f32(remainderPred, output_ptr, result);
|
||||
data_ptr += remainder;
|
||||
output_ptr += remainder;
|
||||
}
|
||||
|
||||
@ -153,6 +153,7 @@ _ZN3c104impl12PyObjectSlot10owns_pyobjEv
|
||||
_ZN3c104impl12PyObjectSlot19maybe_destroy_pyobjEv
|
||||
_ZN3c104impl12PyObjectSlotC1Ev
|
||||
_ZN3c104impl12PyObjectSlotD2Ev
|
||||
_ZN3c104impl19HermeticPyObjectTLS13get_tls_stateEv
|
||||
_ZN3c104impl20TorchDispatchModeTLS13any_modes_setEb
|
||||
_ZN3c104impl23ExcludeDispatchKeyGuardC1ENS_14DispatchKeySetE
|
||||
_ZN3c104impl23ExcludeDispatchKeyGuardD2Ev
|
||||
|
||||
72
docs/source/accelerator/amp.md
Normal file
72
docs/source/accelerator/amp.md
Normal file
@ -0,0 +1,72 @@
|
||||
# Automatic Mixed Precision
|
||||
|
||||
## Background
|
||||
|
||||
Automatic Mixed Precision (AMP) enables the use of both single precision (32-bit) and half precision (16-bit) floating point types during training or inference.
|
||||
|
||||
Key components include:
|
||||
|
||||
- [**Autocast**](https://docs.pytorch.org/docs/stable/amp.html#autocasting): Automatically casts operations to lower-precision (e.g., float16 or bfloat16) to improve performance while maintaining accuracy.
|
||||
- [**Gradient Scaling**](https://docs.pytorch.org/docs/stable/amp.html#gradient-scaling): Dynamically scales gradients during backpropagation to prevent underflow when training with mixed precision.
|
||||
|
||||
## Design
|
||||
|
||||
### Casting Strategy
|
||||
|
||||
The [`CastPolicy`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L416-L438) is used to define type conversion rules. Each enum value represents a set of type conversion requirements for a group of operators, ensuring consistent handling of operations that prioritize either precision or performance.
|
||||
|
||||
| Policy | Explanation |
|
||||
| :--- | :--- |
|
||||
| **`lower_precision_fp`** | Cast all inputs to `lower_precision_fp` before execute the op. |
|
||||
| **`fp32`** | Cast all inputs to `at::kFloat` before running the op. |
|
||||
| **`fp32_set_opt_dtype`** | Execution in `at::kFloat`, while respecting user-specified output dtype if provided. |
|
||||
| **`fp32_append_dtype`** | Append at::kFloat to the args and redispatch to the type-aware overload |
|
||||
| **`promote`** | Promote all inputs to the “widest” dtype before execution. |
|
||||
|
||||
### Operators Lists
|
||||
|
||||
PyTorch defines a general list of operators for each of casting strategies mentioned above, as a reference for developers of new accelerators.
|
||||
|
||||
| Policy | Operators List |
|
||||
| :--- | :--- |
|
||||
| **`lower_precision_fp`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L819-L852) |
|
||||
| **`fp32`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L854-L912) |
|
||||
| **`fp32_set_opt_dtype`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L914-L931) |
|
||||
| **`fp32_append_dtype`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L933-L958) |
|
||||
| **`promote`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L960-L971) |
|
||||
|
||||
## Implementation
|
||||
|
||||
### Python Integration
|
||||
|
||||
Implement the `get_amp_supported_dtype` method to return the data types supported by the new accelerator in the AMP context.
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/amp/__init__.py
|
||||
:language: python
|
||||
:start-after: LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
|
||||
:end-before: LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE
|
||||
:linenos:
|
||||
```
|
||||
|
||||
### C++ Integration
|
||||
|
||||
This section shows how AMP registers autocast kernels for the `AutocastPrivateUse1` dispatch key.
|
||||
|
||||
- Register a fallback that makes unhandled ops fall through to their normal implementations.
|
||||
- Register specific aten kernels under `AutocastPrivateUse1` using the `KERNEL_PRIVATEUSEONE` helper macro, which maps an op to the desired precision implementation (with enum `at::autocast::CastPolicy`)
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: AMP FALLTHROUTH
|
||||
:end-before: LITERALINCLUDE END: AMP FALLTHROUTH
|
||||
:linenos:
|
||||
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: AMP IMPL
|
||||
:end-before: LITERALINCLUDE END: AMP IMPL
|
||||
:emphasize-lines: 3,6,8-10
|
||||
:linenos:
|
||||
```
|
||||
@ -44,6 +44,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
|
||||
|
||||
autoload
|
||||
operators
|
||||
amp
|
||||
```
|
||||
|
||||
[OpenReg URL]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg "OpenReg URL"
|
||||
|
||||
@ -339,13 +339,16 @@ XLA
|
||||
~~~
|
||||
|
||||
- Jack Cao (`JackCaoG <https://github.com/JackCaoG>`__)
|
||||
- Daniel Sohn (`jysohn23 <https://github.com/jysohn23>`__)
|
||||
- Zach Cain (`zcain117 <https://github.com/zcain117>`__)
|
||||
- Han Qi (`qihqi <https://github.com/qihqi>`__)
|
||||
- Yifei Teng (`tengyifei <https://github.com/tengyifei>`__)
|
||||
- Siyuan Liu (`lsy323 <https://github.com/lsy323>`__)
|
||||
- Brian Hirsh (`bdhirsh <https://github.com/bdhirsh>`__)
|
||||
- Gregory Chanan (`gchanan <https://github.com/gchanan>`__)
|
||||
- (emeritus) Gregory Chanan (`gchanan <https://github.com/gchanan>`__)
|
||||
- (emeritus) Ailing Zhang (`ailzhang <https://github.com/ailzhang>`__)
|
||||
- (emeritus) Davide Libenzi (`dlibenzi <https://github.com/dlibenzi>`__)
|
||||
- (emeritus) Alex Suhan (`asuhan <https://github.com/asuhan>`__)
|
||||
- (emeritus) Daniel Sohn (`jysohn23 <https://github.com/jysohn23>`__)
|
||||
- (emeritus) Zach Cain (`zcain117 <https://github.com/zcain117>`__)
|
||||
|
||||
TorchServe
|
||||
~~~~~~~~~~
|
||||
|
||||
@ -613,8 +613,7 @@ Available options:
|
||||
CUDA Graph capture by using the graph topology (instead of CUDA events) to determine
|
||||
when a freed block is safe to reuse. This can reduce peak memory during long captures that free
|
||||
and reallocate buffers across multiple streams, especially when the capture DAG frequently
|
||||
reaches joined frontiers. Note: Enabling this option can significantly increase the time spent
|
||||
capturing the graph.
|
||||
reaches joined frontiers.
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ set(AOTI_ABI_CHECK_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_abi_check)
|
||||
set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
|
||||
@ -27,7 +28,7 @@ add_executable(test_aoti_abi_check
|
||||
target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
|
||||
|
||||
# WARNING: DO NOT LINK torch!!!
|
||||
# The purpose is to check if the used aten/c10 headers are writtern in a header-only way
|
||||
# The purpose is to check if the used aten/c10 headers are written in a header-only way
|
||||
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
|
||||
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})
|
||||
|
||||
|
||||
35
test/cpp/aoti_abi_check/test_devicetype.cpp
Normal file
35
test/cpp/aoti_abi_check/test_devicetype.cpp
Normal file
@ -0,0 +1,35 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/core/DeviceType.h>
|
||||
|
||||
TEST(TestDeviceType, TestDeviceType) {
|
||||
using torch::headeronly::DeviceType;
|
||||
constexpr DeviceType expected_device_types[] = {
|
||||
torch::headeronly::kCPU,
|
||||
torch::headeronly::kCUDA,
|
||||
DeviceType::MKLDNN,
|
||||
DeviceType::OPENGL,
|
||||
DeviceType::OPENCL,
|
||||
DeviceType::IDEEP,
|
||||
torch::headeronly::kHIP,
|
||||
torch::headeronly::kFPGA,
|
||||
torch::headeronly::kMAIA,
|
||||
torch::headeronly::kXLA,
|
||||
torch::headeronly::kVulkan,
|
||||
torch::headeronly::kMetal,
|
||||
torch::headeronly::kXPU,
|
||||
torch::headeronly::kMPS,
|
||||
torch::headeronly::kMeta,
|
||||
torch::headeronly::kHPU,
|
||||
torch::headeronly::kVE,
|
||||
torch::headeronly::kLazy,
|
||||
torch::headeronly::kIPU,
|
||||
torch::headeronly::kMTIA,
|
||||
torch::headeronly::kPrivateUse1,
|
||||
};
|
||||
for (int8_t i = 0; i <
|
||||
static_cast<int8_t>(torch::headeronly::COMPILE_TIME_MAX_DEVICE_TYPES);
|
||||
i++) {
|
||||
EXPECT_EQ(static_cast<DeviceType>(i), expected_device_types[i]);
|
||||
}
|
||||
}
|
||||
@ -25,6 +25,8 @@ The goal of `torch_openreg` is **not to implement a fully functional, high-perfo
|
||||
torch_openreg/
|
||||
├── CMakeLists.txt
|
||||
├── csrc
|
||||
│ ├── amp
|
||||
│ │ └── autocast_mode.cpp
|
||||
│ ├── aten
|
||||
│ │ ├── native
|
||||
│ │ │ ├── Extra.cpp
|
||||
@ -59,6 +61,8 @@ torch_openreg/
|
||||
│ └── stub.c
|
||||
├── __init__.py
|
||||
└── openreg
|
||||
├── amp
|
||||
│ └── __init__.py
|
||||
├── __init__.py
|
||||
├── meta.py
|
||||
└── random.py
|
||||
@ -95,11 +99,12 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
|
||||
**Key Directories**:
|
||||
|
||||
- `csrc/`: Core device implementation, including operator registration, runtime, etc.
|
||||
- `csrc/amp/`: AMP(Automatic Mixed Precision)
|
||||
- `csrc/aten/`: Operator registration
|
||||
- `csrc/aten/native/`: Specific operator implementations for the OpenReg device.
|
||||
- `csrc/aten/native/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion).
|
||||
- `csrc/aten/native/OpenRegExtra.cpp`: Implementations for other types of operators.
|
||||
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
|
||||
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
|
||||
- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU.
|
||||
- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings).
|
||||
- `torch_openreg/csrc/`: Python C++ binding code.
|
||||
@ -126,13 +131,18 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
|
||||
|
||||
### Autoload
|
||||
|
||||
- Autoload Machanism
|
||||
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
|
||||
|
||||
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
|
||||
- Register the backend with Python `entry points`: See `setup` in `setup.py`
|
||||
- Add a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
|
||||
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
|
||||
|
||||
- Registering the backend with Python `entry points`: See `setup` in `setup.py`
|
||||
- Adding a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
|
||||
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
|
||||
### AMP(Automatic Mixed Precision)
|
||||
|
||||
`AMP` provides convenience methods for mixed precision, where some operations use the `torch.float32` datatype and other operations use `lower precision` floating point datatype: `torch.float16` or `torch.bfloat16`.
|
||||
|
||||
- Register specific operator conversion rules: See `autocat_mode.cpp` in `csrc/amp`.
|
||||
- Add support for new data types for different accelerators: See `get_amp_supported_dtype` in `torch_openreg/openreg/amp/__init__.py`
|
||||
|
||||
## Installation and Usage
|
||||
|
||||
@ -168,11 +178,13 @@ print("Result z:\n", z)
|
||||
print(f"Device of z: {z.device}")
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
Please refer to [this](https://docs.pytorch.org/docs/main/accelerator/index.html) for a series of documents on integrating new accelerators into PyTorch, which will be kept in sync with the `OpenReg` codebase as well.
|
||||
|
||||
## Future Plans
|
||||
|
||||
- **Enhance Features**:
|
||||
- Autoload
|
||||
- AMP
|
||||
- Device-agnostic APIs
|
||||
- Memory Management
|
||||
- Generator
|
||||
@ -180,5 +192,3 @@ print(f"Device of z: {z.device}")
|
||||
- Custom Tensor&Storage
|
||||
- ...
|
||||
- **Improve Tests**: Add more test cases related to the integration mechanism.
|
||||
- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation.
|
||||
- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync.
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
#include <ATen/autocast_mode.h>
|
||||
|
||||
using at::Tensor;
|
||||
|
||||
Tensor binary_cross_entropy_banned(
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const std::optional<Tensor>&,
|
||||
int64_t) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n"
|
||||
"Many models use a sigmoid layer right before the binary cross entropy layer.\n"
|
||||
"In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n"
|
||||
"or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n"
|
||||
"safe to autocast.");
|
||||
}
|
||||
|
||||
// LITERALINCLUDE START: AMP FALLTHROUTH
|
||||
TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
}
|
||||
// LITERALINCLUDE END: AMP FALLTHROUTH
|
||||
|
||||
// LITERALINCLUDE START: AMP IMPL
|
||||
TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) {
|
||||
// lower_precision_fp
|
||||
KERNEL_PRIVATEUSEONE(mm, lower_precision_fp)
|
||||
|
||||
// fp32
|
||||
KERNEL_PRIVATEUSEONE(asin, fp32)
|
||||
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
|
||||
TORCH_FN((&binary_cross_entropy_banned)));
|
||||
}
|
||||
// LITERALINCLUDE END: AMP IMPL
|
||||
@ -0,0 +1,50 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestAutocast(TestCase):
|
||||
def test_autocast_with_unsupported_type(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"In openreg autocast, but the target dtype torch.float32 is not supported.",
|
||||
):
|
||||
with torch.autocast(device_type="openreg", dtype=torch.float32):
|
||||
_ = torch.ones(10)
|
||||
|
||||
def test_autocast_operator_not_supported(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.",
|
||||
):
|
||||
x = torch.randn(2, 3, device="openreg")
|
||||
y = torch.randn(2, 3, device="openreg")
|
||||
with torch.autocast(device_type="openreg", dtype=torch.float16):
|
||||
_ = torch.nn.functional.binary_cross_entropy(x, y)
|
||||
|
||||
def test_autocast_low_precision(self):
|
||||
with torch.amp.autocast(device_type="openreg", dtype=torch.float16):
|
||||
x = torch.randn(2, 3, device="openreg")
|
||||
y = torch.randn(3, 3, device="openreg")
|
||||
result = torch.mm(x, y)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
def test_autocast_fp32(self):
|
||||
with torch.amp.autocast(device_type="openreg"):
|
||||
x = torch.randn(2, device="openreg", dtype=torch.float16)
|
||||
result = torch.asin(x)
|
||||
self.assertEqual(result.dtype, torch.float32)
|
||||
|
||||
def test_autocast_default_dtype(self):
|
||||
openreg_fast_dtype = torch.get_autocast_dtype(device_type="openreg")
|
||||
self.assertEqual(openreg_fast_dtype, torch.half)
|
||||
|
||||
def test_autocast_set_dtype(self):
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
torch.set_autocast_dtype("openreg", dtype)
|
||||
self.assertEqual(torch.get_autocast_dtype("openreg"), dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -3,6 +3,7 @@ import torch
|
||||
import torch_openreg._C # type: ignore[misc]
|
||||
|
||||
from . import meta # noqa: F401
|
||||
from .amp import get_amp_supported_dtype # noqa: F401
|
||||
|
||||
|
||||
_initialized = False
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
import torch
|
||||
|
||||
|
||||
# LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
|
||||
def get_amp_supported_dtype():
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
# LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE
|
||||
@ -15,6 +15,9 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._composable import checkpoint, replicate
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.fsdp import (
|
||||
FSDPModule,
|
||||
@ -58,6 +61,7 @@ from torch.testing._internal.common_fsdp import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_XPU, xfailIf
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
FeedForward,
|
||||
ModelArgs,
|
||||
Transformer,
|
||||
TransformerBlock,
|
||||
@ -1010,6 +1014,222 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||
self.assertEqual(events, expected_backward_events)
|
||||
events.clear()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_set_modules_to_backward_prefetch_inside_ac(self):
|
||||
n_layers = 3
|
||||
reshard_after_forward = True
|
||||
# use checkpoint wrapper instead of torch.utils
|
||||
model_args = ModelArgs(n_layers=n_layers, checkpoint_activations=False)
|
||||
model = Transformer(model_args)
|
||||
apply_activation_checkpointing(
|
||||
model, check_fn=lambda m: isinstance(m, TransformerBlock)
|
||||
)
|
||||
apply_activation_checkpointing(
|
||||
model, check_fn=lambda m: isinstance(m, FeedForward)
|
||||
)
|
||||
fully_shard([model.tok_embeddings, model.pos_embeddings])
|
||||
for layer in model.layers:
|
||||
# mimic fully_shard(layer.moe.experts)
|
||||
fully_shard(
|
||||
layer.feed_forward.w1, reshard_after_forward=reshard_after_forward
|
||||
)
|
||||
fully_shard(layer, reshard_after_forward=reshard_after_forward)
|
||||
fully_shard(
|
||||
[model.norm, model.output], reshard_after_forward=reshard_after_forward
|
||||
)
|
||||
fully_shard(model, reshard_after_forward=reshard_after_forward)
|
||||
inp = torch.randint(
|
||||
0,
|
||||
model_args.vocab_size,
|
||||
(2, model_args.max_seq_len),
|
||||
device=device_type.type,
|
||||
)
|
||||
|
||||
def set_backward_prefetch(model: Transformer) -> None:
|
||||
# tell pyre model.set_modules_to_backward_prefetch is available
|
||||
assert isinstance(model, FSDPModule)
|
||||
assert isinstance(model.output, FSDPModule)
|
||||
|
||||
# mimic deepseek MOE
|
||||
# prefetch layer - 1 and its feedforward before cpu sync during a2a
|
||||
reversed_transformer_blocks = list(reversed(model.layers))
|
||||
prev_transformer_blocks = reversed_transformer_blocks[1:] + [None]
|
||||
|
||||
if (
|
||||
model.norm is not None
|
||||
and model.output is not None
|
||||
and len(model.layers) > 0
|
||||
):
|
||||
assert isinstance(reversed_transformer_blocks[0], FSDPModule)
|
||||
model.output.set_modules_to_backward_prefetch(
|
||||
[reversed_transformer_blocks[0]]
|
||||
)
|
||||
|
||||
for transformer_block, prev_transformer_block in zip(
|
||||
reversed_transformer_blocks, prev_transformer_blocks
|
||||
):
|
||||
assert isinstance(transformer_block, FSDPModule)
|
||||
if prev_transformer_block is not None:
|
||||
assert isinstance(prev_transformer_block, FSDPModule)
|
||||
assert hasattr(prev_transformer_block.feed_forward, "w1")
|
||||
assert isinstance(
|
||||
prev_transformer_block.feed_forward.w1, FSDPModule
|
||||
)
|
||||
transformer_block.set_modules_to_backward_prefetch(
|
||||
[
|
||||
prev_transformer_block,
|
||||
prev_transformer_block.feed_forward.w1,
|
||||
]
|
||||
)
|
||||
elif model.tok_embeddings is not None:
|
||||
assert isinstance(model.tok_embeddings, FSDPModule)
|
||||
transformer_block.set_modules_to_backward_prefetch(
|
||||
[model.tok_embeddings]
|
||||
)
|
||||
|
||||
events: list[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
)
|
||||
reshard_with_record = self._get_reshard_with_record(
|
||||
FSDPParamGroup.reshard, events
|
||||
)
|
||||
with (
|
||||
patch_unshard(unshard_with_record),
|
||||
patch_reshard(reshard_with_record),
|
||||
):
|
||||
loss = model(inp)
|
||||
events.clear()
|
||||
loss.sum().backward()
|
||||
expected_backward_events = [
|
||||
("unshard", "norm, output", TrainingState.PRE_BACKWARD),
|
||||
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "norm, output", TrainingState.POST_BACKWARD),
|
||||
# layers.2 prefetch w1
|
||||
(
|
||||
"unshard",
|
||||
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
# layers.2.w1 prefetch layers.1
|
||||
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
|
||||
(
|
||||
"reshard",
|
||||
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.2", TrainingState.POST_BACKWARD),
|
||||
(
|
||||
"unshard",
|
||||
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
|
||||
(
|
||||
"reshard",
|
||||
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.1", TrainingState.POST_BACKWARD),
|
||||
(
|
||||
"unshard",
|
||||
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
(
|
||||
"unshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.0", TrainingState.POST_BACKWARD),
|
||||
(
|
||||
"reshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "norm, output", TrainingState.POST_BACKWARD),
|
||||
]
|
||||
self.assertEqual(events, expected_backward_events)
|
||||
events.clear()
|
||||
|
||||
set_backward_prefetch(model)
|
||||
loss = model(inp)
|
||||
events.clear()
|
||||
loss.sum().backward()
|
||||
expected_backward_events = expected_backward_events = [
|
||||
("unshard", "norm, output", TrainingState.PRE_BACKWARD),
|
||||
# root explicit prefetch layers.2
|
||||
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "norm, output", TrainingState.POST_BACKWARD),
|
||||
# layers.2 prefetch layers.1 and feed_forward
|
||||
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
|
||||
(
|
||||
"unshard",
|
||||
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
# AC recompute_fn
|
||||
(
|
||||
"unshard",
|
||||
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.FORWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"layers.2._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.2", TrainingState.POST_BACKWARD),
|
||||
# layers.1 prefetch layers.0
|
||||
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
|
||||
(
|
||||
"unshard",
|
||||
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"layers.1._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.1", TrainingState.POST_BACKWARD),
|
||||
# layers.0 prefetch embeddings
|
||||
(
|
||||
"unshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.PRE_BACKWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"layers.0._checkpoint_wrapped_module.feed_forward._checkpoint_wrapped_module.w1",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "layers.0", TrainingState.POST_BACKWARD),
|
||||
(
|
||||
"reshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
(
|
||||
"reshard",
|
||||
"tok_embeddings, pos_embeddings",
|
||||
TrainingState.POST_BACKWARD,
|
||||
),
|
||||
("reshard", "norm, output", TrainingState.POST_BACKWARD),
|
||||
]
|
||||
self.assertEqual(events, expected_backward_events)
|
||||
events.clear()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_fully_shard_multi_module_backward_prefetch(self):
|
||||
n_layers = 5
|
||||
|
||||
626
test/distributed/_composable/test_replicate_mixed_precision.py
Normal file
626
test/distributed/_composable/test_replicate_mixed_precision.py
Normal file
@ -0,0 +1,626 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable.replicate_with_fsdp import replicate
|
||||
from torch.distributed.fsdp import MixedPrecisionPolicy
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||
_get_gradient_divide_factors,
|
||||
)
|
||||
from torch.distributed.tensor import Shard
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl_version,
|
||||
SaveForwardInputsModel,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
FSDPTest,
|
||||
FSDPTestMultiThread,
|
||||
get_devtype,
|
||||
MLP,
|
||||
patch_reduce_scatter,
|
||||
reduce_scatter_with_assert,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfRocmVersionLessThan,
|
||||
TEST_HPU,
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(2, torch.get_device_module(device_type).device_count())
|
||||
|
||||
def _init_models_and_optims(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
param_dtype: Optional[torch.dtype],
|
||||
reduce_dtype: Optional[torch.dtype],
|
||||
use_shard_placement_fn,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
|
||||
largest_dim = -1
|
||||
largest_dim_size = -1
|
||||
for dim, dim_size in enumerate(param.shape):
|
||||
if dim_size > largest_dim_size:
|
||||
largest_dim = dim
|
||||
largest_dim_size = dim_size
|
||||
assert largest_dim >= 0, f"{param.shape}"
|
||||
return Shard(largest_dim)
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=param_dtype, reduce_dtype=reduce_dtype
|
||||
)
|
||||
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mp_policy=mp_policy,
|
||||
shard_placement_fn=shard_placement_fn,
|
||||
)
|
||||
for mlp in model:
|
||||
replicate_fn(mlp)
|
||||
replicate_fn(model)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
|
||||
return ref_model, ref_optim, model, optim
|
||||
|
||||
def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):
|
||||
use_shard_placement_fn_vals = [False]
|
||||
if self.world_size == 2:
|
||||
# For world size >2, gradient elements get reduced in different
|
||||
# orders for the baseline vs. dim-1 sharding, leading to numeric
|
||||
# differences for bf16 reduction, so only test world size 2.
|
||||
use_shard_placement_fn_vals.append(True)
|
||||
return use_shard_placement_fn_vals
|
||||
|
||||
@skipIfRocmVersionLessThan((7, 0))
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_compute_dtype(self):
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"param_dtype": [torch.bfloat16, torch.float16],
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_compute_dtype,
|
||||
)
|
||||
|
||||
def _test_compute_dtype(
|
||||
self,
|
||||
param_dtype: torch.dtype,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
use_shard_placement_fn: bool,
|
||||
):
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=None,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def assert_fn(output: torch.Tensor):
|
||||
self.assertEqual(output.dtype, param_dtype)
|
||||
|
||||
reduce_scatter = functools.partial(
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
predivide_factor, postdivide_factor, _, _ = _get_gradient_divide_factors(
|
||||
self.process_group, all_reduce_group=None, reduce_dtype=param_dtype
|
||||
)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
|
||||
for iter_idx in range(10):
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
fsdp_loss = model(inp).sum()
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
fsdp_loss.backward()
|
||||
optim.step()
|
||||
|
||||
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
ref_loss = ref_model_bf16(inp.to(param_dtype)).sum()
|
||||
ref_loss.backward()
|
||||
for param in ref_model_bf16.parameters():
|
||||
# Use reduce-scatter -> all-gather as all-reduce because for
|
||||
# world size >=4, NCCL all-reduce shows numeric differences
|
||||
# compared with NCCL reduce-scatter
|
||||
if predivide_factor is not None and predivide_factor > 1:
|
||||
param.grad.div_(predivide_factor)
|
||||
elif predivide_factor is None:
|
||||
param.grad.div_(self.world_size)
|
||||
output = torch.zeros_like(torch.chunk(param.grad, self.world_size)[0])
|
||||
dist.reduce_scatter_tensor(output, param.grad)
|
||||
dist.all_gather_into_tensor(param.grad, output)
|
||||
if postdivide_factor is not None and postdivide_factor > 1:
|
||||
param.grad.div_(postdivide_factor)
|
||||
for param_fp32, param_bf16 in zip(
|
||||
ref_model.parameters(), ref_model_bf16.parameters()
|
||||
):
|
||||
param_fp32.grad = param_bf16.grad.to(param_fp32.dtype)
|
||||
param_bf16.grad = None
|
||||
ref_optim.step() # fp32 optimizer step
|
||||
for param_fp32, param_bf16 in zip(
|
||||
ref_model.parameters(), ref_model_bf16.parameters()
|
||||
):
|
||||
param_bf16.detach().copy_(param_fp32)
|
||||
|
||||
self.assertEqual(fsdp_loss, ref_loss)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
@skipIfRocmVersionLessThan((7, 0))
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_reduce_dtype(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": [False, True],
|
||||
},
|
||||
self._test_reduce_dtype_fp32_reduce,
|
||||
)
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_reduce_dtype_bf16_reduce,
|
||||
)
|
||||
|
||||
def _test_reduce_dtype_fp32_reduce(
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
):
|
||||
if (
|
||||
self.world_size > 2
|
||||
and isinstance(reshard_after_forward, int)
|
||||
and use_shard_placement_fn
|
||||
):
|
||||
return
|
||||
param_dtype, reduce_dtype = torch.bfloat16, torch.float32
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def assert_fn(output: torch.Tensor):
|
||||
self.assertEqual(output.dtype, reduce_dtype)
|
||||
|
||||
reduce_scatter = functools.partial(
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
|
||||
for iter_idx in range(10):
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
fsdp_loss = model(inp).sum()
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
fsdp_loss.backward()
|
||||
optim.step()
|
||||
|
||||
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
ref_loss = ref_model_bf16(inp.to(param_dtype)).sum()
|
||||
ref_loss.backward()
|
||||
for param in ref_model_bf16.parameters():
|
||||
param.grad.data = param.grad.to(torch.float32)
|
||||
dist.all_reduce(param.grad) # fp32 reduction
|
||||
param.grad.div_(self.world_size)
|
||||
for param_fp32, param_bf16 in zip(
|
||||
ref_model.parameters(), ref_model_bf16.parameters()
|
||||
):
|
||||
param_fp32.grad = param_bf16.grad
|
||||
param_bf16.grad = None
|
||||
ref_optim.step() # fp32 optimizer step
|
||||
for param_fp32, param_bf16 in zip(
|
||||
ref_model.parameters(), ref_model_bf16.parameters()
|
||||
):
|
||||
param_bf16.detach().copy_(param_fp32)
|
||||
|
||||
self.assertEqual(fsdp_loss, ref_loss)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
def _test_reduce_dtype_bf16_reduce(
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
):
|
||||
param_dtype, reduce_dtype = torch.float32, torch.bfloat16
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
group = dist.distributed_c10d._get_default_group()
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def assert_fn(output: torch.Tensor):
|
||||
self.assertEqual(output.dtype, reduce_dtype)
|
||||
|
||||
reduce_scatter = functools.partial(
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
|
||||
for iter_idx in range(10):
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
fsdp_loss = model(inp).sum()
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
fsdp_loss.backward()
|
||||
optim.step()
|
||||
|
||||
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
ref_loss = ref_model(inp).sum()
|
||||
ref_loss.backward()
|
||||
for param in ref_model.parameters():
|
||||
param_grad = param.grad.to(reduce_dtype)
|
||||
# Use reduce-scatter -> all-gather to implement all-reduce
|
||||
# since for world size >2, bf16 all-reduce and reduce-scatter
|
||||
# have numeric differences
|
||||
sharded_grad = funcol.reduce_scatter_tensor(
|
||||
param_grad, scatter_dim=0, reduceOp="avg", group=group
|
||||
) # bf16 reduction
|
||||
param.grad = funcol.all_gather_tensor(
|
||||
sharded_grad, gather_dim=0, group=group
|
||||
).to(param.dtype) # upcast to fp32
|
||||
ref_optim.step() # fp32 optimizer step
|
||||
|
||||
self.assertEqual(fsdp_loss, ref_loss)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_grad_acc_with_reduce_dtype(self):
|
||||
"""
|
||||
Tests that gradient accumulation without reduce-scatter when using
|
||||
bf16 compute and fp32 reduction accumulates the unsharded gradients in
|
||||
fp32.
|
||||
"""
|
||||
self.run_subtests(
|
||||
{"reshard_after_forward": [True, False]},
|
||||
self._test_grad_acc_with_reduce_dtype,
|
||||
)
|
||||
|
||||
def _test_grad_acc_with_reduce_dtype(self, reshard_after_forward: bool):
|
||||
torch.manual_seed(42)
|
||||
param_dtype, reduce_dtype = (torch.bfloat16, torch.float32)
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=param_dtype, reduce_dtype=reduce_dtype
|
||||
)
|
||||
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
|
||||
# To emulate the mixed precision implementation where forward/backward
|
||||
# compute use bf16 and optimizer uses fp32, we maintain both an fp32
|
||||
# and a bf16 copy of the reference model
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for mlp in model:
|
||||
replicate(
|
||||
mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
replicate(
|
||||
model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def assert_fn(output: torch.Tensor):
|
||||
self.assertEqual(output.dtype, reduce_dtype)
|
||||
|
||||
reduce_scatter = functools.partial(
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
device = device_type
|
||||
# Train on the same input to avoid loss explosion
|
||||
num_microbatches = 4
|
||||
inp = torch.randn((2 * num_microbatches, 16), device=device, dtype=param_dtype)
|
||||
for iter_idx in range(10):
|
||||
microbatch_inps = torch.chunk(inp, 4)
|
||||
for microbatch_idx in range(num_microbatches):
|
||||
is_last_microbatch = microbatch_idx == num_microbatches - 1
|
||||
model.set_requires_gradient_sync(is_last_microbatch)
|
||||
model.set_reshard_after_backward(
|
||||
is_last_microbatch or reshard_after_forward
|
||||
)
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model_compute, model):
|
||||
losses.append(
|
||||
_model(microbatch_inps[microbatch_idx].detach()).sum()
|
||||
)
|
||||
self.assertEqual(losses[-1].dtype, param_dtype)
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
losses[-1].backward()
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
# Manually accumulate gradients into the base reference model
|
||||
# from the compute reference model in fp32
|
||||
for ref_param, ref_param_compute in zip(
|
||||
ref_model.parameters(), ref_model_compute.parameters()
|
||||
):
|
||||
self.assertTrue(ref_param_compute.grad is not None)
|
||||
self.assertEqual(ref_param.dtype, torch.float32)
|
||||
if ref_param.grad is not None:
|
||||
ref_param.grad += ref_param_compute.grad
|
||||
else:
|
||||
ref_param.grad = ref_param_compute.grad.to(ref_param.dtype)
|
||||
ref_param_compute.grad = None
|
||||
# Manually reduce gradients for the reference model on the last
|
||||
# microbatch to implement data parallelism
|
||||
if is_last_microbatch:
|
||||
for ref_param in ref_model.parameters():
|
||||
self.assertTrue(ref_param.grad is not None)
|
||||
dist.all_reduce(ref_param.grad)
|
||||
ref_param.grad /= self.world_size
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
ref_optim.step()
|
||||
optim.step()
|
||||
ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
# Manually copy parameters from the base reference model to the
|
||||
# compute reference model to run the optimizer step for the latter
|
||||
for ref_param, ref_param_compute in zip(
|
||||
ref_model.parameters(), ref_model_compute.parameters()
|
||||
):
|
||||
ref_param_compute.detach().copy_(ref_param)
|
||||
|
||||
|
||||
class TestReplicateMixedPrecisionCasts(FSDPTestMultiThread):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_float16_on_one_submodule(self):
|
||||
x = torch.zeros(2, 100, device=device_type)
|
||||
|
||||
# Subtest 1: use fp16 on the second child submodule -- does not require
|
||||
# any additional casting logic
|
||||
forward_inputs: dict[str, nn.Module] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs,
|
||||
cast_forward_inputs=False,
|
||||
).to(device_type)
|
||||
replicate(model.c2, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
|
||||
replicate(model)
|
||||
model(x).sum().backward()
|
||||
self.assertEqual(forward_inputs[model].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c2].dtype, torch.float16)
|
||||
|
||||
# Subtest 2: use fp16 on the second child module, where the user module
|
||||
# owns the cast
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs=forward_inputs, cast_forward_inputs=True
|
||||
).to(device_type)
|
||||
replicate(
|
||||
model.c2,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
param_dtype=torch.float16, cast_forward_inputs=False
|
||||
),
|
||||
)
|
||||
replicate(model)
|
||||
model(x).sum().backward()
|
||||
self.assertEqual(forward_inputs[model].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c1].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
|
||||
|
||||
# Subtest 3: use fp16 on the first child module and specify its output
|
||||
# dtype so that the second child module does not need to cast
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs=forward_inputs, cast_forward_inputs=False
|
||||
).to(device_type)
|
||||
replicate(
|
||||
model.c1,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
param_dtype=torch.float16, output_dtype=torch.float32
|
||||
),
|
||||
)
|
||||
replicate(model)
|
||||
model(x).sum().backward()
|
||||
self.assertEqual(forward_inputs[model].dtype, torch.float32)
|
||||
self.assertEqual(forward_inputs[model.c1].dtype, torch.float16)
|
||||
self.assertEqual(forward_inputs[model.c2].dtype, torch.float32)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_submodules_with_external_inputs(self):
|
||||
self.run_subtests(
|
||||
{"enable_submodule_cast": [False, True]},
|
||||
self._test_submodules_with_external_inputs,
|
||||
)
|
||||
|
||||
def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool):
|
||||
class ToyModule(nn.Module):
|
||||
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
|
||||
super().__init__()
|
||||
self.l = nn.Linear(100, 100)
|
||||
self.forward_inputs = forward_inputs
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
self.forward_inputs["l2_input_x"] = x
|
||||
self.forward_inputs["l2_input_y"] = y
|
||||
return self.l(x)
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(100, 100)
|
||||
self.l2 = ToyModule(forward_inputs)
|
||||
self.forward_inputs = forward_inputs
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
self.forward_inputs["model_input_x"] = x
|
||||
y = torch.ones(
|
||||
2, 100, device=device_type.type, dtype=torch.float32
|
||||
) # external input
|
||||
return self.l2(self.l1(x), y)
|
||||
|
||||
forward_inputs: dict[str, torch.Tensor] = {}
|
||||
model = ToyModel(forward_inputs).to(device_type)
|
||||
x = torch.zeros(2, 100, device=device_type.type, dtype=torch.float32)
|
||||
replicate(
|
||||
model.l2,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
param_dtype=torch.float16, cast_forward_inputs=enable_submodule_cast
|
||||
),
|
||||
)
|
||||
replicate(model, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
|
||||
model(x).sum().backward()
|
||||
|
||||
# If we enable `model.l2` to cast (as default), then `l2_input_y` gets
|
||||
# cast to fp16, and if we disable, then it says as fp32.
|
||||
self.assertEqual(forward_inputs["model_input_x"].dtype, torch.float16)
|
||||
self.assertEqual(forward_inputs["l2_input_x"].dtype, torch.float16)
|
||||
self.assertEqual(
|
||||
forward_inputs["l2_input_y"].dtype,
|
||||
torch.float16 if enable_submodule_cast else torch.float32,
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_norm_modules_bf16(self):
|
||||
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
|
||||
self._test_norm_modules(mp_policy)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_norm_modules_fp16(self):
|
||||
mp_policy = MixedPrecisionPolicy(param_dtype=torch.float16)
|
||||
self._test_norm_modules(mp_policy)
|
||||
|
||||
def _test_norm_modules(self, mp_policy: MixedPrecisionPolicy):
|
||||
def inner(model: nn.Module, x: torch.Tensor):
|
||||
# Run forward and backward to check for no type mismatch errors
|
||||
z = model(x)
|
||||
self.assertEqual(z.dtype, mp_policy.param_dtype)
|
||||
z.sum().backward()
|
||||
|
||||
# Layer norm
|
||||
model = nn.Sequential(nn.Linear(32, 32), nn.LayerNorm(32), nn.Linear(32, 32))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
inner(model, torch.randn((4, 32)))
|
||||
|
||||
# Batch norm 1D
|
||||
model = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.Linear(32, 32))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
inner(model, torch.randn((4, 32)))
|
||||
|
||||
# Batch norm 2D: error in backward from buffer dtype mismatch
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
if TEST_HPU:
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
else:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected running_mean to have type", # Error not seen on HPUs and hence it can be skipped
|
||||
):
|
||||
# Errors in batch norm 2D backward
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
|
||||
# Batch norm 2D: cast buffers down to lower precision
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
# Casting batch norm buffers to the lower precision allows backward
|
||||
model[1].running_mean = model[1].running_mean.to(mp_policy.param_dtype)
|
||||
model[1].running_var = model[1].running_var.to(mp_policy.param_dtype)
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
|
||||
# Batch norm 2D: use special mixed precision policy
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
|
||||
bn_mp_policy = MixedPrecisionPolicy(output_dtype=mp_policy.param_dtype)
|
||||
replicate(model[1], mp_policy=bn_mp_policy)
|
||||
for module in (model[0], model[2], model):
|
||||
replicate(module, mp_policy=mp_policy)
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_clamp_reduce_dtype(self):
|
||||
# Initialize the model directly in bf16
|
||||
init_dtype = torch.bfloat16
|
||||
model = nn.Sequential(
|
||||
nn.Linear(32, 32, dtype=init_dtype),
|
||||
nn.Linear(32, 32, dtype=init_dtype),
|
||||
).to(device_type.type)
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16
|
||||
)
|
||||
# Check that we did not clamp the reduce dtype
|
||||
self.assertEqual(mp_policy.reduce_dtype, torch.bfloat16)
|
||||
for module in model:
|
||||
replicate((module), mp_policy=mp_policy)
|
||||
replicate(model, mp_policy=mp_policy)
|
||||
|
||||
# Check that the reduce-scatter runs in bf16 even after we change the
|
||||
# model from bf16 to fp32
|
||||
model.to(torch.float32)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def assert_fn(output: torch.Tensor):
|
||||
self.assertEqual(output.dtype, torch.bfloat16)
|
||||
|
||||
reduce_scatter = functools.partial(
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
inp = torch.randn((4, 32), device=device_type.type)
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_dataclass_input(self):
|
||||
@dataclasses.dataclass
|
||||
class Input:
|
||||
x: torch.Tensor
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._layer = nn.Linear(10, 10)
|
||||
|
||||
def forward(self, input: Input):
|
||||
return self._layer(input.x)
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
torch.bfloat16, torch.bfloat16, torch.bfloat16, True
|
||||
)
|
||||
model = Model()
|
||||
inp = Input(torch.randn(2, 10).cuda())
|
||||
|
||||
replicate(model, mp_policy=mp_policy)
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -5,6 +5,7 @@ import copy
|
||||
import functools
|
||||
import itertools
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
@ -17,8 +18,20 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
_CHECKPOINT_PREFIX,
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, OffloadPolicy
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
OffloadPolicy,
|
||||
register_fsdp_forward_method,
|
||||
)
|
||||
from torch.distributed.tensor import DTensor, init_device_mesh
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
@ -26,6 +39,7 @@ from torch.testing._internal.common_fsdp import (
|
||||
FSDPTest,
|
||||
FSDPTestMultiThread,
|
||||
MLP,
|
||||
MLPStack,
|
||||
patch_all_gather,
|
||||
patch_reduce_scatter,
|
||||
)
|
||||
@ -842,5 +856,385 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
|
||||
class TestReplicateGradientAccumulation(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_gradient_accumulation(self):
|
||||
"""
|
||||
Tests gradient accumulation with/without gradient reduction and
|
||||
with/without resharding after backward.
|
||||
"""
|
||||
|
||||
shard_size, replicate_size = 1, self.world_size
|
||||
meshes = init_device_mesh(
|
||||
device_type.type,
|
||||
(replicate_size, shard_size),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"mesh": [meshes],
|
||||
"reshard_after_forward": [True, False],
|
||||
# "all": disable reduce-scatter for all modules
|
||||
# "root_only": disable reduce-scatter for root's linear only
|
||||
# "some_mlps": disable reduce-scatter for some MLPs
|
||||
"mode": ["all", "root_only", "some_mlps"],
|
||||
"reshard_after_backward": [False, True],
|
||||
"offload_policy": [OffloadPolicy(), CPUOffloadPolicy()],
|
||||
# For HSDP only:
|
||||
# `True`: reduce-scatter only (no all-reduce) each microbatch
|
||||
# until the last microbatch
|
||||
# `False`: neither reduce-scatter nor all-reduce each
|
||||
# microbatch until the last microbatch
|
||||
"reduce_scatter_only": [False, True],
|
||||
},
|
||||
self._test_gradient_accumulation,
|
||||
)
|
||||
|
||||
def _test_gradient_accumulation(
|
||||
self,
|
||||
mesh: DeviceMesh,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
mode: str,
|
||||
reshard_after_backward: bool,
|
||||
offload_policy: OffloadPolicy,
|
||||
reduce_scatter_only: bool, # for HSDP
|
||||
):
|
||||
if (
|
||||
(
|
||||
not reshard_after_backward
|
||||
and (reshard_after_forward is not False or mode == "some_mlps")
|
||||
)
|
||||
or (
|
||||
isinstance(offload_policy, CPUOffloadPolicy)
|
||||
and reshard_after_forward is not True
|
||||
)
|
||||
or (
|
||||
mesh.ndim != 2
|
||||
) # may eventually need to change once decision on device mesh is made
|
||||
):
|
||||
return # skip since not common or applicable
|
||||
|
||||
torch.manual_seed(42)
|
||||
batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3)
|
||||
if mode == "some_mlps":
|
||||
num_mlps_to_disable_reduce_scatter = 2
|
||||
modules = [nn.Linear(lin_dim, lin_dim)]
|
||||
modules.extend(MLP(lin_dim) for _ in range(num_mlps))
|
||||
model = nn.Sequential(*modules)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
device_mesh=mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
for mlp in model[1:]:
|
||||
replicate_fn(mlp)
|
||||
replicate_fn(model) # root gets the 1st linear
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
def set_grad_sync_flag(
|
||||
module: nn.Module, is_last_microbatch: bool, recurse: bool = True
|
||||
):
|
||||
if reduce_scatter_only:
|
||||
module.set_requires_all_reduce(is_last_microbatch, recurse=recurse)
|
||||
else:
|
||||
module.set_requires_gradient_sync(is_last_microbatch, recurse=recurse)
|
||||
|
||||
def set_backward_flags(_model: nn.Module, is_last_microbatch: bool):
|
||||
if mode == "all":
|
||||
set_grad_sync_flag(_model, is_last_microbatch)
|
||||
if not reshard_after_backward:
|
||||
_model.set_reshard_after_backward(is_last_microbatch)
|
||||
elif mode == "some_mlps":
|
||||
for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]:
|
||||
set_grad_sync_flag(mlp, is_last_microbatch)
|
||||
if not reshard_after_backward:
|
||||
mlp.set_reshard_after_backward(is_last_microbatch)
|
||||
elif mode == "root_only":
|
||||
set_grad_sync_flag(model, is_last_microbatch, recurse=False)
|
||||
if not reshard_after_backward:
|
||||
model.set_reshard_after_backward(is_last_microbatch, recurse=False)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
for iter_idx in range(5):
|
||||
comm_count_list = []
|
||||
|
||||
for microbatch_idx in range(num_microbatches):
|
||||
is_last_microbatch = microbatch_idx == num_microbatches - 1
|
||||
set_backward_flags(model, is_last_microbatch)
|
||||
inp = torch.randn(batch_size, lin_dim, device=device_type.type)
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model, model):
|
||||
with CommDebugMode() as comm_mode:
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
comm_count_list.append(comm_mode.get_comm_counts())
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
comm_counts = defaultdict(int)
|
||||
for comm_count_dict in comm_count_list:
|
||||
for collective, count in comm_count_dict.items():
|
||||
comm_counts[collective] += count
|
||||
|
||||
all_gather_count = comm_counts[c10d_ops._allgather_base_]
|
||||
# reduce_scatter_count = comm_counts[c10d_ops._reduce_scatter_base_]
|
||||
all_reduce_count = comm_counts[c10d_ops.allreduce_]
|
||||
|
||||
# Expect one reduce-scatter per MLP plus one for the root's linear
|
||||
# on the last microbatch
|
||||
# expected_reduce_scatter_count = 0
|
||||
expected_all_reduce_count = num_mlps + 1
|
||||
|
||||
if mode == "some_mlps":
|
||||
# Expect additional reduce-scatters for non-disabled MLPs and
|
||||
# the root's linear
|
||||
expected_all_reduce_count += (
|
||||
num_mlps - num_mlps_to_disable_reduce_scatter + 1
|
||||
) * (num_microbatches - 1)
|
||||
elif mode == "root_only":
|
||||
# Expect additional reduce-scatters for all MLPs
|
||||
expected_all_reduce_count += (num_mlps) * (num_microbatches - 1)
|
||||
|
||||
# self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
|
||||
self.assertEqual(all_reduce_count, expected_all_reduce_count)
|
||||
|
||||
# Expect one all-gather per MLP plus one for the root's linear in
|
||||
# the first microbatch's forward
|
||||
expected_all_gather_count = 0
|
||||
|
||||
self.assertEqual(all_gather_count, expected_all_gather_count)
|
||||
|
||||
for param in ref_model.parameters():
|
||||
if param.grad is not None:
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
for _optim in (optim, ref_optim):
|
||||
_optim.step()
|
||||
# When `set_to_none=False`, we are exercising mixing
|
||||
# gradient accumulation with and without communication
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_1f1b_microbatching(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"use_explicit_unshard": [False, True],
|
||||
"reshard_after_backward": [False, True],
|
||||
},
|
||||
self._test_1f1b_microbatching,
|
||||
)
|
||||
|
||||
def _test_1f1b_microbatching(
|
||||
self, use_explicit_unshard: bool, reshard_after_backward: bool
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(dropout_p=0.0)
|
||||
model = Transformer(model_args)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
replicate(module, reshard_after_forward=False)
|
||||
replicate(model, reshard_after_forward=False)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
num_microbatches = 3
|
||||
local_batch_size = 2
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inps = [
|
||||
torch.randint(
|
||||
0,
|
||||
model_args.vocab_size,
|
||||
(local_batch_size, 16),
|
||||
device=device_type.type,
|
||||
)
|
||||
for _ in range(num_microbatches)
|
||||
]
|
||||
|
||||
# Before pipelining, we may prefer to issue all all-gathers ahead of
|
||||
# time to increase overlap opportunity at no difference in parameter
|
||||
# memory usage since we do not reshard after forward
|
||||
if use_explicit_unshard:
|
||||
for module in model.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
module.unshard(async_op=True)
|
||||
|
||||
# Emulate the 1f1b pipeline schedule and only reduce gradients on the
|
||||
# last microbatch
|
||||
losses: list[torch.Tensor] = []
|
||||
ref_losses: list[torch.Tensor] = []
|
||||
for inp_idx, inp in enumerate(inps):
|
||||
is_last_microbatch = inp_idx == num_microbatches - 1
|
||||
model.set_requires_gradient_sync(is_last_microbatch)
|
||||
model.set_is_last_backward(is_last_microbatch)
|
||||
if not reshard_after_backward:
|
||||
model.set_reshard_after_backward(is_last_microbatch)
|
||||
losses.append(model(inp).sum())
|
||||
losses[-1].backward()
|
||||
ref_losses.append(ref_model(inp).sum())
|
||||
ref_losses[-1].backward()
|
||||
for param in ref_model.parameters():
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||
|
||||
for loss, ref_loss in zip(losses, ref_losses):
|
||||
self.assertEqual(loss, ref_loss)
|
||||
optim.step()
|
||||
ref_optim.step()
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
|
||||
class TestReplicateCustomForwardMethod(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(torch.get_device_module(device_type).device_count(), 2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_register_fsdp_forward_method(self):
|
||||
class VisionTransformer(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.patch_proj = nn.Conv2d(3, 1024, kernel_size=14, stride=14)
|
||||
|
||||
def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
return self.patch_proj(imgs).flatten(2).transpose(1, 2)
|
||||
|
||||
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_features(imgs).sum(dim=1)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.vit, self.projector = VisionTransformer(), nn.Linear(1024, 256)
|
||||
|
||||
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
# Run `vit.forward_features`, which is not `forward`!
|
||||
patch_embeddings = self.vit.forward_features(imgs)
|
||||
return self.projector(patch_embeddings)
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = Model()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate(model.vit)
|
||||
replicate(model.projector)
|
||||
replicate(model)
|
||||
register_fsdp_forward_method(model.vit, "forward_features")
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn(4, 3, 224, 224, device=device_type.type)
|
||||
ref_loss = ref_model(inp).sum()
|
||||
loss = model(inp).sum()
|
||||
self.assertEqual(ref_loss, loss)
|
||||
ref_loss.backward()
|
||||
loss.backward()
|
||||
for param in ref_model.parameters():
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
|
||||
class TestReplicateTPTraining(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
def init_global_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh(
|
||||
device_type.type,
|
||||
(2, 1, 2),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(8)
|
||||
def test_replicate_tp(self):
|
||||
global_mesh = self.init_global_mesh()
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
"mlp_dim": [3, 5, 16, 17],
|
||||
"foreach": [False],
|
||||
},
|
||||
functools.partial(self._test_replicate_tp, global_mesh),
|
||||
)
|
||||
|
||||
def _test_replicate_tp(
|
||||
self,
|
||||
global_mesh: DeviceMesh,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
mlp_dim: int,
|
||||
foreach: bool,
|
||||
):
|
||||
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
|
||||
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = MLPStack(mlp_dim)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
|
||||
|
||||
parallelize_plan = {
|
||||
# Pass `use_local_output=False` to keep as DTensor to preserve
|
||||
# uneven activation dims
|
||||
"0.in_proj": ColwiseParallel(use_local_output=False),
|
||||
"0.out_proj": RowwiseParallel(use_local_output=False),
|
||||
"1.in_proj": ColwiseParallel(use_local_output=False),
|
||||
"1.out_proj": RowwiseParallel(use_local_output=False),
|
||||
"2.in_proj": ColwiseParallel(use_local_output=False),
|
||||
"2.out_proj": (RowwiseParallel()),
|
||||
}
|
||||
|
||||
model = parallelize_module(model, tp_mesh, parallelize_plan)
|
||||
|
||||
for module in model:
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
continue
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(module)
|
||||
replicate(module, device_mesh=dp_mesh)
|
||||
replicate(model, device_mesh=dp_mesh)
|
||||
|
||||
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
|
||||
# strided-sharded layers.
|
||||
for ref_p, p in zip(ref_model.parameters(), model.parameters()):
|
||||
self.assertIsInstance(p, DTensor)
|
||||
self.assertEqual(ref_p, p.full_tensor())
|
||||
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
|
||||
|
||||
torch.manual_seed(42 + dp_pg.rank() + 1)
|
||||
device = device_type
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randn((8, mlp_dim), device=device)
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model, model):
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
|
||||
for param in ref_model.parameters():
|
||||
if param.grad is not None:
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||
|
||||
for _optim in (ref_optim, optim):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
_optim.step()
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
for _, p in model.named_parameters():
|
||||
self.assertIsInstance(p, DTensor)
|
||||
self.assertEqual(p.device_mesh.ndim, 3)
|
||||
self.assertEqual(len(p.placements), 3)
|
||||
self.assertEqual(
|
||||
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
158
test/distributed/tensor/test_dtensor_export.py
Normal file
158
test/distributed/tensor/test_dtensor_export.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, Replicate
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
requires_cuda,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.mlp_0 = MLPModule(device)
|
||||
self.mlp_1 = MLPModule(device)
|
||||
|
||||
def forward(self, input):
|
||||
return self.mlp_1(self.mlp_0(input))
|
||||
|
||||
|
||||
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
|
||||
# needed for stric export
|
||||
torch.utils._pytree.register_constant(DTensorSpec)
|
||||
|
||||
# install_free_tensors is required for dynamo to work
|
||||
with torch._dynamo.config.patch(
|
||||
install_free_tensors=True, inline_inbuilt_nn_modules=True
|
||||
):
|
||||
with torch._export.utils._disable_aten_to_metadata_assertions():
|
||||
ep = torch.export.export(model, (inputs,), strict=True)
|
||||
|
||||
# joint_gm produced here is missing the backward region, due to incompatiblility
|
||||
# between ep.module() and aot_export_joint_with_descriptors.
|
||||
# Keeping this here to show the issue.
|
||||
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
|
||||
|
||||
|
||||
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = _dynamo_graph_capture_for_export(model)(inputs)
|
||||
return aot_export_joint_with_descriptors_alone(gm, inputs)
|
||||
|
||||
|
||||
def aot_export_joint_with_descriptors_alone(model, inputs):
|
||||
with contextlib.ExitStack() as stack:
|
||||
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||
stack,
|
||||
model,
|
||||
(inputs,),
|
||||
)
|
||||
return joint_with_descriptors.graph_module
|
||||
|
||||
|
||||
def _count_op(gm, target):
|
||||
return sum(1 for node in gm.graph.nodes if node.target == target)
|
||||
|
||||
|
||||
@requires_cuda
|
||||
class DTensorExportTest(TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.world_size = 8
|
||||
store = FakeStore()
|
||||
dist.init_process_group(
|
||||
backend="fake", rank=0, world_size=self.world_size, store=store
|
||||
)
|
||||
self.device_type = "cuda"
|
||||
|
||||
def _run_test(self, export_fn):
|
||||
dp_degree = 2
|
||||
tp_degree = self.world_size // dp_degree
|
||||
|
||||
# 2-D mesh is [dp, tp]
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type,
|
||||
mesh_shape=(dp_degree, tp_degree),
|
||||
mesh_dim_names=["dp", "tp"],
|
||||
)
|
||||
|
||||
model = SimpleModel(self.device_type)
|
||||
parallelize_plan = {
|
||||
"mlp_0.net1": ColwiseParallel(),
|
||||
"mlp_0.net2": RowwiseParallel(),
|
||||
"mlp_1.net1": ColwiseParallel(),
|
||||
"mlp_1.net2": RowwiseParallel(),
|
||||
}
|
||||
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
|
||||
|
||||
inputs = torch.rand(20, 10, device=self.device_type)
|
||||
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
|
||||
|
||||
joint_gm = export_fn(tp_model, inputs)
|
||||
fw_gm, bw_gm = min_cut_rematerialization_partition(
|
||||
joint_gm, None, num_fwd_outputs=1
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
_count_op(joint_gm, torch.ops._c10d_functional.all_reduce.default),
|
||||
3,
|
||||
)
|
||||
self.assertTrue(
|
||||
_count_op(fw_gm, torch.ops._c10d_functional.all_reduce.default),
|
||||
2,
|
||||
)
|
||||
self.assertTrue(
|
||||
_count_op(bw_gm, torch.ops._c10d_functional.all_reduce.default),
|
||||
1,
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"export_fn",
|
||||
[
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
aot_export_joint_with_descriptors_alone,
|
||||
],
|
||||
)
|
||||
def test_export_parallelize_module_with_dtensor_input(
|
||||
self,
|
||||
export_fn,
|
||||
):
|
||||
self._run_test(export_fn)
|
||||
|
||||
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
|
||||
# is producing a joint graph with backward region missing
|
||||
@unittest.expectedFailure
|
||||
def test_strict_export_parallelize_module_with_dtensor_input(self):
|
||||
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
@ -355,7 +356,7 @@ class RedistributeTest(DTensorTestBase):
|
||||
replica_spec = Replicate()
|
||||
# 1) test replicate -> partial forward
|
||||
replica_tensor = distribute_tensor(local_tensor, device_mesh, [replica_spec])
|
||||
with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Can not redistribute"):
|
||||
partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec])
|
||||
|
||||
from torch.distributed.tensor._redistribute import Redistribute
|
||||
@ -619,6 +620,38 @@ class RedistributeTest(DTensorTestBase):
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
self.assertEqual(out.placements, [Shard(0), dst])
|
||||
|
||||
@with_comms
|
||||
def test_redistribute_to_partial(self):
|
||||
mesh = init_device_mesh(self.device_type, (2, 2))
|
||||
|
||||
tensor = torch.randn(12, 8, device=self.device_type)
|
||||
|
||||
test_cases = [
|
||||
# Partial to Partial is allowed
|
||||
([Partial(), Shard(0)], [Partial(), Shard(0)], True),
|
||||
([Partial(), Shard(0)], [Partial(), Shard(1)], True),
|
||||
([Shard(0), Partial()], [Replicate(), Partial()], True),
|
||||
([Shard(0), Partial("prod")], [Replicate(), Partial("prod")], True),
|
||||
# Non-Partial to Partial is NOT allowed
|
||||
([Shard(0), Replicate()], [Shard(0), Partial()], False),
|
||||
([Shard(0), Replicate()], [Replicate(), Partial()], False),
|
||||
([Shard(0), Shard(1)], [Replicate(), Partial()], False),
|
||||
# Partial to partial is allowed, if only the reduction ops is the same
|
||||
([Shard(0), Partial("prod")], [Replicate(), Partial("sum")], False),
|
||||
]
|
||||
|
||||
for src, dst, allow in test_cases:
|
||||
dt = DTensor.from_local(tensor, mesh, src)
|
||||
raise_context = (
|
||||
self.assertRaisesRegex(RuntimeError, "Can not redistribute")
|
||||
if not allow
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
with raise_context:
|
||||
out = dt.redistribute(mesh, dst)
|
||||
self.assertEqual(out.placements, dst)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(RedistributeTest)
|
||||
|
||||
|
||||
757
test/distributed/test_aten_comm_compute_reordering.py
Normal file
757
test/distributed/test_aten_comm_compute_reordering.py
Normal file
@ -0,0 +1,757 @@
|
||||
# flake8: noqa: B950
|
||||
# Owner(s): ["module: inductor"]
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.logging
|
||||
import torch._dynamo.test_case
|
||||
|
||||
# for some reason importing functional collectives after dynamo breaks collectives handling!
|
||||
import torch.distributed._functional_collectives as _functional_collectives
|
||||
from torch._C import FileCheck
|
||||
from torch._dynamo.utils import counters, same
|
||||
from torch._inductor.utils import run_and_get_triton_code
|
||||
from torch.testing._internal.common_distributed import (
|
||||
_dynamo_dist_per_rank_init,
|
||||
at_least_x_gpu,
|
||||
DynamoDistributedMultiProcTestCase,
|
||||
requires_accelerator_dist_backend,
|
||||
)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
import functools
|
||||
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
||||
def estimate_aten_runtime(fx_node, compute_multiplier=1.0):
|
||||
# for tests, assume a matmul can hide a single collective
|
||||
if "c10" in str(fx_node.target):
|
||||
return 1.0
|
||||
elif fx_node.target == aten.mm.default:
|
||||
return compute_multiplier
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
device_type = str(get_devtype())
|
||||
|
||||
|
||||
def apply_reordering_and_get_graph(graph, out_li) -> None:
|
||||
gm = graph.owning_module
|
||||
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
|
||||
|
||||
schedule_overlap_bucketing(gm)
|
||||
gm.graph.lint()
|
||||
out_li.append(str(gm.graph))
|
||||
|
||||
|
||||
def run_and_get_aten_graph(fn, *inputs):
|
||||
li = []
|
||||
apply = functools.partial(apply_reordering_and_get_graph, out_li=li)
|
||||
with torch._inductor.config.patch(post_grad_custom_post_pass=apply):
|
||||
out = fn(*inputs)
|
||||
|
||||
return out, li[0]
|
||||
|
||||
|
||||
def get_patches():
|
||||
return {
|
||||
"test_configs.estimate_aten_runtime": estimate_aten_runtime,
|
||||
"reorder_for_locality": False,
|
||||
"reorder_for_compute_comm_overlap_passes": [],
|
||||
"compile_threads": 1,
|
||||
"force_disable_caches": True,
|
||||
}
|
||||
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
"""
|
||||
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
|
||||
|
||||
Note: these tests are a fork of test/distributed/test_compute_comm_reordering.py
|
||||
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.utils.counters.clear()
|
||||
|
||||
def get_world_trs(self):
|
||||
return {
|
||||
"tag": "",
|
||||
"ranks": list(range(self.world_size)),
|
||||
"group_size": self.world_size,
|
||||
}
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
# hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
|
||||
# works around issue with skipif<2 and workers with unpredictable #s gpu
|
||||
return 2
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_patches())
|
||||
def test_sink_waits(self):
|
||||
def func(a):
|
||||
ar = _functional_collectives.all_reduce(a, "sum", "0")
|
||||
b = torch.matmul(a, a)
|
||||
return torch.matmul(ar, b)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
|
||||
out, aten_graph_str = run_and_get_aten_graph(torch.compile(func), inputs)
|
||||
|
||||
# Verify that the wait_tensor is sinked below the 1st matmul but
|
||||
# above the 2nd matmul.
|
||||
(
|
||||
FileCheck()
|
||||
.check("all_reduce.default")
|
||||
.check("aten.mm.default")
|
||||
.check("wait_tensor.default")
|
||||
.check("aten.mm.default")
|
||||
.run(aten_graph_str)
|
||||
)
|
||||
correct = func(inputs)
|
||||
self.assertTrue(same(out, correct))
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
|
||||
|
||||
@torch._inductor.config.patch(get_patches())
|
||||
def test_raise_comms(self):
|
||||
def func(a):
|
||||
b = torch.matmul(a, a)
|
||||
c = torch.relu(b)
|
||||
d = torch.matmul(c, c)
|
||||
e = _functional_collectives.all_reduce((b + 1), "sum", "0")
|
||||
return torch.matmul(d, e)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
compiled = torch.compile(func)
|
||||
out, aten_graph_str = run_and_get_aten_graph(torch.compile(func), inputs)
|
||||
# Verify that the all_reduce_ has been raised above the 2nd matmul
|
||||
# but below the 1st matmul. Note that the all_reduce_ directly
|
||||
# writes to the output buffer of the 1st matmul, which is an input
|
||||
# to the first relu. Therefore, the all_reduce_ should be scheduled
|
||||
# after the first relu.
|
||||
(
|
||||
FileCheck()
|
||||
.check("aten.mm")
|
||||
.check("all_reduce.default")
|
||||
.check("aten.mm")
|
||||
.check("wait_tensor.default")
|
||||
.check("aten.mm")
|
||||
.run(aten_graph_str)
|
||||
)
|
||||
out = compiled(inputs)
|
||||
correct = func(inputs)
|
||||
self.assertTrue(same(out, correct))
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
|
||||
|
||||
@torch._inductor.config.patch(get_patches())
|
||||
def test_sink_waits_raise_comms(self):
|
||||
def func(a, *, tag, ranks, group_size):
|
||||
b = torch.matmul(a, a)
|
||||
c = torch.relu(b)
|
||||
d = torch.matmul(c, c)
|
||||
e = _functional_collectives.all_reduce(b, "sum", "0")
|
||||
f = torch.relu(d)
|
||||
g = torch.matmul(f, f)
|
||||
return torch.mm(e, g)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(
|
||||
4, 4, dtype=torch.float, device=device_type
|
||||
) # + self.rank
|
||||
kwargs = self.get_world_trs()
|
||||
func = functools.partial(func, **kwargs)
|
||||
compiled = torch.compile(func)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
|
||||
# Things to verify:
|
||||
# - The all_reduce_ and its prologue should be raised above the 2nd
|
||||
# matmul but below the 1st matmul.
|
||||
# - The wait_tensor should be sinked below the 3rd matmul but above
|
||||
# the 4th matmul.
|
||||
|
||||
self.assertExpectedInline(
|
||||
aten_graph_str,
|
||||
"""\
|
||||
graph():
|
||||
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
|
||||
%mm : [num_users=2] = call_function[target=torch.ops.aten.mm.default](args = (%arg0_1, %arg0_1), kwargs = {})
|
||||
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mm,), kwargs = {})
|
||||
%all_reduce : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%mm, sum, 0), kwargs = {})
|
||||
%mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%relu, %relu), kwargs = {})
|
||||
%relu_1 : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mm_1,), kwargs = {})
|
||||
%mm_2 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%relu_1, %relu_1), kwargs = {})
|
||||
%wait_tensor : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce,), kwargs = {})
|
||||
%mm_3 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%wait_tensor, %mm_2), kwargs = {})
|
||||
return (mm_3,)""",
|
||||
)
|
||||
|
||||
# Note: this triggered an all_reduce_ bug
|
||||
correct = func(inputs, **self.get_world_trs())
|
||||
self.assertTrue(same(out, correct))
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
|
||||
|
||||
@torch._inductor.config.patch(get_patches())
|
||||
def test_reorder_compute_for_overlap_mul(self):
|
||||
def func(a, *, tag, ranks, group_size):
|
||||
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
|
||||
g = torch.matmul(a, a)
|
||||
c = torch.relu(a)
|
||||
d = torch.matmul(c, c)
|
||||
f = d * c * ar
|
||||
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
|
||||
e = torch.matmul(d + ar + fr, g)
|
||||
return (e,)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
func_c = functools.partial(func, **self.get_world_trs())
|
||||
compiled = torch.compile(func_c)
|
||||
out_c, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
|
||||
# Note: because we have given collectives and mms equal estimation,
|
||||
# we overlap each collective with a single mm.
|
||||
# Same schedule as in test_reorder_compute_for_overlap_custom_runtime_estimation
|
||||
# although there is an exposed collective
|
||||
(
|
||||
FileCheck()
|
||||
.check("all_reduce.default")
|
||||
.check("aten.mm")
|
||||
.check("aten.mm")
|
||||
.check("wait_tensor.default")
|
||||
.check("aten.mul")
|
||||
.check("all_reduce.default")
|
||||
.check("wait_tensor.default")
|
||||
.check("aten.mm")
|
||||
.run(aten_graph_str)
|
||||
)
|
||||
correct = func(inputs, **self.get_world_trs())
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 1)
|
||||
self.assertTrue(same(out_c, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skipIfRocm
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
@unittest.skipIf(True, "Logic not yet implemented")
|
||||
@torch._inductor.config.patch(get_patches())
|
||||
def test_grouped_scheduler_node(self):
|
||||
def func(a, *, tag, ranks, group_size):
|
||||
add = a + a
|
||||
div = add / a
|
||||
ar = _functional_collectives.all_reduce(div, "sum", ranks, tag)
|
||||
# Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op,
|
||||
# but here in this unit test, we intentionally put `add`, `div` and `ar` computation
|
||||
# into a GroupedSchedulerNode, which prevents them from being fused with any other ops.
|
||||
mul = a * a
|
||||
mm = torch.matmul(mul, ar)
|
||||
return (mm,)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
# Expectations:
|
||||
# 1. `add = a + a` and `div = add / a` are still fused, which means fusion
|
||||
# still happens among nodes within a GroupedSchedulerNode.
|
||||
# 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within
|
||||
# GroupedSchedulerNode and thus are prevented from being fused with any outside ops.
|
||||
FileCheck().check("triton_poi_fused_add_all_reduce_div_0.").check(
|
||||
"_c10d_functional.all_reduce_."
|
||||
).check("triton_poi_fused_mul_1.").run(code)
|
||||
out = compiled(inputs, **self.get_world_trs())
|
||||
correct = func(inputs, **self.get_world_trs())
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_patches())
|
||||
def test_inductor_default_comms_ordering(self):
|
||||
pg_info = self.get_world_trs()
|
||||
tag = pg_info["tag"]
|
||||
ranks = pg_info["ranks"]
|
||||
group_size = pg_info["group_size"]
|
||||
|
||||
g1 = torch.ones(10, 10, device=device_type)
|
||||
g2 = torch.ones(11, 11, device=device_type)
|
||||
g3 = torch.ones(12, 12, device=device_type)
|
||||
|
||||
@torch.compile
|
||||
def fn(g1, g2, g3):
|
||||
handle1 = torch.ops.c10d_functional.all_reduce(
|
||||
g1, "avg", tag, ranks, group_size
|
||||
)
|
||||
handle2 = torch.ops.c10d_functional.all_reduce(
|
||||
g2, "avg", tag, ranks, group_size
|
||||
)
|
||||
handle3 = torch.ops.c10d_functional.all_reduce(
|
||||
g3, "avg", tag, ranks, group_size
|
||||
)
|
||||
|
||||
# wait on them in a different order
|
||||
grad3 = torch.ops._c10d_functional.wait_tensor.default(handle3)
|
||||
grad2 = torch.ops._c10d_functional.wait_tensor.default(handle2)
|
||||
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
|
||||
return grad3, grad2, grad1
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, self.backend(device_type), fake_pg=True
|
||||
):
|
||||
# all_reduces remain in order!
|
||||
# note: this isnt actually invariant of pass currently..
|
||||
# but we should keep collectives stable without reordering opportunities
|
||||
|
||||
_, code = run_and_get_aten_graph(fn, g1, g2, g3)
|
||||
|
||||
FileCheck().check("all_reduce").check_same("arg0_1").check(
|
||||
"all_reduce"
|
||||
).check_same("arg1_1").check("all_reduce").check_same("arg2_1").run(code)
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 3)
|
||||
# these have no overlap opportunities
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_bad_exposed"], 0)
|
||||
|
||||
|
||||
def get_bucket_patches(compute_multiplier=1.0):
|
||||
estimate_aten_runtime_part = functools.partial(
|
||||
estimate_aten_runtime, compute_multiplier=compute_multiplier
|
||||
)
|
||||
return {
|
||||
"test_configs.estimate_aten_runtime": estimate_aten_runtime_part,
|
||||
"test_configs.aten_fx_overlap_preserving_bucketing": True,
|
||||
"reorder_for_locality": False,
|
||||
"reorder_for_compute_comm_overlap_passes": [],
|
||||
"compile_threads": 1,
|
||||
"force_disable_caches": True,
|
||||
}
|
||||
|
||||
|
||||
class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_basic_all_gather_bucketing(self):
|
||||
"""Test that independent all_gather operations get bucketed together."""
|
||||
|
||||
def func(a, b, c, *, ranks):
|
||||
# Three independent all_gathers that should be bucketed
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) + 3
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) + 4
|
||||
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks) + 5
|
||||
return ag1 + ag2 + ag3
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs_a = (
|
||||
torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
)
|
||||
inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2
|
||||
inputs_c = torch.ones(4, 4, dtype=torch.float, device=device_type) * 3
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(
|
||||
compiled, inputs_a, inputs_b, inputs_c
|
||||
)
|
||||
|
||||
# Should see a single bucketed all_gather
|
||||
FileCheck().check_count(
|
||||
"torch.ops._c10d_functional.all_gather_into_tensor", 1, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
correct = func(inputs_a, inputs_b, inputs_c, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_reduce_scatter_bucketing(self):
|
||||
"""Test bucketing of reduce_scatter operations."""
|
||||
|
||||
def func(a, b, c):
|
||||
rs1 = _functional_collectives.reduce_scatter_tensor(a, "sum", 0, "0")
|
||||
rs2 = _functional_collectives.reduce_scatter_tensor(b, "sum", 0, "0")
|
||||
rs3 = _functional_collectives.reduce_scatter_tensor(c, "sum", 0, "0")
|
||||
return torch.cat([rs1, rs2, rs3])
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs_a = torch.ones(8, 4, dtype=torch.float, device=device_type)
|
||||
inputs_b = torch.ones(8, 4, dtype=torch.float, device=device_type) * 2
|
||||
inputs_c = torch.ones(8, 4, dtype=torch.float, device=device_type) * 3
|
||||
|
||||
out, aten_graph_str = run_and_get_aten_graph(
|
||||
torch.compile(func), inputs_a, inputs_b, inputs_c
|
||||
)
|
||||
|
||||
# Should bucket reduce_scatter ops
|
||||
FileCheck().check_count(
|
||||
"torch.ops._c10d_functional.reduce_scatter_tensor", 1, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
# TODO: debug - on ci this fails.
|
||||
# correct = func(inputs_a, inputs_b, inputs_c)
|
||||
# self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_no_bucketing_with_dependent_hiding_nodes(self):
|
||||
"""Test that collectives with dependent hiding nodes don't get bucketed."""
|
||||
|
||||
def func(a, b, *, ranks):
|
||||
# ag1 could be hidden by mm1
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
mm1 = torch.matmul(a, a)
|
||||
|
||||
# ag2 can be hidden by mm2, but mm2 depends on ag1's result
|
||||
# ag2 start
|
||||
mm2 = torch.matmul(ag1[:4], b)
|
||||
# ag2 end
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
|
||||
return ag1.sum() * ag2.sum() * mm1 * mm2
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs_a = torch.ones(4, 4, dtype=torch.float, device=device_type)
|
||||
inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type)
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs_a, inputs_b)
|
||||
|
||||
# mm2 depends on ag1, so if mm2 is to hide ag2, we can't bucket ag1 and ag2
|
||||
# because that would create a dependency issue, even though we could bucket them
|
||||
FileCheck().check_count(
|
||||
"torch.ops._c10d_functional.all_gather_into_tensor", 2, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
correct = func(inputs_a, inputs_b, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_no_bucketing_when_collective_depends_on_hiding_node(self):
|
||||
"""Test that collectives don't get bucketed when one depends on another's hiding node."""
|
||||
|
||||
def func(a, *, ranks):
|
||||
# ag1 hidden by mm1
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
mm1 = torch.matmul(a, a)
|
||||
|
||||
# ag2 depends on mm1 (which hides ag1)
|
||||
b = mm1 * 2
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
|
||||
return ag1.sum() * ag2.sum() * mm1
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type)
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
|
||||
|
||||
# ag2 depends on mm1 (ag1's hiding node), so they can't be bucketed
|
||||
FileCheck().check_count(
|
||||
"_c10d_functional.all_gather_into_tensor", 2, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
correct = func(inputs, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches(2.0))
|
||||
def test_bucketing_wait_sink(self):
|
||||
"""Test that 4 independent all-gathers split bucketed."""
|
||||
|
||||
def func(a, b, c, d, *, ranks):
|
||||
# All 4 all-gathers are independent - COULD be bucketed together
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
|
||||
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
|
||||
|
||||
# First compute - can hide ag1 and ag2
|
||||
e = a * 5
|
||||
mm1 = torch.matmul(e, e.T)
|
||||
|
||||
# Second compute - can hide ag3 and ag4
|
||||
f = b * 6
|
||||
mm2 = torch.matmul(f, f.T)
|
||||
|
||||
# Use all collective results
|
||||
result = (
|
||||
ag1.sum() * 1.1
|
||||
+ ag2.sum() * 1.2
|
||||
+ ag3.sum() * 1.3
|
||||
+ ag4.sum() * 1.4
|
||||
+ mm1.sum()
|
||||
+ mm2.sum()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
|
||||
|
||||
# The 4 all gathers can be bucketed, and their waits should be sunk below the mms
|
||||
FileCheck().check_count(
|
||||
"_c10d_functional.all_gather_into_tensor", 1, exactly=True
|
||||
).check_count("ops.aten.mm", 2, exactly=True).check(
|
||||
"_c10d_functional.wait_tensor"
|
||||
).run(aten_graph_str)
|
||||
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches(2.0))
|
||||
def test_bucketing_split_for_overlap_blocking(self):
|
||||
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
|
||||
|
||||
def func(a, b, c, d, *, ranks):
|
||||
# All 4 all-gathers are independent - COULD be bucketed together
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
|
||||
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
|
||||
|
||||
# First compute - can hide ag1 and ag2
|
||||
e = a * 5 # Use a to avoid fusion
|
||||
mm1 = torch.matmul(e, e.T)
|
||||
|
||||
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
|
||||
# Use first 8x8 elements to match mm1's shape
|
||||
intermediate = ag1[:8, :8] + ag2[:8, :8]
|
||||
|
||||
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
|
||||
mm2 = torch.matmul(mm1 + intermediate, c[:8])
|
||||
|
||||
# Use all results
|
||||
result = (
|
||||
ag1.sum() * 1.1
|
||||
+ ag2.sum() * 1.2
|
||||
+ ag3.sum() * 1.3
|
||||
+ ag4.sum() * 1.4
|
||||
+ mm1.sum()
|
||||
+ mm2.sum()
|
||||
)
|
||||
return result
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
|
||||
|
||||
# The 4 all gathers can be bucketed, and the wait should be sunk below the mms
|
||||
FileCheck().check_count(
|
||||
"_c10d_functional.all_gather_into_tensor", 1, exactly=True
|
||||
).check_count("ops.aten.mm", 2, exactly=True).check_count(
|
||||
"_c10d_functional.wait_tensor", 1, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches(2.0))
|
||||
def test_bucketing_split_for_overlap(self):
|
||||
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
|
||||
|
||||
def func(a, b, c, d, *, ranks):
|
||||
# All 4 all-gathers are independent - COULD be bucketed together
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
|
||||
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
|
||||
|
||||
# First compute - can hide ag1 and ag2
|
||||
e = a * 5 # Use a to avoid fusion
|
||||
mm1 = torch.matmul(e, e.T)
|
||||
|
||||
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
|
||||
intermediate = ag1[:2, :2] + ag2[:2, :2] # Small slice to minimize compute
|
||||
|
||||
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
|
||||
f = b * 6
|
||||
# Expand intermediate to match mm1's shape for broadcasting
|
||||
intermediate_expanded = torch.nn.functional.pad(intermediate, (0, 6, 0, 6))
|
||||
mm2 = torch.matmul(mm1 + intermediate_expanded, f.T)
|
||||
|
||||
# Use all results
|
||||
result = (
|
||||
ag1.sum() * 1.1
|
||||
+ ag2.sum() * 1.2
|
||||
+ ag3.sum() * 1.3
|
||||
+ ag4.sum() * 1.4
|
||||
+ mm1.sum()
|
||||
+ mm2.sum()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
|
||||
|
||||
# Should have 2 bucketed all-gathers (one for ag1+ag2, one for ag3+ag4)
|
||||
FileCheck().check_count(
|
||||
"_c10d_functional.all_gather_into_tensor_out", 2, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
# Verify the ordering - first bucket, then mm1, then second bucket, then mm2
|
||||
FileCheck().check("_c10d_functional.all_gather_into_tensor_out").check(
|
||||
"ops.aten.mm"
|
||||
).check("_c10d_functional.all_gather_into_tensor_out").check(
|
||||
"ops.aten.mm"
|
||||
).run(aten_graph_str)
|
||||
|
||||
# Verify correctness
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_bucket_exposed_with_hidden_single_overlap(self):
|
||||
"""Test that exposed and hidden collectives bucket together when overlap is preserved."""
|
||||
|
||||
def func(a, b, c, *, ranks):
|
||||
# ag1 will be hidden by mm1
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
|
||||
# ag2 and ag3 are exposed (no compute to hide them)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks)
|
||||
|
||||
# can only hide one collective
|
||||
mm1 = torch.matmul(a[:2], a[:2].T) # 2x2 matmul, hides only ag1
|
||||
|
||||
# All three can bucket together because:
|
||||
# bucketing ag1, ag2, ag3 together does not prevent ag1 being hidden by mm1.
|
||||
|
||||
return ag1.sum() + ag2.sum() + ag3.sum() + mm1.sum()
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c)
|
||||
|
||||
# Should have 1 bucketed operation containing all 3 all-gathers
|
||||
FileCheck().check_count("wait_tensor.default", 1, exactly=True).run(
|
||||
aten_graph_str
|
||||
)
|
||||
|
||||
# Verify bucketed collective overlaps with mm1
|
||||
FileCheck().check("functional.all_gather_into_tensor").check(
|
||||
"aten.mm"
|
||||
).check("wait_tensor").run(aten_graph_str)
|
||||
|
||||
# Verify correctness
|
||||
correct = func(a, b, c, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -95,6 +95,18 @@ class TestSerialization(TestCase):
|
||||
result = _streaming_load(file)
|
||||
torch.testing.assert_close(result, state_dict)
|
||||
|
||||
def test_empty_tensor(self) -> None:
|
||||
state_dict = {
|
||||
"empty": torch.zeros(0, 10),
|
||||
}
|
||||
|
||||
file = BytesIO()
|
||||
_streaming_save(state_dict, file)
|
||||
file.seek(0)
|
||||
|
||||
result = _streaming_load(file, weights_only=False)
|
||||
self.assertEqual(result, state_dict)
|
||||
|
||||
def test_dtensor(self) -> None:
|
||||
dist.init_process_group(
|
||||
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
|
||||
|
||||
@ -4,7 +4,7 @@ import itertools
|
||||
import os
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from unittest import skip, skipIf
|
||||
from unittest import skip, skipIf, skipUnless
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -25,6 +25,7 @@ from torch.distributed._symmetric_memory import (
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
SM100OrLater,
|
||||
SM89OrLater,
|
||||
SM90OrLater,
|
||||
xfailIfSM100OrLater,
|
||||
)
|
||||
@ -51,10 +52,6 @@ from torch.testing._internal.common_utils import (
|
||||
|
||||
test_contexts = [nullcontext, _test_mode]
|
||||
|
||||
# Set environment variable to disable multicast for all tests in this module
|
||||
# Workaround https://github.com/pytorch/pytorch/issues/162429
|
||||
os.environ["TORCH_SYMM_MEM_DISABLE_MULTICAST"] = "1"
|
||||
|
||||
# So that tests are written in device-agnostic way
|
||||
device_type = "cuda"
|
||||
device_module = torch.get_device_module(device_type)
|
||||
@ -430,6 +427,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
|
||||
@parametrize("gather_dim", [0, 1])
|
||||
@parametrize(
|
||||
"scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"]
|
||||
@ -545,6 +543,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
|
||||
@skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
|
||||
@parametrize("scatter_dim", [0, 1])
|
||||
@parametrize("rowwise", [True, False])
|
||||
def test_fused_scaled_matmul_reduce_scatter(
|
||||
|
||||
@ -759,6 +759,38 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
),
|
||||
)
|
||||
|
||||
def test_sac_with_partial_context_fn(self):
|
||||
class CustomPolicy:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, ctx, out, func, *args, **kwargs):
|
||||
return CheckpointPolicy.MUST_SAVE
|
||||
|
||||
def f(x, y):
|
||||
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
||||
|
||||
context_fn1 = functools.partial(
|
||||
create_selective_checkpoint_contexts, CustomPolicy()
|
||||
)
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
f,
|
||||
x,
|
||||
y,
|
||||
use_reentrant=False,
|
||||
context_fn=context_fn1,
|
||||
)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="aot_eager_decomp_partition", fullgraph=True)
|
||||
a = torch.randn(4, 4, requires_grad=True, device="cpu")
|
||||
b = torch.randn(4, 4, requires_grad=True, device="cpu")
|
||||
|
||||
expected = fn(a, b)
|
||||
result = opt_fn(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
|
||||
|
||||
@ -203,6 +203,22 @@ class TestAOTCompile(torch._inductor.test_case.TestCase):
|
||||
actual = compiled_fn(*example_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_aot_compile_disable_guard_check(self):
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
with torch.no_grad():
|
||||
compiled_fn = torch.compile(fn, fullgraph=True).aot_compile(
|
||||
((torch.randn(3, 4), torch.randn(3, 4)), {})
|
||||
)
|
||||
inputs = (torch.randn(3, 4), torch.randn(3, 4))
|
||||
expected = fn(*inputs)
|
||||
with self.assertRaisesRegex(RuntimeError, "GuardManager check failed"):
|
||||
compiled_fn(*inputs)
|
||||
compiled_fn.disable_guard_check()
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_aot_compile_source_info(self):
|
||||
from torch._dynamo.package import SourceInfo
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ class CallbackTests(TestCase):
|
||||
|
||||
def test_callbacks_with_duplicate_prevention(self) -> None:
|
||||
trigger = CallbackTrigger.DYNAMO
|
||||
compile_id = CompileId(0, 0)
|
||||
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
||||
with (
|
||||
callback_handler.install_callbacks(trigger, compile_id),
|
||||
callback_handler.install_callbacks(trigger, compile_id),
|
||||
@ -40,7 +40,7 @@ class CallbackTests(TestCase):
|
||||
|
||||
def test_counter(self) -> None:
|
||||
trigger = CallbackTrigger.DYNAMO
|
||||
compile_id = CompileId(0, 0)
|
||||
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
||||
with callback_handler.install_callbacks(trigger, compile_id):
|
||||
self.assertEqual(
|
||||
callback_handler._CompilationCallbackHandler__pending_callbacks_counter,
|
||||
@ -56,7 +56,7 @@ class CallbackTests(TestCase):
|
||||
AssertionError, "Pending callbacks counter cannot become negative."
|
||||
):
|
||||
trigger = CallbackTrigger.DYNAMO
|
||||
compile_id = CompileId(0, 0)
|
||||
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
||||
with callback_handler.install_callbacks(trigger, str(compile_id)):
|
||||
pass
|
||||
self.assertEqual(
|
||||
|
||||
@ -216,7 +216,7 @@ Unsupported context manager
|
||||
Hint: If the context manager seems like it should be supported (e.g. torch.set_grad_enabled), then it may be the case that it was created outside the compiled region, which Dynamo does not support. Supported context managers can cross graph break boundaries only if they are local non-closure variables, or are intermediate values.
|
||||
Hint: File an issue to PyTorch. Simple context managers can potentially be supported, but note that context managers can't be supported in general
|
||||
|
||||
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH on ConstantVariable(int: 3)
|
||||
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH/LOAD_SPECIAL on ConstantVariable(int: 3)
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0142.html
|
||||
|
||||
|
||||
@ -127,6 +127,8 @@ class GraphModule(torch.nn.Module):
|
||||
def fn(x):
|
||||
local_rank = device_mesh.get_local_rank()
|
||||
global_rank = device_mesh.get_rank()
|
||||
if "dp" not in device_mesh.mesh_dim_names:
|
||||
x = x * 2
|
||||
return x + local_rank + global_rank
|
||||
|
||||
x = torch.ones(10)
|
||||
|
||||
@ -95,7 +95,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
|
||||
transformed_code = code_map1[frame.f_code]
|
||||
return wrap_guarded_code(
|
||||
GuardedCode(
|
||||
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
|
||||
transformed_code,
|
||||
empty_guard_manager,
|
||||
CompileId(
|
||||
frame_id=None, frame_compile_id=0, compiled_autograd_id=0
|
||||
),
|
||||
)
|
||||
)
|
||||
return ConvertFrameReturn()
|
||||
@ -105,7 +109,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
|
||||
transformed_code = code_map2[frame.f_code]
|
||||
return wrap_guarded_code(
|
||||
GuardedCode(
|
||||
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
|
||||
transformed_code,
|
||||
empty_guard_manager,
|
||||
CompileId(
|
||||
frame_id=None, frame_compile_id=0, compiled_autograd_id=0
|
||||
),
|
||||
)
|
||||
)
|
||||
return ConvertFrameReturn()
|
||||
|
||||
@ -329,7 +329,9 @@ class TestGuardSerializationBase(torch._inductor.test_case.TestCase):
|
||||
package=None,
|
||||
)
|
||||
with (
|
||||
compile_context(CompileContext(CompileId(0, 0))),
|
||||
compile_context(
|
||||
CompileContext(CompileId(frame_id=0, frame_compile_id=0))
|
||||
),
|
||||
tracing(tracer.output.tracing_context),
|
||||
tracer.set_current_tx(),
|
||||
get_metrics_context(),
|
||||
|
||||
@ -5448,7 +5448,8 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
|
||||
# check ShapeEnv counters compared to binding indices
|
||||
shape_env = _get_shape_env_from_gm(ep.graph_module)
|
||||
next_index = next(shape_env.unbacked_symint_counter)
|
||||
next_index = shape_env.unbacked_symint_counter
|
||||
shape_env.unbacked_symint_counter += 1
|
||||
for symbol in bound:
|
||||
self.assertTrue(symbol_is_type(symbol, SymT.UNBACKED_INT))
|
||||
self.assertTrue(
|
||||
@ -10293,6 +10294,28 @@ graph():
|
||||
ep = export(m, args)
|
||||
self.assertEqual(ep.module()(*args), m(*args))
|
||||
|
||||
def test_cdist_forward_compute_mode_zero_export(self):
|
||||
class CDistModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(CDistModel, self).__init__()
|
||||
|
||||
def forward(self, x, y, compute_mode):
|
||||
return torch.ops.aten._cdist_forward(
|
||||
x, y, p=2.0, compute_mode=compute_mode
|
||||
)
|
||||
|
||||
x = torch.ones([3, 3])
|
||||
y = torch.ones([3, 3])
|
||||
model = CDistModel()
|
||||
|
||||
expected_none = model(x, y, None)
|
||||
ep_none = torch.export.export(model, (x, y, None))
|
||||
self.assertTrue(torch.equal(ep_none.module()(x, y, None), expected_none))
|
||||
|
||||
expected_0 = model(x, y, 0)
|
||||
ep_0 = torch.export.export(model, (x, y, 0))
|
||||
self.assertTrue(torch.equal(ep_0.module()(x, y, 0), expected_0))
|
||||
|
||||
def test_export_then_compile_tensor_ctor(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, scores, mask):
|
||||
|
||||
@ -56,8 +56,6 @@ fake_export_failures = {
|
||||
xfail("masked.var"),
|
||||
xfail("nn.functional.grid_sample"),
|
||||
xfail("to_sparse"),
|
||||
# cannot xfail as it is passing for cpu-only build
|
||||
skip("nn.functional.scaled_dot_product_attention"),
|
||||
# following are failing due to OptionalDeviceGuard
|
||||
xfail("__getitem__"),
|
||||
xfail("nn.functional.batch_norm"),
|
||||
@ -80,8 +78,7 @@ def _test_export_helper(self, dtype, op):
|
||||
sample_inputs_itr = op.sample_inputs("cpu", dtype, requires_grad=False)
|
||||
|
||||
mode = FakeTensorMode(allow_non_fake_inputs=True)
|
||||
# intentionally avoid cuda:0 to flush out some bugs
|
||||
target_device = "cuda:1"
|
||||
target_device = "cuda:0"
|
||||
|
||||
def to_fake_device(x):
|
||||
return x.to(target_device)
|
||||
@ -135,8 +132,10 @@ instantiate_device_type_tests(TestExportOpInfo, globals(), only_for="cpu")
|
||||
selected_ops = {
|
||||
"__getitem__",
|
||||
# "nn.functional.batch_norm", # needs to fix
|
||||
"nn.functional.conv2d",
|
||||
"nn.functional.instance_norm",
|
||||
"nn.functional.multi_margin_loss",
|
||||
"nn.functional.scaled_dot_product_attention",
|
||||
"nonzero",
|
||||
}
|
||||
selected_op_db = [op for op in op_db if op.name in selected_ops]
|
||||
|
||||
@ -924,6 +924,26 @@ def forward(self, x):
|
||||
loaded_ep = load(buffer)
|
||||
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
|
||||
|
||||
def test_non_float_weight(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.p = torch.nn.Parameter(
|
||||
torch.ones(2, 2, dtype=torch.int8), requires_grad=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.p
|
||||
|
||||
m = M()
|
||||
sample_inputs = (torch.randn(2, 2),)
|
||||
ep = torch.export.export(m, sample_inputs)
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
buffer.seek(0)
|
||||
loaded_ep = load(buffer)
|
||||
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
|
||||
|
||||
def test_complex_constant(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -1166,7 +1186,8 @@ class TestDeserialize(TestCase):
|
||||
|
||||
# check ShapeEnv counters
|
||||
shape_env = _get_shape_env_from_gm(loaded_ep.graph_module)
|
||||
next_index = next(shape_env.unbacked_symint_counter)
|
||||
next_index = shape_env.unbacked_symint_counter
|
||||
shape_env.unbacked_symint_counter += 1
|
||||
for symbol in bound:
|
||||
self.assertTrue(symbol_is_type(symbol, SymT.UNBACKED_INT))
|
||||
self.assertTrue(
|
||||
|
||||
@ -42,6 +42,7 @@ from torch.testing import FileCheck
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
IS_SM90,
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
@ -1238,6 +1239,72 @@ class AOTInductorTestsTemplate:
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ROCM or not IS_SM90,
|
||||
"scaled_grouped_mm is only supported on SM90",
|
||||
)
|
||||
@skipIfXpu
|
||||
def test_scaled_grouped_mm(self):
|
||||
# Test torch._scaled_grouped_mm AOTI lowering
|
||||
# cuda only
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, weight, scale_a, scale_b, offsets):
|
||||
# x: [num_groups, batch, in_features] - FP8 inputs
|
||||
# weight: [total_out_features, in_features] - FP8 weights (transposed)
|
||||
# scale_a: [num_groups] - input scales
|
||||
# scale_b: [num_groups] - weight scales
|
||||
# offsets: [num_groups] - cumulative output sizes
|
||||
output = torch._scaled_grouped_mm(
|
||||
x,
|
||||
weight.t(),
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
offs=offsets,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
return output.half()
|
||||
|
||||
dtype = torch.float16
|
||||
num_groups = 3
|
||||
batch_size = 64
|
||||
in_features = 128
|
||||
out_features_list = [64, 128, 256] # Different output sizes for each group
|
||||
|
||||
device = GPU_TYPE
|
||||
|
||||
# Calculate offsets (cumulative output sizes)
|
||||
offsets = torch.cumsum(torch.tensor(out_features_list), dim=0).to(
|
||||
device, dtype=torch.int32
|
||||
)
|
||||
total_out_features = sum(out_features_list)
|
||||
|
||||
# Create FP8 input tensors - stacked for all groups
|
||||
x_fp16 = torch.randn(
|
||||
num_groups, batch_size, in_features, dtype=dtype, device=device
|
||||
)
|
||||
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
|
||||
|
||||
# Create FP8 weight tensor - concatenated and transposed
|
||||
weight_fp16 = torch.randn(
|
||||
total_out_features, in_features, dtype=dtype, device=device
|
||||
)
|
||||
weight_fp8 = weight_fp16.to(torch.float8_e4m3fn)
|
||||
|
||||
# Create scales
|
||||
scale_a = torch.ones(num_groups, batch_size, device=device, dtype=torch.float32)
|
||||
scale_b = torch.ones(total_out_features, device=device, dtype=torch.float32)
|
||||
|
||||
self.check_model(
|
||||
Model(),
|
||||
(x_fp8, weight_fp8, scale_a, scale_b, offsets),
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
@ -7265,14 +7332,16 @@ class AOTInductorLoggingTest(LoggingTestCase):
|
||||
|
||||
class TestAOTInductorConfig(TestCase):
|
||||
def test_no_compile_standalone(self):
|
||||
with config.patch({"aot_inductor.compile_standalone": False}):
|
||||
with config.patch({"aot_inductor_mode.compile_standalone": False}):
|
||||
result = maybe_aoti_standalone_config({})
|
||||
self.assertEqual(result, {})
|
||||
|
||||
def test_compile_standalone_sets_package_cpp(self):
|
||||
result = maybe_aoti_standalone_config({"aot_inductor.compile_standalone": True})
|
||||
result = maybe_aoti_standalone_config(
|
||||
{"aot_inductor_mode.compile_standalone": True}
|
||||
)
|
||||
self.assertEqual(result["aot_inductor.package_cpp_only"], True)
|
||||
self.assertEqual(result["aot_inductor.compile_standalone"], True)
|
||||
self.assertEqual(result["aot_inductor_mode.compile_standalone"], True)
|
||||
self.assertEqual(result["aot_inductor.embed_kernel_binary"], True)
|
||||
self.assertEqual(
|
||||
result["aot_inductor.emit_multi_arch_kernel"], not torch.version.hip
|
||||
@ -7280,12 +7349,15 @@ class TestAOTInductorConfig(TestCase):
|
||||
self.assertEqual(
|
||||
result["aot_inductor.model_name_for_generated_files"], "aoti_model"
|
||||
)
|
||||
self.assertEqual(result["aot_inductor.dynamic_linkage"], False)
|
||||
|
||||
def test_compile_standalone_explicit_set(self):
|
||||
patches = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
"aot_inductor.package_cpp_only": True,
|
||||
"aot_inductor.embed_kernel_binary": True,
|
||||
"aot_inductor.dynamic_linkage": False,
|
||||
"aot_inductor.link_libtorch": False,
|
||||
"aot_inductor.emit_multi_arch_kernel": not torch.version.hip,
|
||||
"aot_inductor.model_name_for_generated_files": "aoti_model",
|
||||
}
|
||||
@ -7294,7 +7366,7 @@ class TestAOTInductorConfig(TestCase):
|
||||
|
||||
def test_compile_standalone_package_cpp_false_raises(self):
|
||||
patches = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
"aot_inductor.package_cpp_only": False,
|
||||
}
|
||||
with self.assertRaises(RuntimeError):
|
||||
@ -7302,7 +7374,7 @@ class TestAOTInductorConfig(TestCase):
|
||||
|
||||
with config.patch({"aot_inductor.package_cpp_only": False}):
|
||||
patches = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
}
|
||||
with self.assertRaises(RuntimeError):
|
||||
maybe_aoti_standalone_config(patches)
|
||||
|
||||
@ -393,7 +393,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
|
||||
# Test compilation when no name is passed in
|
||||
options = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
}
|
||||
with (
|
||||
tempfile.TemporaryDirectory() as tmp_dir,
|
||||
@ -407,7 +407,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
|
||||
# Test compilation when model name is passed in
|
||||
options = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
"aot_inductor.model_name_for_generated_files": "linear",
|
||||
}
|
||||
with (
|
||||
@ -422,7 +422,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
|
||||
# test invalid model name
|
||||
options = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
"aot_inductor.model_name_for_generated_files": "linear/linear",
|
||||
}
|
||||
with self.assertRaisesRegex(Exception, "Invalid AOTI model name"):
|
||||
@ -448,7 +448,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
|
||||
# Test compilation when model name is passed in
|
||||
options = {
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"aot_inductor_mode.compile_standalone": True,
|
||||
"aot_inductor.model_name_for_generated_files": "cos",
|
||||
}
|
||||
with (
|
||||
|
||||
69
test/inductor/test_aot_inductor_windows.py
Normal file
69
test/inductor/test_aot_inductor_windows.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import tempfile
|
||||
import unittest
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import IS_CI
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU, requires_gpu
|
||||
|
||||
|
||||
class Simple(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(16, 1)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class TestAOTInductorWindowsCrossCompilation(TestCase):
|
||||
@requires_gpu()
|
||||
def test_simple_so(self):
|
||||
if IS_CI:
|
||||
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
|
||||
|
||||
# TODO: enable in CI
|
||||
with torch.no_grad():
|
||||
device = "cuda"
|
||||
model = Simple().to(device=device)
|
||||
example_inputs = (torch.randn(8, 10, device=device),)
|
||||
batch_dim = torch.export.Dim("batch", min=1, max=1024)
|
||||
exported = torch.export.export(
|
||||
model, example_inputs, dynamic_shapes={"x": {0: batch_dim}}
|
||||
)
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
exported,
|
||||
inductor_configs={
|
||||
"aot_inductor.model_name_for_generated_files": "model",
|
||||
"aot_inductor.cross_target_platform": "windows",
|
||||
"aot_inductor.link_libtorch": False,
|
||||
"aot_inductor.aoti_shim_library": "executorch",
|
||||
# no fallback ops
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON,CPP",
|
||||
"max_autotune_conv_backends": "TRITON,CPP",
|
||||
# simplify things for now
|
||||
"aot_inductor.precompile_headers": False,
|
||||
},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with zipfile.ZipFile(package_path, "r") as zf:
|
||||
zf.extractall(tmpdir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_GPU:
|
||||
run_tests(needs="filelock")
|
||||
346
test/inductor/test_augmented_graph_helper.py
Normal file
346
test/inductor/test_augmented_graph_helper.py
Normal file
@ -0,0 +1,346 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
class TestAugmentedGraphHelper(TestCase):
|
||||
"""Test suite for AugmentedGraphHelper dependency and merge management."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create a simple graph structure for testing."""
|
||||
# Create a torch.fx.Graph with multiple nodes
|
||||
self.graph = fx.Graph()
|
||||
|
||||
# Create placeholder nodes (inputs)
|
||||
self.x = self.graph.placeholder("x")
|
||||
self.y = self.graph.placeholder("y")
|
||||
|
||||
# Create computation nodes with specific names for easy reference
|
||||
self.node_a = self.graph.call_function(
|
||||
torch.add, args=(self.x, self.y), name="A"
|
||||
)
|
||||
self.node_b = self.graph.call_function(
|
||||
torch.mul, args=(self.node_a, self.x), name="B"
|
||||
)
|
||||
self.node_c = self.graph.call_function(
|
||||
torch.sub, args=(self.node_a, self.y), name="C"
|
||||
)
|
||||
self.node_d = self.graph.call_function(
|
||||
torch.div, args=(self.node_b, self.node_c), name="D"
|
||||
)
|
||||
self.node_e = self.graph.call_function(
|
||||
operator.neg, args=(self.node_d,), name="E"
|
||||
)
|
||||
self.node_f = self.graph.call_function(torch.abs, args=(self.node_e,), name="F")
|
||||
self.node_g = self.graph.call_function(
|
||||
torch.relu, args=(self.node_f,), name="G"
|
||||
)
|
||||
self.node_h = self.graph.call_function(
|
||||
torch.sigmoid, args=(self.node_g,), name="H"
|
||||
)
|
||||
|
||||
# Create output
|
||||
self.graph.output(self.node_h)
|
||||
|
||||
# Create a mapping of nodes by name for easier access in tests
|
||||
self.nodes = {}
|
||||
for node in self.graph.nodes:
|
||||
if hasattr(node, "name") and node.name in [
|
||||
"A",
|
||||
"B",
|
||||
"C",
|
||||
"D",
|
||||
"E",
|
||||
"F",
|
||||
"G",
|
||||
"H",
|
||||
]:
|
||||
self.nodes[node.name] = node
|
||||
|
||||
# Get all nodes and create tracker
|
||||
self.all_nodes = list(self.graph.nodes)
|
||||
self.tracker = AugmentedGraphHelper(self.graph)
|
||||
|
||||
def get_deps(self, node):
|
||||
"""Helper to get dependencies for a node."""
|
||||
return list(getattr(node, "args", []))
|
||||
|
||||
# ========== Basic Functionality Tests ==========
|
||||
|
||||
def test_initial_state(self):
|
||||
"""Test that nodes start as singletons."""
|
||||
for node in self.all_nodes:
|
||||
merge_set = self.tracker.merge_sets[node]
|
||||
self.assertEqual(merge_set, {node})
|
||||
self.assertEqual(len(merge_set), 1)
|
||||
|
||||
def test_simple_merge(self):
|
||||
"""Test merging two nodes."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
|
||||
self.merge_nodes(self.tracker, [node_a, node_b])
|
||||
|
||||
# Both should be in same merge set
|
||||
self.assertEqual(self.tracker.merge_sets[node_a], {node_a, node_b})
|
||||
self.assertEqual(self.tracker.merge_sets[node_b], {node_a, node_b})
|
||||
self.assertEqual(
|
||||
self.tracker.merge_sets[node_a], self.tracker.merge_sets[node_b]
|
||||
)
|
||||
|
||||
def test_transitive_merge(self):
|
||||
"""Test merging already merged nodes."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
node_c = self.nodes["C"]
|
||||
node_d = self.nodes["D"]
|
||||
|
||||
# Merge A-B and C-D separately
|
||||
for node in node_b, node_c, node_d:
|
||||
self.tracker.merge_to_set(node_a, node)
|
||||
|
||||
expected_set = {node_a, node_b, node_c, node_d}
|
||||
for node in [node_a, node_b, node_c, node_d]:
|
||||
self.assertEqual(self.tracker.merge_sets[node], expected_set)
|
||||
|
||||
def merge_nodes(self, tracker, nodes):
|
||||
for n in nodes[1:]:
|
||||
tracker.merge_to_set(nodes[0], n)
|
||||
|
||||
def test_unmerge_node(self):
|
||||
"""Test removing a node from its merge set."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
node_c = self.nodes["C"]
|
||||
|
||||
# Merge all three
|
||||
self.merge_nodes(self.tracker, [node_a, node_b, node_c])
|
||||
self.assertEqual(len(self.tracker.merge_sets[node_a]), 3)
|
||||
|
||||
# Unmerge B
|
||||
self.tracker.unmerge_node(node_b)
|
||||
|
||||
# B should be singleton
|
||||
self.assertEqual(self.tracker.merge_sets[node_b], {node_b})
|
||||
|
||||
# A and C should still be together
|
||||
self.assertEqual(self.tracker.merge_sets[node_a], {node_a, node_c})
|
||||
self.assertEqual(self.tracker.merge_sets[node_c], {node_a, node_c})
|
||||
|
||||
def test_unmerge_from_singleton(self):
|
||||
"""Test unmerging a node that's already singleton."""
|
||||
node_a = self.nodes["A"]
|
||||
|
||||
# Should be no-op
|
||||
self.tracker.unmerge_node(node_a)
|
||||
self.assertEqual(self.tracker.merge_sets[node_a], {node_a})
|
||||
|
||||
# ========== Dependency Propagation Tests ==========
|
||||
|
||||
def test_merged_deps_collection(self):
|
||||
"""Test that dependencies are collected from all merged nodes."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
node_c = self.nodes["C"]
|
||||
|
||||
# B already depends on A (and x) from graph construction
|
||||
# C already depends on A (and y) from graph construction
|
||||
|
||||
# Merge B and C
|
||||
self.merge_nodes(self.tracker, [node_b, node_c])
|
||||
|
||||
# Get merged deps for B - should include deps from both B and C
|
||||
deps = self.tracker.get_merged_deps(node_b)
|
||||
|
||||
# Should include all dependencies from both nodes
|
||||
self.assertIn(node_a, deps) # From both B and C
|
||||
self.assertIn(self.x, deps) # From B
|
||||
self.assertIn(self.y, deps) # From C
|
||||
|
||||
def test_extra_deps_with_merge(self):
|
||||
"""Test extra dependencies work correctly with merged nodes."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
node_c = self.nodes["C"]
|
||||
node_d = self.nodes["D"]
|
||||
|
||||
# Add extra dep from A to C
|
||||
self.tracker.add_extra_dep(n=node_a, dep=node_c)
|
||||
|
||||
# Merge A and B
|
||||
self.merge_nodes(self.tracker, [node_a, node_b])
|
||||
|
||||
# Add extra dep from D to the merged node (via B)
|
||||
self.tracker.add_extra_dep(n=node_d, dep=node_b)
|
||||
|
||||
# D should depend on B through extra deps
|
||||
deps = self.tracker.get_merged_deps(node_d)
|
||||
self.assertIn(node_b, deps)
|
||||
|
||||
# A should still have its dep on C
|
||||
deps = self.tracker.get_merged_deps(node_a)
|
||||
self.assertIn(node_c, deps)
|
||||
|
||||
# ========== Path Finding Tests ==========
|
||||
|
||||
def test_has_path_direct(self):
|
||||
"""Test path finding for direct dependencies."""
|
||||
# In our graph: B depends on A
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
|
||||
self.assertTrue(self.tracker.has_path(node_a, node_b))
|
||||
self.assertFalse(self.tracker.has_path(node_b, node_a))
|
||||
|
||||
def test_has_path_transitive(self):
|
||||
"""Test path finding through multiple nodes."""
|
||||
# In our graph: A -> B -> D and A -> C -> D -> E
|
||||
node_a = self.nodes["A"]
|
||||
node_e = self.nodes["E"]
|
||||
|
||||
self.assertTrue(self.tracker.has_path(node_a, node_e))
|
||||
self.assertFalse(self.tracker.has_path(node_e, node_a))
|
||||
|
||||
def test_has_path_through_merge(self):
|
||||
"""Test path finding when nodes are merged."""
|
||||
# Create a new graph for this specific test
|
||||
graph2 = fx.Graph()
|
||||
x2 = graph2.placeholder("x")
|
||||
a2 = graph2.call_function(torch.neg, args=(x2,), name="A2")
|
||||
b2 = graph2.call_function(torch.abs, args=(a2,), name="B2")
|
||||
c2 = graph2.call_function(torch.relu, args=(x2,), name="C2")
|
||||
d2 = graph2.call_function(torch.sigmoid, args=(c2,), name="D2")
|
||||
graph2.output(d2)
|
||||
|
||||
tracker2 = AugmentedGraphHelper(graph2)
|
||||
|
||||
# Initially no path from B2 to D2
|
||||
self.assertFalse(tracker2.has_path(b2, d2))
|
||||
|
||||
# Merge B2 and C2
|
||||
tracker2.merge_to_set(b2, c2)
|
||||
|
||||
# Now there should be a path B2/C2 -> D2
|
||||
self.assertTrue(tracker2.has_path(b2, d2))
|
||||
|
||||
def test_has_path_with_extra_deps(self):
|
||||
"""Test path finding with extra dependencies."""
|
||||
|
||||
graph2 = fx.Graph()
|
||||
x2 = graph2.placeholder("x")
|
||||
a2 = graph2.call_function(torch.neg, args=(x2,), name="A2")
|
||||
b2 = graph2.call_function(torch.abs, args=(a2,), name="B2")
|
||||
c2 = graph2.call_function(torch.relu, args=(x2,), name="C2")
|
||||
d2 = graph2.call_function(torch.sigmoid, args=(c2,), name="D2")
|
||||
graph2.output(d2)
|
||||
|
||||
tracker2 = AugmentedGraphHelper(graph2)
|
||||
|
||||
# Initially no path from B2 to D2
|
||||
self.assertFalse(tracker2.has_path(b2, d2))
|
||||
|
||||
tracker2.add_extra_dep(n=c2, dep=b2)
|
||||
|
||||
# Now there should be a path B2/C2 -> D2
|
||||
self.assertTrue(tracker2.has_path(b2, d2))
|
||||
|
||||
# ========== Cycle Detection Tests ==========
|
||||
|
||||
def test_no_cycle_in_dag(self):
|
||||
"""Test that DAG has no cycles."""
|
||||
# Our original graph is a DAG, should have no cycles
|
||||
self.assertFalse(self.tracker.has_cycle())
|
||||
|
||||
def test_simple_cycle_detection(self):
|
||||
"""Test detection of simple cycle."""
|
||||
# Create a graph with a cycle
|
||||
graph3 = fx.Graph()
|
||||
x3 = graph3.placeholder("x")
|
||||
|
||||
# We can't create true cycles in fx.Graph directly,
|
||||
# but we can simulate with extra_deps
|
||||
a3 = graph3.call_function(torch.neg, args=(x3,))
|
||||
b3 = graph3.call_function(torch.abs, args=(a3,))
|
||||
c3 = graph3.call_function(torch.relu, args=(b3,))
|
||||
graph3.output(c3)
|
||||
|
||||
tracker3 = AugmentedGraphHelper(graph3)
|
||||
self.assertFalse(tracker3.has_cycle())
|
||||
|
||||
# Add extra dep to create cycle: a3 -> c3
|
||||
tracker3.add_extra_dep(n=a3, dep=c3)
|
||||
|
||||
self.assertTrue(tracker3.has_cycle())
|
||||
|
||||
def test_cycle_through_merge(self):
|
||||
"""Test that merging can create cycles."""
|
||||
# Create specific graph for this test
|
||||
graph4 = fx.Graph()
|
||||
x4 = graph4.placeholder("x")
|
||||
a4 = graph4.call_function(torch.neg, args=(x4,))
|
||||
b4 = graph4.call_function(torch.abs, args=(a4,))
|
||||
c4 = graph4.call_function(torch.relu, args=(x4,))
|
||||
d4 = graph4.call_function(torch.sigmoid, args=(c4,))
|
||||
graph4.output(d4)
|
||||
|
||||
tracker4 = AugmentedGraphHelper(graph4)
|
||||
|
||||
# Add extra dep d4 -> a4
|
||||
tracker4.add_extra_dep(n=a4, dep=d4)
|
||||
|
||||
# Now: a4 -> b4, c4 -> d4 -> a4
|
||||
# Merging b4 and c4 would create cycle
|
||||
tracker4.merge_to_set(b4, c4)
|
||||
|
||||
self.assertTrue(tracker4.has_cycle())
|
||||
|
||||
def test_cycle_with_extra_deps(self):
|
||||
"""Test cycle detection with extra dependencies."""
|
||||
node_a = self.nodes["A"]
|
||||
node_b = self.nodes["B"]
|
||||
|
||||
# B already depends on A naturally
|
||||
# Add reverse dependency to create cycle
|
||||
self.tracker.add_extra_dep(n=node_a, dep=node_b)
|
||||
|
||||
self.assertTrue(self.tracker.has_cycle())
|
||||
|
||||
def test_multiple_merge_unmerge(self):
|
||||
"""Test sequence of merge and unmerge operations."""
|
||||
nodes = [self.nodes[c] for c in ["A", "B", "C", "D", "E"]]
|
||||
|
||||
# Merge A, B, C
|
||||
self.merge_nodes(self.tracker, nodes[:3])
|
||||
self.assertEqual(len(self.tracker.merge_sets[nodes[0]]), 3)
|
||||
|
||||
# Merge D, E
|
||||
self.merge_nodes(self.tracker, nodes[3:5])
|
||||
self.assertEqual(len(self.tracker.merge_sets[nodes[3]]), 2)
|
||||
|
||||
# Merge the two groups via B and D
|
||||
try:
|
||||
self.merge_nodes(self.tracker, [nodes[1], nodes[3]])
|
||||
thrown = False
|
||||
except AssertionError:
|
||||
thrown = True
|
||||
self.assertTrue(thrown)
|
||||
|
||||
# Unmerge C
|
||||
self.tracker.unmerge_node(nodes[2])
|
||||
self.assertEqual(len(self.tracker.merge_sets[nodes[0]]), 2)
|
||||
self.assertEqual(self.tracker.merge_sets[nodes[2]], {nodes[2]})
|
||||
|
||||
# Unmerge A
|
||||
self.tracker.unmerge_node(nodes[0])
|
||||
self.assertEqual(self.tracker.merge_sets[nodes[0]], {nodes[0]})
|
||||
self.assertEqual(len(self.tracker.merge_sets[nodes[1]]), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user