mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 22:25:03 +08:00
Compare commits
73 Commits
ciflow/tru
...
dev/joona/
| Author | SHA1 | Date | |
|---|---|---|---|
| 94b993d7e4 | |||
| 0c76b784d1 | |||
| 0cd0bd7217 | |||
| fe33d7cadf | |||
| a9542426d0 | |||
| f79cdc89db | |||
| 3d063519bf | |||
| 0b3bdb0d89 | |||
| 8f00ec31ca | |||
| 21f32e4af3 | |||
| 940979a229 | |||
| 4fc688625a | |||
| 23f4f323ea | |||
| 9ac3fc0d0a | |||
| 38806f381a | |||
| cfb3a6b3da | |||
| d8384e296e | |||
| d273422582 | |||
| fadb62f592 | |||
| e5eb89e111 | |||
| b5e0e6932a | |||
| 6ea779188c | |||
| 460c7e196c | |||
| 7aac506cdc | |||
| 374ee9e867 | |||
| 698aa0f3e5 | |||
| d3ca4a3a4f | |||
| c940b1fbbc | |||
| 4de24bcc56 | |||
| f2d0a472ef | |||
| 9ae0ecec7d | |||
| ce4f31f662 | |||
| 2c846bb614 | |||
| 8c86ccfbc9 | |||
| 8f96e7bc1d | |||
| 782fc3c72b | |||
| 1a67403fc6 | |||
| 3d801a4c01 | |||
| 2034ca99ae | |||
| 480b4ff882 | |||
| f570e589da | |||
| f9851af59b | |||
| eeebf9f664 | |||
| d9a50bf9a8 | |||
| 2984331c87 | |||
| 9b68682df2 | |||
| 8f5f89c9a0 | |||
| 8919f69362 | |||
| 19c867873a | |||
| e3dadb1d36 | |||
| c9b09a31e8 | |||
| 35571fe94b | |||
| 485f2b607a | |||
| 0c5d5c7e9a | |||
| 5f98a0363a | |||
| 2d739001d3 | |||
| 273babeec3 | |||
| a76dd6b7c6 | |||
| 2fa18d1545 | |||
| 537167aa1e | |||
| 0dac408f43 | |||
| 158e72427b | |||
| 0184ef291d | |||
| 2ca428c721 | |||
| 1311385f9d | |||
| 5f0a5b8f87 | |||
| 74e85c6944 | |||
| a6a0379b9c | |||
| a95eee68d9 | |||
| 2ad70c9446 | |||
| bc09a84150 | |||
| 760c901c9a | |||
| d105e3a198 |
@ -1,19 +0,0 @@
|
||||
# Aarch64 (ARM/Graviton) Support Scripts
|
||||
Scripts for building aarch64 PyTorch PIP Wheels. These scripts build the following wheels:
|
||||
* torch
|
||||
* torchvision
|
||||
* torchaudio
|
||||
* torchtext
|
||||
* torchdata
|
||||
## Aarch64_ci_build.sh
|
||||
This script is design to support CD operations within PyPi manylinux aarch64 container, and be executed in the container. It prepares the container and then executes __aarch64_wheel_ci_build.py__ to build the wheels. The script "assumes" the PyTorch repo is located at: ```/pytorch``` and will put the wheels into ```/artifacts```.
|
||||
### Usage
|
||||
```DESIRED_PYTHON=<PythonVersion> aarch64_ci_build.sh```
|
||||
|
||||
__NOTE:__ CI build is currently __EXPERMINTAL__
|
||||
|
||||
## Build_aarch64_wheel.py
|
||||
This app allows a person to build using AWS EC3 resources and requires AWS-CLI and Boto3 with AWS credentials to support building EC2 instances for the wheel builds. Can be used in a codebuild CD or from a local system.
|
||||
|
||||
### Usage
|
||||
```build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch <RCtag>```
|
||||
@ -1,53 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -eux -o pipefail
|
||||
|
||||
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
|
||||
|
||||
# Set CUDA architecture lists to match x86 build_cuda.sh
|
||||
if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"13.0"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;11.0;12.0+PTX"
|
||||
fi
|
||||
|
||||
# Compress the fatbin with -compress-mode=size for CUDA 13
|
||||
if [[ "$DESIRED_CUDA" == *"13"* ]]; then
|
||||
export TORCH_NVCC_FLAGS="-compress-mode=size"
|
||||
# Bundle ptxas into the cu13 wheel, see https://github.com/pytorch/pytorch/issues/163801
|
||||
export BUILD_BUNDLE_PTXAS=1
|
||||
fi
|
||||
|
||||
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
|
||||
source $SCRIPTPATH/aarch64_ci_setup.sh
|
||||
|
||||
###############################################################################
|
||||
# Run aarch64 builder python
|
||||
###############################################################################
|
||||
cd /
|
||||
# adding safe directory for git as the permissions will be
|
||||
# on the mounted pytorch repo
|
||||
git config --global --add safe.directory /pytorch
|
||||
pip install -r /pytorch/requirements.txt
|
||||
pip install auditwheel==6.2.0 wheel
|
||||
if [ "$DESIRED_CUDA" = "cpu" ]; then
|
||||
echo "BASE_CUDA_VERSION is not set. Building cpu wheel."
|
||||
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
|
||||
else
|
||||
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
|
||||
export USE_SYSTEM_NCCL=1
|
||||
|
||||
# Check if we should use NVIDIA libs from PyPI (similar to x86 build_cuda.sh logic)
|
||||
if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
|
||||
echo "Bundling CUDA libraries with wheel for aarch64."
|
||||
else
|
||||
echo "Using nvidia libs from pypi for aarch64."
|
||||
echo "Updated PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64: $PYTORCH_EXTRA_INSTALL_REQUIREMENTS"
|
||||
export USE_NVIDIA_PYPI_LIBS=1
|
||||
fi
|
||||
|
||||
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
|
||||
fi
|
||||
@ -1,21 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -eux -o pipefail
|
||||
|
||||
# This script is used to prepare the Docker container for aarch64_ci_wheel_build.py python script
|
||||
# By creating symlinks from desired /opt/python to /usr/local/bin/
|
||||
|
||||
NUMPY_VERSION=2.0.2
|
||||
if [[ "$DESIRED_PYTHON" == "3.13" || "$DESIRED_PYTHON" == "3.13t" ]]; then
|
||||
NUMPY_VERSION=2.1.2
|
||||
fi
|
||||
|
||||
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
|
||||
source $SCRIPTPATH/../manywheel/set_desired_python.sh
|
||||
|
||||
pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2
|
||||
|
||||
for tool in python python3 pip pip3 ninja scons patchelf; do
|
||||
ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin;
|
||||
done
|
||||
|
||||
python --version
|
||||
@ -1,333 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# encoding: UTF-8
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from subprocess import check_call, check_output
|
||||
|
||||
|
||||
def list_dir(path: str) -> list[str]:
|
||||
"""'
|
||||
Helper for getting paths for Python
|
||||
"""
|
||||
return check_output(["ls", "-1", path]).decode().split("\n")
|
||||
|
||||
|
||||
def replace_tag(filename) -> None:
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("Tag:"):
|
||||
lines[i] = line.replace("-linux_", "-manylinux_2_28_")
|
||||
print(f"Updated tag from {line} to {lines[i]}")
|
||||
break
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
|
||||
def patch_library_rpath(
|
||||
folder: str,
|
||||
lib_name: str,
|
||||
use_nvidia_pypi_libs: bool = False,
|
||||
desired_cuda: str = "",
|
||||
) -> None:
|
||||
"""Apply patchelf to set RPATH for a library in torch/lib"""
|
||||
lib_path = f"{folder}/tmp/torch/lib/{lib_name}"
|
||||
|
||||
if use_nvidia_pypi_libs:
|
||||
# For PyPI NVIDIA libraries, construct CUDA RPATH
|
||||
cuda_rpaths = [
|
||||
"$ORIGIN/../../nvidia/cudnn/lib",
|
||||
"$ORIGIN/../../nvidia/nvshmem/lib",
|
||||
"$ORIGIN/../../nvidia/nccl/lib",
|
||||
"$ORIGIN/../../nvidia/cusparselt/lib",
|
||||
]
|
||||
|
||||
if "130" in desired_cuda:
|
||||
cuda_rpaths.append("$ORIGIN/../../nvidia/cu13/lib")
|
||||
else:
|
||||
cuda_rpaths.extend(
|
||||
[
|
||||
"$ORIGIN/../../nvidia/cublas/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_cupti/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_nvrtc/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_runtime/lib",
|
||||
"$ORIGIN/../../nvidia/cufft/lib",
|
||||
"$ORIGIN/../../nvidia/curand/lib",
|
||||
"$ORIGIN/../../nvidia/cusolver/lib",
|
||||
"$ORIGIN/../../nvidia/cusparse/lib",
|
||||
"$ORIGIN/../../nvidia/nvtx/lib",
|
||||
"$ORIGIN/../../nvidia/cufile/lib",
|
||||
]
|
||||
)
|
||||
|
||||
# Add $ORIGIN for local torch libs
|
||||
rpath = ":".join(cuda_rpaths) + ":$ORIGIN"
|
||||
else:
|
||||
# For bundled libraries, just use $ORIGIN
|
||||
rpath = "$ORIGIN"
|
||||
|
||||
if os.path.exists(lib_path):
|
||||
os.system(
|
||||
f"cd {folder}/tmp/torch/lib/; "
|
||||
f"patchelf --set-rpath '{rpath}' --force-rpath {lib_name}"
|
||||
)
|
||||
|
||||
|
||||
def copy_and_patch_library(
|
||||
src_path: str,
|
||||
folder: str,
|
||||
use_nvidia_pypi_libs: bool = False,
|
||||
desired_cuda: str = "",
|
||||
) -> None:
|
||||
"""Copy a library to torch/lib and patch its RPATH"""
|
||||
if os.path.exists(src_path):
|
||||
lib_name = os.path.basename(src_path)
|
||||
shutil.copy2(src_path, f"{folder}/tmp/torch/lib/{lib_name}")
|
||||
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
|
||||
def package_cuda_wheel(wheel_path, desired_cuda) -> None:
|
||||
"""
|
||||
Package the cuda wheel libraries
|
||||
"""
|
||||
folder = os.path.dirname(wheel_path)
|
||||
os.mkdir(f"{folder}/tmp")
|
||||
os.system(f"unzip {wheel_path} -d {folder}/tmp")
|
||||
# Delete original wheel since it will be repackaged
|
||||
os.system(f"rm {wheel_path}")
|
||||
|
||||
# Check if we should use PyPI NVIDIA libraries or bundle system libraries
|
||||
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
|
||||
|
||||
if use_nvidia_pypi_libs:
|
||||
print("Using nvidia libs from pypi - skipping CUDA library bundling")
|
||||
# For PyPI approach, we don't bundle CUDA libraries - they come from PyPI packages
|
||||
# We only need to bundle non-NVIDIA libraries
|
||||
minimal_libs_to_copy = [
|
||||
"/lib64/libgomp.so.1",
|
||||
"/usr/lib64/libgfortran.so.5",
|
||||
"/acl/build/libarm_compute.so",
|
||||
"/acl/build/libarm_compute_graph.so",
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0",
|
||||
]
|
||||
|
||||
# Copy minimal libraries to unzipped_folder/torch/lib
|
||||
for lib_path in minimal_libs_to_copy:
|
||||
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
# Patch torch libraries used for searching libraries
|
||||
torch_libs_to_patch = [
|
||||
"libtorch.so",
|
||||
"libtorch_cpu.so",
|
||||
"libtorch_cuda.so",
|
||||
"libtorch_cuda_linalg.so",
|
||||
"libtorch_global_deps.so",
|
||||
"libtorch_python.so",
|
||||
"libtorch_nvshmem.so",
|
||||
"libc10.so",
|
||||
"libc10_cuda.so",
|
||||
"libcaffe2_nvrtc.so",
|
||||
"libshm.so",
|
||||
]
|
||||
for lib_name in torch_libs_to_patch:
|
||||
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
|
||||
else:
|
||||
print("Bundling CUDA libraries with wheel")
|
||||
# Original logic for bundling system CUDA libraries
|
||||
# Common libraries for all CUDA versions
|
||||
common_libs = [
|
||||
# Non-NVIDIA system libraries
|
||||
"/lib64/libgomp.so.1",
|
||||
"/usr/lib64/libgfortran.so.5",
|
||||
"/acl/build/libarm_compute.so",
|
||||
"/acl/build/libarm_compute_graph.so",
|
||||
# Common CUDA libraries (same for all versions)
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0",
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so",
|
||||
"/usr/local/cuda/lib64/libcudnn.so.9",
|
||||
"/usr/local/cuda/lib64/libcusparseLt.so.0",
|
||||
"/usr/local/cuda/lib64/libcurand.so.10",
|
||||
"/usr/local/cuda/lib64/libnccl.so.2",
|
||||
"/usr/local/cuda/lib64/libnvshmem_host.so.3",
|
||||
"/usr/local/cuda/lib64/libcudnn_adv.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_cnn.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_graph.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_ops.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_heuristic.so.9",
|
||||
"/usr/local/cuda/lib64/libcufile.so.0",
|
||||
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
|
||||
"/usr/local/cuda/lib64/libcusparse.so.12",
|
||||
]
|
||||
|
||||
# CUDA version-specific libraries
|
||||
if "13" in desired_cuda:
|
||||
minor_version = desired_cuda[-1]
|
||||
version_specific_libs = [
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13",
|
||||
"/usr/local/cuda/lib64/libcublas.so.13",
|
||||
"/usr/local/cuda/lib64/libcublasLt.so.13",
|
||||
"/usr/local/cuda/lib64/libcudart.so.13",
|
||||
"/usr/local/cuda/lib64/libcufft.so.12",
|
||||
"/usr/local/cuda/lib64/libcusolver.so.12",
|
||||
"/usr/local/cuda/lib64/libnvJitLink.so.13",
|
||||
"/usr/local/cuda/lib64/libnvrtc.so.13",
|
||||
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.{minor_version}",
|
||||
]
|
||||
elif "12" in desired_cuda:
|
||||
# Get the last character for libnvrtc-builtins version (e.g., "129" -> "9")
|
||||
minor_version = desired_cuda[-1]
|
||||
version_specific_libs = [
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12",
|
||||
"/usr/local/cuda/lib64/libcublas.so.12",
|
||||
"/usr/local/cuda/lib64/libcublasLt.so.12",
|
||||
"/usr/local/cuda/lib64/libcudart.so.12",
|
||||
"/usr/local/cuda/lib64/libcufft.so.11",
|
||||
"/usr/local/cuda/lib64/libcusolver.so.11",
|
||||
"/usr/local/cuda/lib64/libnvJitLink.so.12",
|
||||
"/usr/local/cuda/lib64/libnvrtc.so.12",
|
||||
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.{minor_version}",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unsupported CUDA version: {desired_cuda}.")
|
||||
|
||||
# Combine all libraries
|
||||
libs_to_copy = common_libs + version_specific_libs
|
||||
|
||||
# Copy libraries to unzipped_folder/torch/lib
|
||||
for lib_path in libs_to_copy:
|
||||
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
# Make sure the wheel is tagged with manylinux_2_28
|
||||
for f in os.scandir(f"{folder}/tmp/"):
|
||||
if f.is_dir() and f.name.endswith(".dist-info"):
|
||||
replace_tag(f"{f.path}/WHEEL")
|
||||
break
|
||||
|
||||
os.system(f"wheel pack {folder}/tmp/ -d {folder}")
|
||||
os.system(f"rm -rf {folder}/tmp/")
|
||||
|
||||
|
||||
def complete_wheel(folder: str) -> str:
|
||||
"""
|
||||
Complete wheel build and put in artifact location
|
||||
"""
|
||||
wheel_name = list_dir(f"/{folder}/dist")[0]
|
||||
|
||||
# Please note for cuda we don't run auditwheel since we use custom script to package
|
||||
# the cuda dependencies to the wheel file using update_wheel() method.
|
||||
# However we need to make sure filename reflects the correct Manylinux platform.
|
||||
if "pytorch" in folder and not enable_cuda:
|
||||
print("Repairing Wheel with AuditWheel")
|
||||
check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder)
|
||||
repaired_wheel_name = list_dir(f"/{folder}/wheelhouse")[0]
|
||||
|
||||
print(f"Moving {repaired_wheel_name} wheel to /{folder}/dist")
|
||||
os.rename(
|
||||
f"/{folder}/wheelhouse/{repaired_wheel_name}",
|
||||
f"/{folder}/dist/{repaired_wheel_name}",
|
||||
)
|
||||
else:
|
||||
repaired_wheel_name = list_dir(f"/{folder}/dist")[0]
|
||||
|
||||
print(f"Copying {repaired_wheel_name} to artifacts")
|
||||
shutil.copy2(
|
||||
f"/{folder}/dist/{repaired_wheel_name}", f"/artifacts/{repaired_wheel_name}"
|
||||
)
|
||||
|
||||
return repaired_wheel_name
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""
|
||||
Parse inline arguments
|
||||
"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("AARCH64 wheels python CD")
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--build-only", action="store_true")
|
||||
parser.add_argument("--test-only", type=str)
|
||||
parser.add_argument("--enable-mkldnn", action="store_true")
|
||||
parser.add_argument("--enable-cuda", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Entry Point
|
||||
"""
|
||||
args = parse_arguments()
|
||||
enable_mkldnn = args.enable_mkldnn
|
||||
enable_cuda = args.enable_cuda
|
||||
branch = check_output(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd="/pytorch"
|
||||
).decode()
|
||||
|
||||
print("Building PyTorch wheel")
|
||||
build_vars = ""
|
||||
# MAX_JOB=5 is not required for CPU backend (see commit 465d98b)
|
||||
if enable_cuda:
|
||||
build_vars += "MAX_JOBS=5 "
|
||||
|
||||
# Handle PyPI NVIDIA libraries vs bundled libraries
|
||||
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
|
||||
if use_nvidia_pypi_libs:
|
||||
print("Configuring build for PyPI NVIDIA libraries")
|
||||
# Configure for dynamic linking (matching x86 logic)
|
||||
build_vars += "ATEN_STATIC_CUDA=0 USE_CUDA_STATIC_LINK=0 USE_CUPTI_SO=1 "
|
||||
else:
|
||||
print("Configuring build for bundled NVIDIA libraries")
|
||||
# Keep existing static linking approach - already configured above
|
||||
|
||||
override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION")
|
||||
desired_cuda = os.getenv("DESIRED_CUDA")
|
||||
if override_package_version is not None:
|
||||
version = override_package_version
|
||||
build_vars += (
|
||||
f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version} PYTORCH_BUILD_NUMBER=1 "
|
||||
)
|
||||
elif branch in ["nightly", "main"]:
|
||||
build_date = (
|
||||
check_output(["git", "log", "--pretty=format:%cs", "-1"], cwd="/pytorch")
|
||||
.decode()
|
||||
.replace("-", "")
|
||||
)
|
||||
version = (
|
||||
check_output(["cat", "version.txt"], cwd="/pytorch").decode().strip()[:-2]
|
||||
)
|
||||
if enable_cuda:
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date}+{desired_cuda} PYTORCH_BUILD_NUMBER=1 "
|
||||
else:
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 "
|
||||
elif branch.startswith(("v1.", "v2.")):
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
|
||||
|
||||
if enable_mkldnn:
|
||||
print("build pytorch with mkldnn+acl backend")
|
||||
build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
|
||||
build_vars += "ACL_ROOT_DIR=/acl "
|
||||
if enable_cuda:
|
||||
build_vars += "BLAS=NVPL "
|
||||
else:
|
||||
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS "
|
||||
else:
|
||||
print("build pytorch without mkldnn backend")
|
||||
|
||||
os.system(f"cd /pytorch; {build_vars} python3 -m build --wheel --no-isolation")
|
||||
if enable_cuda:
|
||||
print("Updating Cuda Dependency")
|
||||
filename = os.listdir("/pytorch/dist/")
|
||||
wheel_path = f"/pytorch/dist/{filename[0]}"
|
||||
package_cuda_wheel(wheel_path, desired_cuda)
|
||||
pytorch_wheel_name = complete_wheel("/pytorch/")
|
||||
print(f"Build Complete. Created {pytorch_wheel_name}..")
|
||||
@ -1,999 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# This script is for building AARCH64 wheels using AWS EC2 instances.
|
||||
# To generate binaries for the release follow these steps:
|
||||
# 1. Update mappings for each of the Domain Libraries by adding new row to a table like this:
|
||||
# "v1.11.0": ("0.11.0", "rc1"),
|
||||
# 2. Run script with following arguments for each of the supported python versions and required tag, for example:
|
||||
# build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch v1.11.0-rc3
|
||||
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
|
||||
# AMI images for us-east-1, change the following based on your ~/.aws/config
|
||||
os_amis = {
|
||||
"ubuntu20_04": "ami-052eac90edaa9d08f", # login_name: ubuntu
|
||||
"ubuntu22_04": "ami-0c6c29c5125214c77", # login_name: ubuntu
|
||||
"redhat8": "ami-0698b90665a2ddcf1", # login_name: ec2-user
|
||||
}
|
||||
|
||||
ubuntu20_04_ami = os_amis["ubuntu20_04"]
|
||||
|
||||
|
||||
def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]:
|
||||
if key_name is None:
|
||||
key_name = os.getenv("AWS_KEY_NAME")
|
||||
if key_name is None:
|
||||
return os.getenv("SSH_KEY_PATH", ""), ""
|
||||
|
||||
homedir_path = os.path.expanduser("~")
|
||||
default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem")
|
||||
return os.getenv("SSH_KEY_PATH", default_path), key_name
|
||||
|
||||
|
||||
ec2 = boto3.resource("ec2")
|
||||
|
||||
|
||||
def ec2_get_instances(filter_name, filter_value):
|
||||
return ec2.instances.filter(
|
||||
Filters=[{"Name": filter_name, "Values": [filter_value]}]
|
||||
)
|
||||
|
||||
|
||||
def ec2_instances_of_type(instance_type="t4g.2xlarge"):
|
||||
return ec2_get_instances("instance-type", instance_type)
|
||||
|
||||
|
||||
def ec2_instances_by_id(instance_id):
|
||||
rc = list(ec2_get_instances("instance-id", instance_id))
|
||||
return rc[0] if len(rc) > 0 else None
|
||||
|
||||
|
||||
def start_instance(
|
||||
key_name, ami=ubuntu20_04_ami, instance_type="t4g.2xlarge", ebs_size: int = 50
|
||||
):
|
||||
inst = ec2.create_instances(
|
||||
ImageId=ami,
|
||||
InstanceType=instance_type,
|
||||
SecurityGroups=["ssh-allworld"],
|
||||
KeyName=key_name,
|
||||
MinCount=1,
|
||||
MaxCount=1,
|
||||
BlockDeviceMappings=[
|
||||
{
|
||||
"DeviceName": "/dev/sda1",
|
||||
"Ebs": {
|
||||
"DeleteOnTermination": True,
|
||||
"VolumeSize": ebs_size,
|
||||
"VolumeType": "standard",
|
||||
},
|
||||
}
|
||||
],
|
||||
)[0]
|
||||
print(f"Create instance {inst.id}")
|
||||
inst.wait_until_running()
|
||||
running_inst = ec2_instances_by_id(inst.id)
|
||||
print(f"Instance started at {running_inst.public_dns_name}")
|
||||
return running_inst
|
||||
|
||||
|
||||
class RemoteHost:
|
||||
addr: str
|
||||
keyfile_path: str
|
||||
login_name: str
|
||||
container_id: Optional[str] = None
|
||||
ami: Optional[str] = None
|
||||
|
||||
def __init__(self, addr: str, keyfile_path: str, login_name: str = "ubuntu"):
|
||||
self.addr = addr
|
||||
self.keyfile_path = keyfile_path
|
||||
self.login_name = login_name
|
||||
|
||||
def _gen_ssh_prefix(self) -> list[str]:
|
||||
return [
|
||||
"ssh",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
f"{self.login_name}@{self.addr}",
|
||||
"--",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _split_cmd(args: Union[str, list[str]]) -> list[str]:
|
||||
return args.split() if isinstance(args, str) else args
|
||||
|
||||
def run_ssh_cmd(self, args: Union[str, list[str]]) -> None:
|
||||
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
|
||||
|
||||
def check_ssh_output(self, args: Union[str, list[str]]) -> str:
|
||||
return subprocess.check_output(
|
||||
self._gen_ssh_prefix() + self._split_cmd(args)
|
||||
).decode("utf-8")
|
||||
|
||||
def scp_upload_file(self, local_file: str, remote_file: str) -> None:
|
||||
subprocess.check_call(
|
||||
[
|
||||
"scp",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
local_file,
|
||||
f"{self.login_name}@{self.addr}:{remote_file}",
|
||||
]
|
||||
)
|
||||
|
||||
def scp_download_file(
|
||||
self, remote_file: str, local_file: Optional[str] = None
|
||||
) -> None:
|
||||
if local_file is None:
|
||||
local_file = "."
|
||||
subprocess.check_call(
|
||||
[
|
||||
"scp",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
f"{self.login_name}@{self.addr}:{remote_file}",
|
||||
local_file,
|
||||
]
|
||||
)
|
||||
|
||||
def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None:
|
||||
self.run_ssh_cmd("sudo apt-get install -y docker.io")
|
||||
self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}")
|
||||
self.run_ssh_cmd("sudo service docker start")
|
||||
self.run_ssh_cmd(f"docker pull {image}")
|
||||
self.container_id = self.check_ssh_output(
|
||||
f"docker run -t -d -w /root {image}"
|
||||
).strip()
|
||||
|
||||
def using_docker(self) -> bool:
|
||||
return self.container_id is not None
|
||||
|
||||
def run_cmd(self, args: Union[str, list[str]]) -> None:
|
||||
if not self.using_docker():
|
||||
return self.run_ssh_cmd(args)
|
||||
assert self.container_id is not None
|
||||
docker_cmd = self._gen_ssh_prefix() + [
|
||||
"docker",
|
||||
"exec",
|
||||
"-i",
|
||||
self.container_id,
|
||||
"bash",
|
||||
]
|
||||
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE)
|
||||
p.communicate(
|
||||
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
rc = p.wait()
|
||||
if rc != 0:
|
||||
raise subprocess.CalledProcessError(rc, docker_cmd)
|
||||
|
||||
def check_output(self, args: Union[str, list[str]]) -> str:
|
||||
if not self.using_docker():
|
||||
return self.check_ssh_output(args)
|
||||
assert self.container_id is not None
|
||||
docker_cmd = self._gen_ssh_prefix() + [
|
||||
"docker",
|
||||
"exec",
|
||||
"-i",
|
||||
self.container_id,
|
||||
"bash",
|
||||
]
|
||||
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||
(out, err) = p.communicate(
|
||||
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
rc = p.wait()
|
||||
if rc != 0:
|
||||
raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err)
|
||||
return out.decode("utf-8")
|
||||
|
||||
def upload_file(self, local_file: str, remote_file: str) -> None:
|
||||
if not self.using_docker():
|
||||
return self.scp_upload_file(local_file, remote_file)
|
||||
tmp_file = os.path.join("/tmp", os.path.basename(local_file))
|
||||
self.scp_upload_file(local_file, tmp_file)
|
||||
self.run_ssh_cmd(
|
||||
["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"]
|
||||
)
|
||||
self.run_ssh_cmd(["rm", tmp_file])
|
||||
|
||||
def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None:
|
||||
if not self.using_docker():
|
||||
return self.scp_download_file(remote_file, local_file)
|
||||
tmp_file = os.path.join("/tmp", os.path.basename(remote_file))
|
||||
self.run_ssh_cmd(
|
||||
["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file]
|
||||
)
|
||||
self.scp_download_file(tmp_file, local_file)
|
||||
self.run_ssh_cmd(["rm", tmp_file])
|
||||
|
||||
def download_wheel(
|
||||
self, remote_file: str, local_file: Optional[str] = None
|
||||
) -> None:
|
||||
if self.using_docker() and local_file is None:
|
||||
basename = os.path.basename(remote_file)
|
||||
local_file = basename.replace(
|
||||
"-linux_aarch64.whl", "-manylinux2014_aarch64.whl"
|
||||
)
|
||||
self.download_file(remote_file, local_file)
|
||||
|
||||
def list_dir(self, path: str) -> list[str]:
|
||||
return self.check_output(["ls", "-1", path]).split("\n")
|
||||
|
||||
|
||||
def wait_for_connection(addr, port, timeout=15, attempt_cnt=5):
|
||||
import socket
|
||||
|
||||
for i in range(attempt_cnt):
|
||||
try:
|
||||
with socket.create_connection((addr, port), timeout=timeout):
|
||||
return
|
||||
except (ConnectionRefusedError, TimeoutError): # noqa: PERF203
|
||||
if i == attempt_cnt - 1:
|
||||
raise
|
||||
time.sleep(timeout)
|
||||
|
||||
|
||||
def update_apt_repo(host: RemoteHost) -> None:
|
||||
time.sleep(5)
|
||||
host.run_cmd("sudo systemctl stop apt-daily.service || true")
|
||||
host.run_cmd("sudo systemctl stop unattended-upgrades.service || true")
|
||||
host.run_cmd(
|
||||
"while systemctl is-active --quiet apt-daily.service; do sleep 1; done"
|
||||
)
|
||||
host.run_cmd(
|
||||
"while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done"
|
||||
)
|
||||
host.run_cmd("sudo apt-get update")
|
||||
time.sleep(3)
|
||||
host.run_cmd("sudo apt-get update")
|
||||
|
||||
|
||||
def install_condaforge(
|
||||
host: RemoteHost, suffix: str = "latest/download/Miniforge3-Linux-aarch64.sh"
|
||||
) -> None:
|
||||
print("Install conda-forge")
|
||||
host.run_cmd(f"curl -OL https://github.com/conda-forge/miniforge/releases/{suffix}")
|
||||
host.run_cmd(f"sh -f {os.path.basename(suffix)} -b")
|
||||
host.run_cmd(f"rm -f {os.path.basename(suffix)}")
|
||||
if host.using_docker():
|
||||
host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc")
|
||||
else:
|
||||
host.run_cmd(
|
||||
[
|
||||
"sed",
|
||||
"-i",
|
||||
"'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH'",
|
||||
".bashrc",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None:
|
||||
if python_version == "3.6":
|
||||
# Python-3.6 EOLed and not compatible with conda-4.11
|
||||
install_condaforge(
|
||||
host, suffix="download/4.10.3-10/Miniforge3-4.10.3-10-Linux-aarch64.sh"
|
||||
)
|
||||
host.run_cmd(f"conda install -y python={python_version} numpy pyyaml")
|
||||
else:
|
||||
install_condaforge(
|
||||
host, suffix="download/4.11.0-4/Miniforge3-4.11.0-4-Linux-aarch64.sh"
|
||||
)
|
||||
# Pytorch-1.10 or older are not compatible with setuptools=59.6 or newer
|
||||
host.run_cmd(
|
||||
f"conda install -y python={python_version} numpy pyyaml setuptools>=59.5.0"
|
||||
)
|
||||
|
||||
|
||||
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
|
||||
host.run_cmd("pip3 install auditwheel")
|
||||
host.run_cmd(
|
||||
"conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf"
|
||||
)
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
with NamedTemporaryFile() as tmp:
|
||||
tmp.write(embed_library_script.encode("utf-8"))
|
||||
tmp.flush()
|
||||
host.upload_file(tmp.name, "embed_library.py")
|
||||
|
||||
print("Embedding libgomp into wheel")
|
||||
if host.using_docker():
|
||||
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
|
||||
else:
|
||||
host.run_cmd(f"python3 embed_library.py {wheel_name}")
|
||||
|
||||
|
||||
def checkout_repo(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
url: str,
|
||||
git_clone_flags: str,
|
||||
mapping: dict[str, tuple[str, str]],
|
||||
) -> Optional[str]:
|
||||
for prefix in mapping:
|
||||
if not branch.startswith(prefix):
|
||||
continue
|
||||
tag = f"v{mapping[prefix][0]}-{mapping[prefix][1]}"
|
||||
host.run_cmd(f"git clone {url} -b {tag} {git_clone_flags}")
|
||||
return mapping[prefix][0]
|
||||
|
||||
host.run_cmd(f"git clone {url} -b {branch} {git_clone_flags}")
|
||||
return None
|
||||
|
||||
|
||||
def build_torchvision(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str,
|
||||
run_smoke_tests: bool = True,
|
||||
) -> str:
|
||||
print("Checking out TorchVision repo")
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/vision",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.7.1": ("0.8.2", "rc2"),
|
||||
"v1.8.0": ("0.9.0", "rc3"),
|
||||
"v1.8.1": ("0.9.1", "rc1"),
|
||||
"v1.9.0": ("0.10.0", "rc1"),
|
||||
"v1.10.0": ("0.11.1", "rc1"),
|
||||
"v1.10.1": ("0.11.2", "rc1"),
|
||||
"v1.10.2": ("0.11.3", "rc1"),
|
||||
"v1.11.0": ("0.12.0", "rc1"),
|
||||
"v1.12.0": ("0.13.0", "rc4"),
|
||||
"v1.12.1": ("0.13.1", "rc6"),
|
||||
"v1.13.0": ("0.14.0", "rc4"),
|
||||
"v1.13.1": ("0.14.1", "rc2"),
|
||||
"v2.0.0": ("0.15.1", "rc2"),
|
||||
"v2.0.1": ("0.15.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchVision wheel")
|
||||
|
||||
# Please note libnpg and jpeg are required to build image.so extension
|
||||
if use_conda:
|
||||
host.run_cmd("conda install -y libpng jpeg")
|
||||
# Remove .so files to force static linking
|
||||
host.run_cmd(
|
||||
"rm miniforge3/lib/libpng.so miniforge3/lib/libpng16.so miniforge3/lib/libjpeg.so"
|
||||
)
|
||||
# And patch setup.py to include libz dependency for libpng
|
||||
host.run_cmd(
|
||||
[
|
||||
'sed -i -e \'s/image_link_flags\\.append("png")/image_link_flags += ["png", "z"]/\' vision/setup.py'
|
||||
]
|
||||
)
|
||||
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"]
|
||||
).strip()
|
||||
if len(version) == 0:
|
||||
# In older revisions, version was embedded in setup.py
|
||||
version = (
|
||||
host.check_output(["grep", '"version = \'"', "vision/setup.py"])
|
||||
.strip()
|
||||
.split("'")[1][:-2]
|
||||
)
|
||||
build_date = (
|
||||
host.check_output("cd vision && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd vision && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
vision_wheel_name = host.list_dir("vision/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("vision", "dist", vision_wheel_name))
|
||||
|
||||
print("Copying TorchVision wheel")
|
||||
host.download_wheel(os.path.join("vision", "dist", vision_wheel_name))
|
||||
if run_smoke_tests:
|
||||
host.run_cmd(
|
||||
f"pip3 install {os.path.join('vision', 'dist', vision_wheel_name)}"
|
||||
)
|
||||
host.run_cmd("python3 vision/test/smoke_test.py")
|
||||
print("Delete vision checkout")
|
||||
host.run_cmd("rm -rf vision")
|
||||
|
||||
return vision_wheel_name
|
||||
|
||||
|
||||
def build_torchdata(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchData repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/data",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.13.1": ("0.5.1", ""),
|
||||
"v2.0.0": ("0.6.0", "rc5"),
|
||||
"v2.0.1": ("0.6.1", "rc1"),
|
||||
},
|
||||
)
|
||||
print("Building TorchData wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f data/version.txt ]; then cat data/version.txt; fi"]
|
||||
).strip()
|
||||
build_date = (
|
||||
host.check_output("cd data && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd data && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
wheel_name = host.list_dir("data/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("data", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchData wheel")
|
||||
host.download_wheel(os.path.join("data", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def build_torchtext(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchText repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/text",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.9.0": ("0.10.0", "rc1"),
|
||||
"v1.10.0": ("0.11.0", "rc2"),
|
||||
"v1.10.1": ("0.11.1", "rc1"),
|
||||
"v1.10.2": ("0.11.2", "rc1"),
|
||||
"v1.11.0": ("0.12.0", "rc1"),
|
||||
"v1.12.0": ("0.13.0", "rc2"),
|
||||
"v1.12.1": ("0.13.1", "rc5"),
|
||||
"v1.13.0": ("0.14.0", "rc3"),
|
||||
"v1.13.1": ("0.14.1", "rc1"),
|
||||
"v2.0.0": ("0.15.1", "rc2"),
|
||||
"v2.0.1": ("0.15.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchText wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f text/version.txt ]; then cat text/version.txt; fi"]
|
||||
).strip()
|
||||
build_date = (
|
||||
host.check_output("cd text && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd text && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
wheel_name = host.list_dir("text/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("text", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchText wheel")
|
||||
host.download_wheel(os.path.join("text", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def build_torchaudio(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchAudio repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/audio",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.9.0": ("0.9.0", "rc2"),
|
||||
"v1.10.0": ("0.10.0", "rc5"),
|
||||
"v1.10.1": ("0.10.1", "rc1"),
|
||||
"v1.10.2": ("0.10.2", "rc1"),
|
||||
"v1.11.0": ("0.11.0", "rc1"),
|
||||
"v1.12.0": ("0.12.0", "rc3"),
|
||||
"v1.12.1": ("0.12.1", "rc5"),
|
||||
"v1.13.0": ("0.13.0", "rc4"),
|
||||
"v1.13.1": ("0.13.1", "rc2"),
|
||||
"v2.0.0": ("2.0.1", "rc3"),
|
||||
"v2.0.1": ("2.0.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchAudio wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = (
|
||||
host.check_output(["grep", '"version = \'"', "audio/setup.py"])
|
||||
.strip()
|
||||
.split("'")[1][:-2]
|
||||
)
|
||||
build_date = (
|
||||
host.check_output("cd audio && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(
|
||||
f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \
|
||||
&& ./packaging/ffmpeg/build.sh \
|
||||
&& {build_vars} python3 -m build --wheel --no-isolation"
|
||||
)
|
||||
|
||||
wheel_name = host.list_dir("audio/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("audio", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchAudio wheel")
|
||||
host.download_wheel(os.path.join("audio", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def configure_system(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
compiler: str = "gcc-8",
|
||||
use_conda: bool = True,
|
||||
python_version: str = "3.8",
|
||||
) -> None:
|
||||
if use_conda:
|
||||
install_condaforge_python(host, python_version)
|
||||
|
||||
print("Configuring the system")
|
||||
if not host.using_docker():
|
||||
update_apt_repo(host)
|
||||
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip")
|
||||
else:
|
||||
host.run_cmd("yum install -y sudo")
|
||||
host.run_cmd("conda install -y ninja scons")
|
||||
|
||||
if not use_conda:
|
||||
host.run_cmd(
|
||||
"sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip"
|
||||
)
|
||||
host.run_cmd("pip3 install dataclasses typing-extensions")
|
||||
if not use_conda:
|
||||
print("Installing Cython + numpy from PyPy")
|
||||
host.run_cmd("sudo pip3 install Cython")
|
||||
host.run_cmd("sudo pip3 install numpy")
|
||||
|
||||
|
||||
def build_domains(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> tuple[str, str, str, str]:
|
||||
vision_wheel_name = build_torchvision(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
audio_wheel_name = build_torchaudio(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
data_wheel_name = build_torchdata(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
text_wheel_name = build_torchtext(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
return (vision_wheel_name, audio_wheel_name, data_wheel_name, text_wheel_name)
|
||||
|
||||
|
||||
def start_build(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
compiler: str = "gcc-8",
|
||||
use_conda: bool = True,
|
||||
python_version: str = "3.8",
|
||||
pytorch_only: bool = False,
|
||||
pytorch_build_number: Optional[str] = None,
|
||||
shallow_clone: bool = True,
|
||||
enable_mkldnn: bool = False,
|
||||
) -> tuple[str, str, str, str, str]:
|
||||
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
|
||||
if host.using_docker() and not use_conda:
|
||||
print("Auto-selecting conda option for docker images")
|
||||
use_conda = True
|
||||
if not host.using_docker():
|
||||
print("Disable mkldnn for host builds")
|
||||
enable_mkldnn = False
|
||||
|
||||
configure_system(
|
||||
host, compiler=compiler, use_conda=use_conda, python_version=python_version
|
||||
)
|
||||
|
||||
if host.using_docker():
|
||||
print("Move libgfortant.a into a standard location")
|
||||
# HACK: pypa gforntran.a is compiled without PIC, which leads to the following error
|
||||
# libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17' # noqa: E501, B950
|
||||
# Workaround by copying gfortran library from the host
|
||||
host.run_ssh_cmd("sudo apt-get install -y gfortran-8")
|
||||
host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8")
|
||||
host.run_ssh_cmd(
|
||||
[
|
||||
"docker",
|
||||
"cp",
|
||||
"/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a",
|
||||
f"{host.container_id}:/opt/rh/devtoolset-10/root/usr/lib/gcc/aarch64-redhat-linux/10/",
|
||||
]
|
||||
)
|
||||
|
||||
print("Checking out PyTorch repo")
|
||||
host.run_cmd(
|
||||
f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}"
|
||||
)
|
||||
|
||||
host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh")
|
||||
|
||||
print("Building PyTorch wheel")
|
||||
build_opts = ""
|
||||
if pytorch_build_number is not None:
|
||||
build_opts += f" -C--build-option=--build-number={pytorch_build_number}"
|
||||
# Breakpad build fails on aarch64
|
||||
build_vars = "USE_BREAKPAD=0 "
|
||||
if branch == "nightly":
|
||||
build_date = (
|
||||
host.check_output("cd pytorch && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
version = host.check_output("cat pytorch/version.txt").strip()[:-2]
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1"
|
||||
if branch.startswith(("v1.", "v2.")):
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
if enable_mkldnn:
|
||||
host.run_cmd("pytorch/.ci/docker/common/install_acl.sh")
|
||||
print("build pytorch with mkldnn+acl backend")
|
||||
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
|
||||
build_vars += " BLAS=OpenBLAS"
|
||||
build_vars += " OpenBLAS_HOME=/opt/OpenBLAS"
|
||||
build_vars += " ACL_ROOT_DIR=/acl"
|
||||
host.run_cmd(
|
||||
f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
|
||||
)
|
||||
print("Repair the wheel")
|
||||
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
|
||||
ld_library_path = "/acl/build:$HOME/pytorch/build/lib"
|
||||
host.run_cmd(
|
||||
f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}"
|
||||
)
|
||||
print("replace the original wheel with the repaired one")
|
||||
pytorch_repaired_wheel_name = host.list_dir("wheelhouse")[0]
|
||||
host.run_cmd(
|
||||
f"cp $HOME/wheelhouse/{pytorch_repaired_wheel_name} $HOME/pytorch/dist/{pytorch_wheel_name}"
|
||||
)
|
||||
else:
|
||||
print("build pytorch without mkldnn backend")
|
||||
host.run_cmd(
|
||||
f"cd pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
|
||||
)
|
||||
|
||||
print("Deleting build folder")
|
||||
host.run_cmd("cd pytorch && rm -rf build")
|
||||
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("pytorch", "dist", pytorch_wheel_name))
|
||||
print("Copying the wheel")
|
||||
host.download_wheel(os.path.join("pytorch", "dist", pytorch_wheel_name))
|
||||
|
||||
print("Installing PyTorch wheel")
|
||||
host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}")
|
||||
|
||||
if pytorch_only:
|
||||
return (pytorch_wheel_name, None, None, None, None)
|
||||
domain_wheels = build_domains(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
|
||||
return (pytorch_wheel_name, *domain_wheels)
|
||||
|
||||
|
||||
embed_library_script = """
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from auditwheel.patcher import Patchelf
|
||||
from auditwheel.wheeltools import InWheelCtx
|
||||
from auditwheel.elfutils import elf_file_filter
|
||||
from auditwheel.repair import copylib
|
||||
from auditwheel.lddtree import lddtree
|
||||
from subprocess import check_call
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
|
||||
def replace_tag(filename):
|
||||
with open(filename, 'r') as f:
|
||||
lines = f.read().split("\\n")
|
||||
for i,line in enumerate(lines):
|
||||
if not line.startswith("Tag: "):
|
||||
continue
|
||||
lines[i] = line.replace("-linux_", "-manylinux2014_")
|
||||
print(f'Updated tag from {line} to {lines[i]}')
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
f.write("\\n".join(lines))
|
||||
|
||||
|
||||
class AlignedPatchelf(Patchelf):
|
||||
def set_soname(self, file_name: str, new_soname: str) -> None:
|
||||
check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name])
|
||||
|
||||
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
|
||||
check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name])
|
||||
|
||||
|
||||
def embed_library(whl_path, lib_soname, update_tag=False):
|
||||
patcher = AlignedPatchelf()
|
||||
out_dir = TemporaryDirectory()
|
||||
whl_name = os.path.basename(whl_path)
|
||||
tmp_whl_name = os.path.join(out_dir.name, whl_name)
|
||||
with InWheelCtx(whl_path) as ctx:
|
||||
torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib')
|
||||
ctx.out_wheel=tmp_whl_name
|
||||
new_lib_path, new_lib_soname = None, None
|
||||
for filename, elf in elf_file_filter(ctx.iter_files()):
|
||||
if not filename.startswith('torch/lib'):
|
||||
continue
|
||||
libtree = lddtree(filename)
|
||||
if lib_soname not in libtree['needed']:
|
||||
continue
|
||||
lib_path = libtree['libs'][lib_soname]['path']
|
||||
if lib_path is None:
|
||||
print(f"Can't embed {lib_soname} as it could not be found")
|
||||
break
|
||||
if lib_path.startswith(torchlib_path):
|
||||
continue
|
||||
|
||||
if new_lib_path is None:
|
||||
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
|
||||
patcher.replace_needed(filename, lib_soname, new_lib_soname)
|
||||
print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}')
|
||||
if update_tag:
|
||||
# Add manylinux2014 tag
|
||||
for filename in ctx.iter_files():
|
||||
if os.path.basename(filename) != 'WHEEL':
|
||||
continue
|
||||
replace_tag(filename)
|
||||
shutil.move(tmp_whl_name, whl_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
|
||||
"""
|
||||
|
||||
|
||||
def run_tests(host: RemoteHost, whl: str, branch="main") -> None:
|
||||
print("Configuring the system")
|
||||
update_apt_repo(host)
|
||||
host.run_cmd("sudo apt-get install -y python3-pip git")
|
||||
host.run_cmd("sudo pip3 install Cython")
|
||||
host.run_cmd("sudo pip3 install numpy")
|
||||
host.upload_file(whl, ".")
|
||||
host.run_cmd(f"sudo pip3 install {whl}")
|
||||
host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'")
|
||||
host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch")
|
||||
host.run_cmd("cd pytorch/test; python3 test_torch.py -v")
|
||||
|
||||
|
||||
def get_instance_name(instance) -> Optional[str]:
|
||||
if instance.tags is None:
|
||||
return None
|
||||
for tag in instance.tags:
|
||||
if tag["Key"] == "Name":
|
||||
return tag["Value"]
|
||||
return None
|
||||
|
||||
|
||||
def list_instances(instance_type: str) -> None:
|
||||
print(f"All instances of type {instance_type}")
|
||||
for instance in ec2_instances_of_type(instance_type):
|
||||
ifaces = instance.network_interfaces
|
||||
az = ifaces[0].subnet.availability_zone if len(ifaces) > 0 else None
|
||||
print(
|
||||
f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']} {az}"
|
||||
)
|
||||
|
||||
|
||||
def terminate_instances(instance_type: str) -> None:
|
||||
print(f"Terminating all instances of type {instance_type}")
|
||||
instances = list(ec2_instances_of_type(instance_type))
|
||||
for instance in instances:
|
||||
print(f"Terminating {instance.id}")
|
||||
instance.terminate()
|
||||
print("Waiting for termination to complete")
|
||||
for instance in instances:
|
||||
instance.wait_until_terminated()
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("Build and test AARCH64 wheels using EC2")
|
||||
parser.add_argument("--key-name", type=str)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--build-only", action="store_true")
|
||||
parser.add_argument("--test-only", type=str)
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--os", type=str, choices=list(os_amis.keys()))
|
||||
group.add_argument("--ami", type=str)
|
||||
parser.add_argument(
|
||||
"--python-version",
|
||||
type=str,
|
||||
choices=[f"3.{d}" for d in range(6, 12)],
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--alloc-instance", action="store_true")
|
||||
parser.add_argument("--list-instances", action="store_true")
|
||||
parser.add_argument("--pytorch-only", action="store_true")
|
||||
parser.add_argument("--keep-running", action="store_true")
|
||||
parser.add_argument("--terminate-instances", action="store_true")
|
||||
parser.add_argument("--instance-type", type=str, default="t4g.2xlarge")
|
||||
parser.add_argument("--ebs-size", type=int, default=50)
|
||||
parser.add_argument("--branch", type=str, default="main")
|
||||
parser.add_argument("--use-docker", action="store_true")
|
||||
parser.add_argument(
|
||||
"--compiler",
|
||||
type=str,
|
||||
choices=["gcc-7", "gcc-8", "gcc-9", "clang"],
|
||||
default="gcc-8",
|
||||
)
|
||||
parser.add_argument("--use-torch-from-pypi", action="store_true")
|
||||
parser.add_argument("--pytorch-build-number", type=str, default=None)
|
||||
parser.add_argument("--disable-mkldnn", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
ami = (
|
||||
args.ami
|
||||
if args.ami is not None
|
||||
else os_amis[args.os]
|
||||
if args.os is not None
|
||||
else ubuntu20_04_ami
|
||||
)
|
||||
keyfile_path, key_name = compute_keyfile_path(args.key_name)
|
||||
|
||||
if args.list_instances:
|
||||
list_instances(args.instance_type)
|
||||
sys.exit(0)
|
||||
|
||||
if args.terminate_instances:
|
||||
terminate_instances(args.instance_type)
|
||||
sys.exit(0)
|
||||
|
||||
if len(key_name) == 0:
|
||||
raise RuntimeError("""
|
||||
Cannot start build without key_name, please specify
|
||||
--key-name argument or AWS_KEY_NAME environment variable.""")
|
||||
if len(keyfile_path) == 0 or not os.path.exists(keyfile_path):
|
||||
raise RuntimeError(f"""
|
||||
Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please
|
||||
check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""")
|
||||
|
||||
# Starting the instance
|
||||
inst = start_instance(
|
||||
key_name, ami=ami, instance_type=args.instance_type, ebs_size=args.ebs_size
|
||||
)
|
||||
instance_name = f"{args.key_name}-{args.os}"
|
||||
if args.python_version is not None:
|
||||
instance_name += f"-py{args.python_version}"
|
||||
inst.create_tags(
|
||||
DryRun=False,
|
||||
Tags=[
|
||||
{
|
||||
"Key": "Name",
|
||||
"Value": instance_name,
|
||||
}
|
||||
],
|
||||
)
|
||||
addr = inst.public_dns_name
|
||||
wait_for_connection(addr, 22)
|
||||
host = RemoteHost(addr, keyfile_path)
|
||||
host.ami = ami
|
||||
if args.use_docker:
|
||||
update_apt_repo(host)
|
||||
host.start_docker()
|
||||
|
||||
if args.test_only:
|
||||
run_tests(host, args.test_only)
|
||||
sys.exit(0)
|
||||
|
||||
if args.alloc_instance:
|
||||
if args.python_version is None:
|
||||
sys.exit(0)
|
||||
install_condaforge_python(host, args.python_version)
|
||||
sys.exit(0)
|
||||
|
||||
python_version = args.python_version if args.python_version is not None else "3.10"
|
||||
|
||||
if args.use_torch_from_pypi:
|
||||
configure_system(host, compiler=args.compiler, python_version=python_version)
|
||||
print("Installing PyTorch wheel")
|
||||
host.run_cmd("pip3 install torch")
|
||||
build_domains(
|
||||
host, branch=args.branch, git_clone_flags=" --depth 1 --shallow-submodules"
|
||||
)
|
||||
else:
|
||||
start_build(
|
||||
host,
|
||||
branch=args.branch,
|
||||
compiler=args.compiler,
|
||||
python_version=python_version,
|
||||
pytorch_only=args.pytorch_only,
|
||||
pytorch_build_number=args.pytorch_build_number,
|
||||
enable_mkldnn=not args.disable_mkldnn,
|
||||
)
|
||||
if not args.keep_running:
|
||||
print(f"Waiting for instance {inst.id} to terminate")
|
||||
inst.terminate()
|
||||
inst.wait_until_terminated()
|
||||
@ -1,87 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from subprocess import check_call
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from auditwheel.elfutils import elf_file_filter
|
||||
from auditwheel.lddtree import lddtree
|
||||
from auditwheel.patcher import Patchelf
|
||||
from auditwheel.repair import copylib
|
||||
from auditwheel.wheeltools import InWheelCtx
|
||||
|
||||
|
||||
def replace_tag(filename):
|
||||
with open(filename) as f:
|
||||
lines = f.read().split("\\n")
|
||||
for i, line in enumerate(lines):
|
||||
if not line.startswith("Tag: "):
|
||||
continue
|
||||
lines[i] = line.replace("-linux_", "-manylinux2014_")
|
||||
print(f"Updated tag from {line} to {lines[i]}")
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("\\n".join(lines))
|
||||
|
||||
|
||||
class AlignedPatchelf(Patchelf):
|
||||
def set_soname(self, file_name: str, new_soname: str) -> None:
|
||||
check_call(
|
||||
["patchelf", "--page-size", "65536", "--set-soname", new_soname, file_name]
|
||||
)
|
||||
|
||||
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
|
||||
check_call(
|
||||
[
|
||||
"patchelf",
|
||||
"--page-size",
|
||||
"65536",
|
||||
"--replace-needed",
|
||||
soname,
|
||||
new_soname,
|
||||
file_name,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def embed_library(whl_path, lib_soname, update_tag=False):
|
||||
patcher = AlignedPatchelf()
|
||||
out_dir = TemporaryDirectory()
|
||||
whl_name = os.path.basename(whl_path)
|
||||
tmp_whl_name = os.path.join(out_dir.name, whl_name)
|
||||
with InWheelCtx(whl_path) as ctx:
|
||||
torchlib_path = os.path.join(ctx._tmpdir.name, "torch", "lib")
|
||||
ctx.out_wheel = tmp_whl_name
|
||||
new_lib_path, new_lib_soname = None, None
|
||||
for filename, _ in elf_file_filter(ctx.iter_files()):
|
||||
if not filename.startswith("torch/lib"):
|
||||
continue
|
||||
libtree = lddtree(filename)
|
||||
if lib_soname not in libtree["needed"]:
|
||||
continue
|
||||
lib_path = libtree["libs"][lib_soname]["path"]
|
||||
if lib_path is None:
|
||||
print(f"Can't embed {lib_soname} as it could not be found")
|
||||
break
|
||||
if lib_path.startswith(torchlib_path):
|
||||
continue
|
||||
|
||||
if new_lib_path is None:
|
||||
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
|
||||
patcher.replace_needed(filename, lib_soname, new_lib_soname)
|
||||
print(f"Replacing {lib_soname} with {new_lib_soname} for {filename}")
|
||||
if update_tag:
|
||||
# Add manylinux2014 tag
|
||||
for filename in ctx.iter_files():
|
||||
if os.path.basename(filename) != "WHEEL":
|
||||
continue
|
||||
replace_tag(filename)
|
||||
shutil.move(tmp_whl_name, whl_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
embed_library(
|
||||
sys.argv[1], "libgomp.so.1", len(sys.argv) > 2 and sys.argv[2] == "--update-tag"
|
||||
)
|
||||
@ -4,14 +4,17 @@ set -ex
|
||||
|
||||
SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||
|
||||
# Source the common build script for architecture-specific configurations (MKLDNN, ACL, etc.)
|
||||
source "${SCRIPTPATH}/../pytorch/build.sh" || true
|
||||
|
||||
case "${GPU_ARCH_TYPE:-BLANK}" in
|
||||
cuda)
|
||||
cuda | cuda-aarch64)
|
||||
bash "${SCRIPTPATH}/build_cuda.sh"
|
||||
;;
|
||||
rocm)
|
||||
bash "${SCRIPTPATH}/build_rocm.sh"
|
||||
;;
|
||||
cpu | cpu-cxx11-abi | cpu-s390x)
|
||||
cpu | cpu-cxx11-abi | cpu-aarch64 | cpu-s390x)
|
||||
bash "${SCRIPTPATH}/build_cpu.sh"
|
||||
;;
|
||||
xpu)
|
||||
|
||||
@ -18,12 +18,31 @@ retry () {
|
||||
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
|
||||
}
|
||||
|
||||
# Detect architecture first
|
||||
ARCH=$(uname -m)
|
||||
echo "Detected architecture: $ARCH"
|
||||
|
||||
PLATFORM=""
|
||||
# TODO move this into the Docker images
|
||||
OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release)
|
||||
if [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
|
||||
retry yum install -q -y zip openssl
|
||||
PLATFORM="manylinux_2_28_x86_64"
|
||||
# Set platform based on architecture
|
||||
case $ARCH in
|
||||
x86_64)
|
||||
PLATFORM="manylinux_2_28_x86_64"
|
||||
;;
|
||||
aarch64)
|
||||
PLATFORM="manylinux_2_28_aarch64"
|
||||
;;
|
||||
s390x)
|
||||
PLATFORM="manylinux_2_28_s390x"
|
||||
;;
|
||||
*)
|
||||
echo "Unsupported architecture: $ARCH"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
|
||||
retry dnf install -q -y zip openssl
|
||||
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
|
||||
@ -38,6 +57,8 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Platform set to: $PLATFORM"
|
||||
|
||||
# We use the package name to test the package by passing this to 'pip install'
|
||||
# This is the env variable that setup.py uses to name the package. Note that
|
||||
# pip 'normalizes' the name first by changing all - to _
|
||||
@ -299,8 +320,8 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
|
||||
# ROCm workaround for roctracer dlopens
|
||||
if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then
|
||||
patchedpath=$(fname_without_so_number $destpath)
|
||||
# Keep the so number for XPU dependencies and libgomp.so.1 to avoid twice load
|
||||
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" ]]; then
|
||||
# Keep the so number for XPU dependencies, libgomp.so.1, ACL libraries, and NVPL libraries to avoid twice load
|
||||
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" || "$filename" == libarm_compute* || "$filename" == libnvpl* || "$filename" == "libgfortran.so.5" ]]; then
|
||||
patchedpath=$destpath
|
||||
else
|
||||
patchedpath=$(fname_with_sha256 $destpath)
|
||||
@ -346,9 +367,22 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
|
||||
done
|
||||
|
||||
# create Manylinux 2_28 tag this needs to happen before regenerate the RECORD
|
||||
if [[ $PLATFORM == "manylinux_2_28_x86_64" && $GPU_ARCH_TYPE != "cpu-s390x" && $GPU_ARCH_TYPE != "xpu" ]]; then
|
||||
# Support all architectures (x86_64, aarch64, s390x)
|
||||
if [[ "$IS_MANYLINUX2_28" == "1" && $GPU_ARCH_TYPE != "xpu" ]]; then
|
||||
wheel_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/WHEEL/g')
|
||||
sed -i -e s#linux_x86_64#"${PLATFORM}"# $wheel_file;
|
||||
echo "Updating wheel tag for $ARCH architecture"
|
||||
# Replace linux_* with manylinux_2_28_* based on architecture
|
||||
case $ARCH in
|
||||
x86_64)
|
||||
sed -i -e 's#linux_x86_64#manylinux_2_28_x86_64#g' $wheel_file
|
||||
;;
|
||||
aarch64)
|
||||
sed -i -e 's#linux_aarch64#manylinux_2_28_aarch64#g' $wheel_file
|
||||
;;
|
||||
s390x)
|
||||
sed -i -e 's#linux_s390x#manylinux_2_28_s390x#g' $wheel_file
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# regenerate the RECORD file with new hashes
|
||||
|
||||
@ -15,6 +15,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS=()
|
||||
fi
|
||||
|
||||
# Detect architecture
|
||||
ARCH=$(uname -m)
|
||||
echo "Building CPU wheel for architecture: $ARCH"
|
||||
|
||||
WHEELHOUSE_DIR="wheelhousecpu"
|
||||
LIBTORCH_HOUSE_DIR="libtorch_housecpu"
|
||||
if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then
|
||||
@ -34,8 +38,10 @@ elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
|
||||
elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
|
||||
LIBGOMP_PATH="/usr/lib64/libgomp.so.1"
|
||||
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
|
||||
if [[ "$(uname -m)" == "s390x" ]]; then
|
||||
if [[ "$ARCH" == "s390x" ]]; then
|
||||
LIBGOMP_PATH="/usr/lib/s390x-linux-gnu/libgomp.so.1"
|
||||
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||
LIBGOMP_PATH="/usr/lib/aarch64-linux-gnu/libgomp.so.1"
|
||||
else
|
||||
LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1"
|
||||
fi
|
||||
@ -49,6 +55,32 @@ DEPS_SONAME=(
|
||||
"libgomp.so.1"
|
||||
)
|
||||
|
||||
# Add ARM-specific library dependencies for CPU builds
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Adding ARM-specific CPU library dependencies"
|
||||
|
||||
# ARM Compute Library (if available)
|
||||
if [[ -d "/acl/build" ]]; then
|
||||
echo "Adding ARM Compute Library for CPU"
|
||||
DEPS_LIST+=(
|
||||
"/acl/build/libarm_compute.so"
|
||||
"/acl/build/libarm_compute_graph.so"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libarm_compute.so"
|
||||
"libarm_compute_graph.so"
|
||||
)
|
||||
fi
|
||||
|
||||
# ARM system libraries
|
||||
DEPS_LIST+=(
|
||||
"/usr/lib64/libgfortran.so.5"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libgfortran.so.5"
|
||||
)
|
||||
fi
|
||||
|
||||
rm -rf /usr/local/cuda*
|
||||
|
||||
SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )"
|
||||
|
||||
@ -29,6 +29,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS=()
|
||||
fi
|
||||
|
||||
# Detect architecture
|
||||
ARCH=$(uname -m)
|
||||
echo "Building for architecture: $ARCH"
|
||||
|
||||
# Determine CUDA version and architectures to build for
|
||||
#
|
||||
# NOTE: We should first check `DESIRED_CUDA` when determining `CUDA_VERSION`,
|
||||
@ -53,34 +57,60 @@ fi
|
||||
cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
|
||||
|
||||
# Function to remove architectures from a list
|
||||
remove_archs() {
|
||||
local result="$1"
|
||||
shift
|
||||
for arch in "$@"; do
|
||||
result="${result//${arch};/}"
|
||||
done
|
||||
echo "$result"
|
||||
}
|
||||
|
||||
# Function to filter CUDA architectures for aarch64
|
||||
# aarch64 ARM GPUs only support certain compute capabilities
|
||||
# Keep: 8.0 (A100), 9.0+ (Hopper, Grace Hopper, newer)
|
||||
# Remove: < 8.0 (no ARM GPUs), 8.6 (x86_64 RTX 3090/A6000 only)
|
||||
filter_aarch64_archs() {
|
||||
local arch_list="$1"
|
||||
# Explicitly remove architectures not needed on aarch64
|
||||
arch_list=$(remove_archs "$arch_list" "5.0" "6.0" "7.0" "7.5" "8.6")
|
||||
echo "$arch_list"
|
||||
}
|
||||
|
||||
# Base: Common architectures across all modern CUDA versions
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0"
|
||||
|
||||
case ${CUDA_VERSION} in
|
||||
#removing sm_50-sm_60 as these architectures are deprecated in CUDA 12.8/9 and will be removed in future releases
|
||||
#however we would like to keep sm_70 architecture see: https://github.com/pytorch/pytorch/issues/157517
|
||||
12.8)
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0"
|
||||
;;
|
||||
12.9)
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0+PTX"
|
||||
# WAR to resolve the ld error in libtorch build with CUDA 12.9
|
||||
12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;${TORCH_CUDA_ARCH_LIST}" ;; # Only 12.6 includes Legacy Maxwell/Pascal that will be removed in future releases
|
||||
12.8) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0" ;; # +Hopper/Blackwell support
|
||||
12.9) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0+PTX" # +Hopper/Blackwell support + PTX for forward compatibility
|
||||
if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX"
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//7.0;/}" # Remove 7.0 to resolve the ld error
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//8.6;/}" # Remove 8.6 for libtorch
|
||||
fi
|
||||
;;
|
||||
13.0)
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX"
|
||||
;;
|
||||
12.6)
|
||||
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0"
|
||||
;;
|
||||
*)
|
||||
echo "unknown cuda version $CUDA_VERSION"
|
||||
exit 1
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;$([[ "$ARCH" == "aarch64" ]] && echo "11.0;" || echo "")12.0+PTX"
|
||||
export TORCH_NVCC_FLAGS="-compress-mode=size"
|
||||
export BUILD_BUNDLE_PTXAS=1
|
||||
;;
|
||||
*) echo "unknown cuda version $CUDA_VERSION"; exit 1 ;;
|
||||
esac
|
||||
|
||||
# Filter for aarch64: Remove < 8.0 and 8.6
|
||||
[[ "$ARCH" == "aarch64" ]] && TORCH_CUDA_ARCH_LIST=$(filter_aarch64_archs "$TORCH_CUDA_ARCH_LIST")
|
||||
|
||||
echo "TORCH_CUDA_ARCH_LIST set to: $TORCH_CUDA_ARCH_LIST"
|
||||
export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}
|
||||
echo "${TORCH_CUDA_ARCH_LIST}"
|
||||
|
||||
# Disable MAGMA for aarch64 as pre-built libraries are x86-64 only
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Disabling MAGMA for aarch64 architecture"
|
||||
export USE_MAGMA=0
|
||||
fi
|
||||
|
||||
# Package directories
|
||||
WHEELHOUSE_DIR="wheelhouse$cuda_version_nodot"
|
||||
LIBTORCH_HOUSE_DIR="libtorch_house$cuda_version_nodot"
|
||||
@ -244,6 +274,51 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Add ARM-specific library dependencies
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Adding ARM-specific library dependencies"
|
||||
|
||||
# ARM Compute Library (if available)
|
||||
if [[ -d "/acl/build" ]]; then
|
||||
echo "Adding ARM Compute Library"
|
||||
DEPS_LIST+=(
|
||||
"/acl/build/libarm_compute.so"
|
||||
"/acl/build/libarm_compute_graph.so"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libarm_compute.so"
|
||||
"libarm_compute_graph.so"
|
||||
)
|
||||
fi
|
||||
|
||||
# ARM system libraries
|
||||
DEPS_LIST+=(
|
||||
"/lib64/libgomp.so.1"
|
||||
"/usr/lib64/libgfortran.so.5"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libgomp.so.1"
|
||||
"libgfortran.so.5"
|
||||
)
|
||||
|
||||
# NVPL libraries (ARM optimized BLAS/LAPACK)
|
||||
if [[ -d "/usr/local/lib" && -f "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ]]; then
|
||||
echo "Adding NVPL libraries for ARM"
|
||||
DEPS_LIST+=(
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0"
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0"
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0"
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libnvpl_lapack_lp64_gomp.so.0"
|
||||
"libnvpl_blas_lp64_gomp.so.0"
|
||||
"libnvpl_lapack_core.so.0"
|
||||
"libnvpl_blas_core.so.0"
|
||||
)
|
||||
fi
|
||||
fi
|
||||
|
||||
# run_tests.sh requires DESIRED_CUDA to know what tests to exclude
|
||||
export DESIRED_CUDA="$cuda_version_nodot"
|
||||
|
||||
@ -251,9 +326,11 @@ export DESIRED_CUDA="$cuda_version_nodot"
|
||||
rm -rf /usr/local/cuda || true
|
||||
ln -s "/usr/local/cuda-${CUDA_VERSION}" /usr/local/cuda
|
||||
|
||||
# Switch `/usr/local/magma` to the desired CUDA version
|
||||
rm -rf /usr/local/magma || true
|
||||
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
|
||||
# Switch `/usr/local/magma` to the desired CUDA version (skip for aarch64)
|
||||
if [[ "$ARCH" != "aarch64" ]]; then
|
||||
rm -rf /usr/local/magma || true
|
||||
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
|
||||
fi
|
||||
|
||||
export CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) # 10.0.130
|
||||
export CUDA_VERSION_SHORT=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev | cut -f1,2 -d".") # 10.0
|
||||
|
||||
@ -86,10 +86,20 @@ else
|
||||
fi
|
||||
fi
|
||||
|
||||
# Enable MKLDNN with ARM Compute Library for ARM builds
|
||||
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
|
||||
export USE_MKLDNN=1
|
||||
|
||||
# ACL is required for aarch64 builds
|
||||
if [[ ! -d "/acl" ]]; then
|
||||
echo "ERROR: ARM Compute Library not found at /acl"
|
||||
echo "ACL is required for aarch64 builds. Check Docker image setup."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export USE_MKLDNN_ACL=1
|
||||
export ACL_ROOT_DIR=/acl
|
||||
echo "ARM Compute Library enabled for MKLDNN: ACL_ROOT_DIR=/acl"
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then
|
||||
|
||||
@ -100,6 +100,337 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _compile_and_extract_symbols(
|
||||
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Helper to compile a C++ file and extract all symbols.
|
||||
|
||||
Args:
|
||||
cpp_content: C++ source code to compile
|
||||
compile_flags: Compilation flags
|
||||
exclude_list: List of symbol names to exclude. Defaults to ["main"].
|
||||
|
||||
Returns:
|
||||
List of all symbols found in the object file (excluding those in exclude_list).
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
if exclude_list is None:
|
||||
exclude_list = ["main"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmppath = Path(tmpdir)
|
||||
cpp_file = tmppath / "test.cpp"
|
||||
obj_file = tmppath / "test.o"
|
||||
|
||||
cpp_file.write_text(cpp_content)
|
||||
|
||||
result = subprocess.run(
|
||||
compile_flags + [str(cpp_file), "-o", str(obj_file)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Compilation failed: {result.stderr}")
|
||||
|
||||
symbols = get_symbols(str(obj_file))
|
||||
|
||||
# Return all symbol names, excluding those in the exclude list
|
||||
return [name for _addr, _stype, name in symbols if name not in exclude_list]
|
||||
|
||||
|
||||
def check_stable_only_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
|
||||
|
||||
This approach tests:
|
||||
1. WITHOUT macros -> many torch symbols exposed
|
||||
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
|
||||
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
|
||||
4. WITH both macros -> zero torch symbols (all hidden)
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
test_cpp_content = """
|
||||
// Main torch C++ API headers
|
||||
#include <torch/torch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// ATen tensor library
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
// Core c10 headers (commonly used)
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
int main() { return 0; }
|
||||
"""
|
||||
|
||||
base_compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c", # Compile only, don't link
|
||||
]
|
||||
|
||||
# Compile WITHOUT any macros
|
||||
symbols_without = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=base_compile_flags,
|
||||
)
|
||||
|
||||
# We expect constexpr symbols, inline functions used by other headers etc.
|
||||
# to produce symbols
|
||||
num_symbols_without = len(symbols_without)
|
||||
print(f"Found {num_symbols_without} symbols without any macros defined")
|
||||
assert num_symbols_without != 0, (
|
||||
"Expected a non-zero number of symbols without any macros"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
|
||||
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
|
||||
|
||||
symbols_with_stable_only = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_stable_only,
|
||||
)
|
||||
|
||||
num_symbols_with_stable_only = len(symbols_with_stable_only)
|
||||
assert num_symbols_with_stable_only == 0, (
|
||||
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
|
||||
compile_flags_with_target_version = base_compile_flags + [
|
||||
"-DTORCH_TARGET_VERSION=1"
|
||||
]
|
||||
|
||||
symbols_with_target_version = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_target_version,
|
||||
)
|
||||
|
||||
num_symbols_with_target_version = len(symbols_with_target_version)
|
||||
assert num_symbols_with_target_version == 0, (
|
||||
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
|
||||
)
|
||||
|
||||
# Compile WITH both macros (expect 0 symbols)
|
||||
compile_flags_with_both = base_compile_flags + [
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
"-DTORCH_TARGET_VERSION=1",
|
||||
]
|
||||
|
||||
symbols_with_both = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_both,
|
||||
)
|
||||
|
||||
num_symbols_with_both = len(symbols_with_both)
|
||||
assert num_symbols_with_both == 0, (
|
||||
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_api_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
stable_dir = include_dir / "torch" / "csrc" / "stable"
|
||||
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
|
||||
|
||||
stable_headers = list(stable_dir.rglob("*.h"))
|
||||
if not stable_headers:
|
||||
raise RuntimeError("Could not find any stable headers")
|
||||
|
||||
includes = []
|
||||
for header in stable_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_stable_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable = len(symbols_stable)
|
||||
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
|
||||
assert num_symbols_stable > 0, (
|
||||
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_headeronly_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Find all headers in torch/headeronly
|
||||
headeronly_dir = include_dir / "torch" / "headeronly"
|
||||
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
|
||||
headeronly_headers = list(headeronly_dir.rglob("*.h"))
|
||||
if not headeronly_headers:
|
||||
raise RuntimeError("Could not find any headeronly headers")
|
||||
|
||||
# Filter out platform-specific headers that may not compile everywhere
|
||||
platform_specific_keywords = [
|
||||
"cpu/vec",
|
||||
]
|
||||
|
||||
filtered_headers = []
|
||||
for header in headeronly_headers:
|
||||
rel_path = header.relative_to(include_dir).as_posix()
|
||||
if not any(
|
||||
keyword in rel_path.lower() for keyword in platform_specific_keywords
|
||||
):
|
||||
filtered_headers.append(header)
|
||||
|
||||
includes = []
|
||||
for header in filtered_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_headeronly_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_headeronly = _compile_and_extract_symbols(
|
||||
cpp_content=test_headeronly_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_headeronly = len(symbols_headeronly)
|
||||
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
|
||||
assert num_symbols_headeronly > 0, (
|
||||
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_headeronly} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_aoti_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_shim_content = """
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
int main() {
|
||||
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
|
||||
int32_t (*fp2)() = &aoti_torch_dtype_float32;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_shim = len(symbols_shim)
|
||||
assert num_symbols_shim > 0, (
|
||||
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_c_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Check if the stable C shim exists
|
||||
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
|
||||
if not stable_shim.exists():
|
||||
raise RuntimeError("Could not find stable c shim")
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_stable_shim_content = """
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
int main() {
|
||||
// Reference stable C API functions to create undefined symbols
|
||||
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
|
||||
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable_shim = len(symbols_stable_shim)
|
||||
assert num_symbols_stable_shim > 0, (
|
||||
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
|
||||
print(f"lib: {lib}")
|
||||
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
|
||||
@ -129,6 +460,13 @@ def main() -> None:
|
||||
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
|
||||
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
|
||||
|
||||
# Check symbols when TORCH_STABLE_ONLY is defined
|
||||
check_stable_only_symbols(install_root)
|
||||
check_stable_api_symbols(install_root)
|
||||
check_headeronly_symbols(install_root)
|
||||
check_aoti_shim_symbols(install_root)
|
||||
check_stable_c_shim_symbols(install_root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -353,6 +353,17 @@ def test_linalg(device="cpu") -> None:
|
||||
torch.linalg.svd(A)
|
||||
|
||||
|
||||
def test_sdpa(device="cpu", dtype=torch.float16) -> None:
|
||||
"""Regression test for https://github.com/pytorch/pytorch/issues/167602
|
||||
Without nvrtc_builtins on CuDNN-9.13 on CUDA-13 fails with ` No valid execution plans built.`
|
||||
"""
|
||||
print(f"Testing SDPA on {device} using type {dtype}")
|
||||
k, q, v = torch.rand(3, 1, 16, 77, 64, dtype=dtype, device=device).unbind(0)
|
||||
attn = torch.rand(1, 1, 77, 77, dtype=dtype, device=device)
|
||||
rc = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn)
|
||||
assert rc.isnan().any().item() is False
|
||||
|
||||
|
||||
def smoke_test_compile(device: str = "cpu") -> None:
|
||||
supported_dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
|
||||
@ -489,10 +500,12 @@ def main() -> None:
|
||||
smoke_test_conv2d()
|
||||
test_linalg()
|
||||
test_numpy()
|
||||
test_sdpa()
|
||||
|
||||
if is_cuda_system:
|
||||
test_linalg("cuda")
|
||||
test_cuda_gds_errors_captured()
|
||||
test_sdpa("cuda")
|
||||
|
||||
if options.package == "all":
|
||||
smoke_test_modules()
|
||||
|
||||
@ -1680,6 +1680,22 @@ test_operator_microbenchmark() {
|
||||
done
|
||||
}
|
||||
|
||||
test_attention_microbenchmark() {
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
TEST_DIR=$(pwd)
|
||||
|
||||
# Install attention-gym dependency
|
||||
echo "Installing attention-gym..."
|
||||
python -m pip install git+https://github.com/meta-pytorch/attention-gym.git@main
|
||||
pip show triton
|
||||
|
||||
cd "${TEST_DIR}"/benchmarks/transformer
|
||||
|
||||
$TASKSET python score_mod.py --config configs/config_basic.yaml \
|
||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/attention_microbenchmark.json"
|
||||
}
|
||||
|
||||
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
||||
(cd test && python -c "import torch; print(torch.__config__.show())")
|
||||
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
|
||||
@ -1737,6 +1753,8 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then
|
||||
test_operator_microbenchmark
|
||||
elif [[ "${TEST_CONFIG}" == *attention_microbenchmark* ]]; then
|
||||
test_attention_microbenchmark
|
||||
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
|
||||
2
.github/scripts/generate_pytorch_version.py
vendored
2
.github/scripts/generate_pytorch_version.py
vendored
@ -50,7 +50,7 @@ def get_tag() -> str:
|
||||
|
||||
def get_base_version() -> str:
|
||||
root = get_pytorch_root()
|
||||
dirty_version = open(root / "version.txt").read().strip()
|
||||
dirty_version = Path(root / "version.txt").read_text().strip()
|
||||
# Strips trailing a0 from version.txt, not too sure why it's there in the
|
||||
# first place
|
||||
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
|
||||
|
||||
7
.github/workflows/_binary-build-linux.yml
vendored
7
.github/workflows/_binary-build-linux.yml
vendored
@ -260,11 +260,8 @@ jobs:
|
||||
"${DOCKER_IMAGE}"
|
||||
)
|
||||
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
|
||||
if [[ ${BUILD_ENVIRONMENT} == *"aarch64"* ]]; then
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/aarch64_linux/aarch64_ci_build.sh"
|
||||
else
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
|
||||
fi
|
||||
# Unified build script for all architectures (x86_64, aarch64, s390x)
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
|
||||
|
||||
- name: Chown artifacts
|
||||
if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }}
|
||||
|
||||
73
.github/workflows/attention_op_microbenchmark.yml
vendored
Normal file
73
.github/workflows/attention_op_microbenchmark.yml
vendored
Normal file
@ -0,0 +1,73 @@
|
||||
name: attention_op_microbenchmark
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/op-benchmark/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
# Run at 06:00 UTC everyday
|
||||
- cron: 0 7 * * *
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
attn-microbenchmark-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '8.0 9.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
|
||||
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
attn-microbenchmark-test:
|
||||
name: attn-microbenchmark-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: attn-microbenchmark-build
|
||||
with:
|
||||
timeout-minutes: 500
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
# B200 runner
|
||||
opmicrobenchmark-build-b200:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: opmicrobenchmark-build-b200
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
opmicrobenchmark-test-b200:
|
||||
name: opmicrobenchmark-test-b200
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: opmicrobenchmark-build-b200
|
||||
with:
|
||||
timeout-minutes: 500
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
||||
@ -94,6 +94,11 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
||||
c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
||||
}
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -18,6 +18,8 @@
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace torch {
|
||||
class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {};
|
||||
namespace jit {
|
||||
@ -1630,4 +1632,6 @@ struct TORCH_API WeakOrStrongTypePtr {
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
#include <ATen/core/ivalue_inl.h> // IWYU pragma: keep
|
||||
|
||||
@ -29,6 +29,8 @@
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
struct Function;
|
||||
@ -2567,3 +2569,5 @@ TypePtr IValue::type() const {
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -11,6 +11,8 @@
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
// Sleef offers vectorized versions of some transcedentals
|
||||
// such as sin, cos, tan etc..
|
||||
// However for now opting for STL, since we are not building
|
||||
@ -650,3 +652,5 @@ inline Vectorized<float> Vectorized<float>::erf() const {
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#include <ATen/cuda/CUDAGraph.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/MemPool.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
|
||||
@ -13,7 +14,7 @@ static bool _cuda_graphs_debug = false;
|
||||
MempoolId_t graph_pool_handle() {
|
||||
// Sets just the second value, to distinguish it from MempoolId_ts created from
|
||||
// cudaStreamGetCaptureInfo id_s in capture_begin.
|
||||
return c10::cuda::MemPool::graph_pool_handle();
|
||||
return at::cuda::MemPool::graph_pool_handle();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -90,7 +91,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
||||
} else {
|
||||
// User did not ask us to share a mempool. Create graph pool handle using is_user_created=false.
|
||||
// Sets just the first value, to distinguish it from MempoolId_ts created by graph_pool_handle().
|
||||
mempool_id_ = c10::cuda::MemPool::graph_pool_handle(false);
|
||||
mempool_id_ = at::cuda::MemPool::graph_pool_handle(false);
|
||||
TORCH_INTERNAL_ASSERT(mempool_id_.first > 0);
|
||||
}
|
||||
|
||||
|
||||
69
aten/src/ATen/cuda/MemPool.cpp
Normal file
69
aten/src/ATen/cuda/MemPool.cpp
Normal file
@ -0,0 +1,69 @@
|
||||
#include <ATen/core/CachingHostAllocator.h>
|
||||
#include <ATen/cuda/MemPool.h>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
// uid_ is incremented when a user creates a MemPool,
|
||||
// for example: using graph_pool_handle() or c10::cuda::MemPool().
|
||||
//
|
||||
// uuid_ is incremented when CUDAGraph creates a MemPool
|
||||
// as a result of a user not providing a pool.
|
||||
//
|
||||
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
|
||||
// passed to a function, either by user or CUDAGraphs. For example,
|
||||
// default value of MempoolId_t for capture_begin function is {0, 0}.
|
||||
// That's why uid_ and uuid_ start at 1.
|
||||
std::atomic<CaptureId_t> MemPool::uid_{1};
|
||||
std::atomic<CaptureId_t> MemPool::uuid_{1};
|
||||
|
||||
MemPool::MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created,
|
||||
bool use_on_oom)
|
||||
: allocator_(allocator), is_user_created_(is_user_created) {
|
||||
if (is_user_created_) {
|
||||
id_ = {0, uid_++};
|
||||
} else {
|
||||
id_ = {uuid_++, 0};
|
||||
}
|
||||
device_ = c10::cuda::current_device();
|
||||
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
|
||||
if (use_on_oom) {
|
||||
CUDACachingAllocator::setUseOnOOM(device_, id_);
|
||||
}
|
||||
}
|
||||
|
||||
MemPool::~MemPool() {
|
||||
// TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
// We used to assert that TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
// However, this assertion is not true if a memory pool is shared
|
||||
// with a cuda graph. That CUDAGraph will increase the use count
|
||||
// until it is reset.
|
||||
CUDACachingAllocator::releasePool(device_, id_);
|
||||
c10::cuda::CUDACachingAllocator::emptyCache(id_);
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::id() {
|
||||
return id_;
|
||||
}
|
||||
|
||||
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
int MemPool::use_count() {
|
||||
return CUDACachingAllocator::getPoolUseCount(device_, id_);
|
||||
}
|
||||
|
||||
c10::DeviceIndex MemPool::device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
|
||||
if (is_user_created) {
|
||||
return {0, uid_++};
|
||||
}
|
||||
return {uuid_++, 0};
|
||||
}
|
||||
|
||||
} // namespace at::cuda
|
||||
44
aten/src/ATen/cuda/MemPool.h
Normal file
44
aten/src/ATen/cuda/MemPool.h
Normal file
@ -0,0 +1,44 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
// Keep BC only
|
||||
using c10::CaptureId_t;
|
||||
using c10::MempoolId_t;
|
||||
|
||||
// MemPool represents a pool of memory in a caching allocator. Currently,
|
||||
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
||||
//
|
||||
// An allocator pointer can be passed to the MemPool to define how the
|
||||
// allocations should be done in the pool. For example: using a different
|
||||
// system allocator such as ncclMemAlloc.
|
||||
struct TORCH_CUDA_CPP_API MemPool {
|
||||
MemPool(
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
|
||||
bool is_user_created = true,
|
||||
bool use_on_oom = false);
|
||||
MemPool(const MemPool&) = delete;
|
||||
MemPool(MemPool&&) = default;
|
||||
MemPool& operator=(const MemPool&) = delete;
|
||||
MemPool& operator=(MemPool&&) = default;
|
||||
~MemPool();
|
||||
|
||||
MempoolId_t id();
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator();
|
||||
int use_count();
|
||||
c10::DeviceIndex device();
|
||||
static MempoolId_t graph_pool_handle(bool is_user_created = true);
|
||||
|
||||
private:
|
||||
static std::atomic<CaptureId_t> uid_;
|
||||
static std::atomic<CaptureId_t> uuid_;
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator_;
|
||||
bool is_user_created_;
|
||||
MempoolId_t id_;
|
||||
c10::DeviceIndex device_;
|
||||
};
|
||||
|
||||
} // namespace at::cuda
|
||||
@ -213,7 +213,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
scan_op,
|
||||
num_items,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
C10_HIP_KERNEL_LAUNCH_CHECK();
|
||||
#else
|
||||
// non synchronizing cub call
|
||||
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
|
||||
@ -471,7 +471,7 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
init_value,
|
||||
num_items,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
C10_HIP_KERNEL_LAUNCH_CHECK();
|
||||
#else
|
||||
// non synchronizing cub call
|
||||
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
|
||||
|
||||
239
aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
Normal file
239
aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
Normal file
@ -0,0 +1,239 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/hip/HIPCachingAllocator.h>
|
||||
|
||||
// Use of c10::hip namespace here makes hipification easier, because
|
||||
// I don't have to also fix namespaces. Sorry!
|
||||
namespace c10::hip {
|
||||
|
||||
// Takes a valid HIPAllocator (of any sort) and turns it into
|
||||
// an allocator pretending to be a CUDA allocator. See
|
||||
// Note [Masquerading as CUDA]
|
||||
class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllocator {
|
||||
HIPCachingAllocator::HIPAllocator* allocator_;
|
||||
public:
|
||||
explicit HIPAllocatorMasqueradingAsCUDA(HIPCachingAllocator::HIPAllocator* allocator)
|
||||
: allocator_(allocator) {}
|
||||
|
||||
virtual ~HIPAllocatorMasqueradingAsCUDA() = default;
|
||||
|
||||
// From c10::Allocator
|
||||
|
||||
DataPtr allocate(size_t size) override {
|
||||
DataPtr r = allocator_->allocate(size);
|
||||
r.unsafe_set_device(Device(c10::DeviceType::CUDA, r.device().index()));
|
||||
return r;
|
||||
}
|
||||
|
||||
bool is_simple_data_ptr(const DataPtr& data_ptr) const override {
|
||||
return allocator_->is_simple_data_ptr(data_ptr);
|
||||
}
|
||||
|
||||
DeleterFnPtr raw_deleter() const override {
|
||||
return allocator_->raw_deleter();
|
||||
}
|
||||
|
||||
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
||||
allocator_->copy_data(dest, src, count);
|
||||
}
|
||||
|
||||
// From DeviceAllocator
|
||||
|
||||
bool initialized() override {
|
||||
return allocator_->initialized();
|
||||
}
|
||||
|
||||
void emptyCache(MempoolId_t mempool_id = {0, 0}) override {
|
||||
allocator_->emptyCache(mempool_id);
|
||||
}
|
||||
|
||||
void recordStream(const DataPtr& ptr, c10::Stream stream) override {
|
||||
HIPStream hip_stream = HIPStream(stream);
|
||||
recordStream(ptr, hip_stream);
|
||||
}
|
||||
|
||||
CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) override {
|
||||
return allocator_->getDeviceStats(device);
|
||||
}
|
||||
|
||||
void resetAccumulatedStats(c10::DeviceIndex device) override {
|
||||
allocator_->resetAccumulatedStats(device);
|
||||
}
|
||||
|
||||
void resetPeakStats(c10::DeviceIndex device) override {
|
||||
allocator_->resetPeakStats(device);
|
||||
}
|
||||
|
||||
// From CUDAAllocator
|
||||
|
||||
void* raw_alloc(size_t nbytes) override {
|
||||
return allocator_->raw_alloc(nbytes);
|
||||
}
|
||||
|
||||
void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) override {
|
||||
return allocator_->raw_alloc_with_stream(nbytes, stream);
|
||||
}
|
||||
|
||||
void raw_delete(void* ptr) override {
|
||||
allocator_->raw_delete(ptr);
|
||||
}
|
||||
|
||||
void init(int device_count) override {
|
||||
allocator_->init(device_count);
|
||||
}
|
||||
|
||||
double getMemoryFraction(c10::DeviceIndex device) override {
|
||||
return allocator_->getMemoryFraction(device);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
|
||||
allocator_->setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
std::vector<HIPCachingAllocator::StreamSegmentSize> getExpandableSegmentSizes(c10::DeviceIndex device) override {
|
||||
return allocator_->getExpandableSegmentSizes(device);
|
||||
}
|
||||
|
||||
void enable(bool value) override {
|
||||
allocator_->enable(value);
|
||||
}
|
||||
|
||||
bool isEnabled() const override {
|
||||
return allocator_->isEnabled();
|
||||
}
|
||||
|
||||
void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override {
|
||||
allocator_->cacheInfo(device, largestBlock);
|
||||
}
|
||||
|
||||
void* getBaseAllocation(void* ptr, size_t* size) override {
|
||||
return allocator_->getBaseAllocation(ptr, size);
|
||||
}
|
||||
|
||||
void recordStream(const DataPtr& ptr, HIPStream stream) override {
|
||||
allocator_->recordStream(ptr, stream);
|
||||
}
|
||||
|
||||
HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) override {
|
||||
return allocator_->snapshot(mempool_id);
|
||||
}
|
||||
|
||||
void beginAllocateToPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
std::function<bool(hipStream_t)> filter) override {
|
||||
allocator_->beginAllocateToPool(device, mempool_id, filter);
|
||||
}
|
||||
|
||||
void endAllocateToPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id) override {
|
||||
allocator_->endAllocateToPool(device, mempool_id);
|
||||
}
|
||||
|
||||
void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
|
||||
allocator_->releasePool(device, mempool_id);
|
||||
}
|
||||
|
||||
int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) override {
|
||||
return allocator_->getPoolUseCount(device, mempool_id);
|
||||
}
|
||||
|
||||
void createOrIncrefPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
HIPAllocator* allocator = nullptr) override {
|
||||
allocator_->createOrIncrefPool(device, mempool_id, allocator);
|
||||
}
|
||||
|
||||
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
|
||||
allocator_->setUseOnOOM(device, mempool_id);
|
||||
}
|
||||
|
||||
bool checkPoolLiveAllocations(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
const std::unordered_set<void*>& expected_live_allocations) override {
|
||||
return allocator_->checkPoolLiveAllocations(device, mempool_id, expected_live_allocations);
|
||||
}
|
||||
|
||||
HIPCachingAllocator::ShareableHandle shareIpcHandle(void* ptr) override {
|
||||
return allocator_->shareIpcHandle(ptr);
|
||||
}
|
||||
|
||||
std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
|
||||
return allocator_->getIpcDevPtr(handle);
|
||||
}
|
||||
|
||||
bool isHistoryEnabled() override {
|
||||
return allocator_->isHistoryEnabled();
|
||||
}
|
||||
|
||||
void recordHistory(
|
||||
bool enabled,
|
||||
HIPCachingAllocator::CreateContextFn context_recorder,
|
||||
size_t alloc_trace_max_entries,
|
||||
HIPCachingAllocator::RecordContext when,
|
||||
bool clearHistory) override {
|
||||
allocator_->recordHistory(enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
|
||||
}
|
||||
|
||||
void recordAnnotation(
|
||||
const std::vector<std::pair<std::string, std::string>>& md) override {
|
||||
allocator_->recordAnnotation(md);
|
||||
}
|
||||
|
||||
void pushCompileContext(std::string& md) override {
|
||||
allocator_->pushCompileContext(md);
|
||||
}
|
||||
|
||||
void popCompileContext() override {
|
||||
allocator_->popCompileContext();
|
||||
}
|
||||
|
||||
void attachOutOfMemoryObserver(HIPCachingAllocator::OutOfMemoryObserver observer) override {
|
||||
allocator_->attachOutOfMemoryObserver(observer);
|
||||
}
|
||||
|
||||
void attachAllocatorTraceTracker(HIPCachingAllocator::AllocatorTraceTracker tracker) override {
|
||||
allocator_->attachAllocatorTraceTracker(tracker);
|
||||
}
|
||||
|
||||
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) override {
|
||||
allocator_->enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
hipError_t memcpyAsync(
|
||||
void* dst,
|
||||
int dstDevice,
|
||||
const void* src,
|
||||
int srcDevice,
|
||||
size_t count,
|
||||
hipStream_t stream,
|
||||
bool p2p_enabled) override {
|
||||
return allocator_->memcpyAsync(dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
|
||||
}
|
||||
|
||||
std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t id) override {
|
||||
return allocator_->getCheckpointState(device, id);
|
||||
}
|
||||
|
||||
HIPCachingAllocator::CheckpointDelta setCheckpointPoolState(
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) override {
|
||||
auto cpd = allocator_->setCheckpointPoolState(device, pps);
|
||||
for (auto& ptr : cpd.dataptrs_allocd) {
|
||||
ptr.unsafe_set_device(Device(c10::DeviceType::CUDA, ptr.device().index()));
|
||||
}
|
||||
return cpd;
|
||||
}
|
||||
|
||||
std::string name() override {
|
||||
return allocator_->name();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace c10::hip
|
||||
@ -0,0 +1,18 @@
|
||||
#include <c10/hip/HIPCachingAllocator.h>
|
||||
#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
|
||||
|
||||
namespace c10 { namespace hip {
|
||||
namespace HIPCachingAllocatorMasqueradingAsCUDA {
|
||||
|
||||
HIPCachingAllocator::HIPAllocator* get() {
|
||||
static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get());
|
||||
return &allocator;
|
||||
}
|
||||
|
||||
void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsCUDA stream) {
|
||||
HIPCachingAllocator::recordStream(ptr, stream.hip_stream());
|
||||
}
|
||||
|
||||
} // namespace HIPCachingAllocatorMasqueradingAsCUDA
|
||||
}} // namespace c10::hip
|
||||
194
aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h
Normal file
194
aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h
Normal file
@ -0,0 +1,194 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/hip/HIPCachingAllocator.h>
|
||||
#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
|
||||
namespace c10 {
|
||||
// forward declaration
|
||||
class DataPtr;
|
||||
namespace hip {
|
||||
namespace HIPCachingAllocatorMasqueradingAsCUDA {
|
||||
|
||||
C10_HIP_API HIPCachingAllocator::HIPAllocator* get();
|
||||
C10_HIP_API void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsCUDA stream);
|
||||
|
||||
inline void* raw_alloc(size_t nbytes) {
|
||||
return get()->raw_alloc(nbytes);
|
||||
}
|
||||
|
||||
inline void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) {
|
||||
return get()->raw_alloc_with_stream(nbytes, stream);
|
||||
}
|
||||
|
||||
inline void raw_delete(void* ptr) {
|
||||
return get()->raw_delete(ptr);
|
||||
}
|
||||
|
||||
inline void init(int device_count) {
|
||||
return get()->init(device_count);
|
||||
}
|
||||
|
||||
inline double getMemoryFraction(c10::DeviceIndex device) {
|
||||
return get()->getMemoryFraction(device);
|
||||
}
|
||||
|
||||
inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
|
||||
return get()->setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
|
||||
return get()->emptyCache(mempool_id);
|
||||
}
|
||||
|
||||
inline void enable(bool value) {
|
||||
return get()->enable(value);
|
||||
}
|
||||
|
||||
inline bool isEnabled() {
|
||||
return get()->isEnabled();
|
||||
}
|
||||
|
||||
inline void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) {
|
||||
return get()->cacheInfo(device, largestBlock);
|
||||
}
|
||||
|
||||
inline void* getBaseAllocation(void* ptr, size_t* size) {
|
||||
return get()->getBaseAllocation(ptr, size);
|
||||
}
|
||||
|
||||
inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
|
||||
c10::DeviceIndex device) {
|
||||
return get()->getDeviceStats(device);
|
||||
}
|
||||
|
||||
inline void resetAccumulatedStats(c10::DeviceIndex device) {
|
||||
return get()->resetAccumulatedStats(device);
|
||||
}
|
||||
|
||||
inline void resetPeakStats(c10::DeviceIndex device) {
|
||||
return get()->resetPeakStats(device);
|
||||
}
|
||||
|
||||
inline HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) {
|
||||
return get()->snapshot(mempool_id);
|
||||
}
|
||||
|
||||
inline std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t id) {
|
||||
return get()->getCheckpointState(device, id);
|
||||
}
|
||||
|
||||
inline HIPCachingAllocator::CheckpointDelta setCheckpointPoolState(
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) {
|
||||
return get()->setCheckpointPoolState(device, std::move(pps));
|
||||
}
|
||||
|
||||
inline void beginAllocateToPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
std::function<bool(hipStream_t)> filter) {
|
||||
get()->beginAllocateToPool(device, mempool_id, std::move(filter));
|
||||
}
|
||||
|
||||
inline void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
get()->endAllocateToPool(device, mempool_id);
|
||||
}
|
||||
|
||||
inline void recordHistory(
|
||||
bool enabled,
|
||||
HIPCachingAllocator::CreateContextFn context_recorder,
|
||||
size_t alloc_trace_max_entries,
|
||||
HIPCachingAllocator::RecordContext when,
|
||||
bool clearHistory) {
|
||||
return get()->recordHistory(
|
||||
enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
|
||||
}
|
||||
|
||||
inline void recordAnnotation(
|
||||
const std::vector<std::pair<std::string, std::string>>& md) {
|
||||
return get()->recordAnnotation(md);
|
||||
}
|
||||
|
||||
inline void pushCompileContext(std::string& md) {
|
||||
return get()->pushCompileContext(md);
|
||||
}
|
||||
|
||||
inline void popCompileContext() {
|
||||
return get()->popCompileContext();
|
||||
}
|
||||
|
||||
inline bool isHistoryEnabled() {
|
||||
return get()->isHistoryEnabled();
|
||||
}
|
||||
|
||||
inline bool checkPoolLiveAllocations(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
const std::unordered_set<void*>& expected_live_allocations) {
|
||||
return get()->checkPoolLiveAllocations(
|
||||
device, mempool_id, expected_live_allocations);
|
||||
}
|
||||
|
||||
inline void attachOutOfMemoryObserver(HIPCachingAllocator::OutOfMemoryObserver observer) {
|
||||
return get()->attachOutOfMemoryObserver(std::move(observer));
|
||||
}
|
||||
|
||||
inline void attachAllocatorTraceTracker(HIPCachingAllocator::AllocatorTraceTracker tracker) {
|
||||
return get()->attachAllocatorTraceTracker(std::move(tracker));
|
||||
}
|
||||
|
||||
inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
return get()->releasePool(device, mempool_id);
|
||||
}
|
||||
|
||||
inline void createOrIncrefPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
HIPCachingAllocator::HIPAllocator* allocator_ptr = nullptr) {
|
||||
get()->createOrIncrefPool(device, mempool_id, allocator_ptr);
|
||||
}
|
||||
|
||||
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
get()->setUseOnOOM(device, mempool_id);
|
||||
}
|
||||
|
||||
inline int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
return get()->getPoolUseCount(device, mempool_id);
|
||||
}
|
||||
|
||||
inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
|
||||
return get()->getIpcDevPtr(std::move(handle));
|
||||
}
|
||||
|
||||
inline HIPCachingAllocator::ShareableHandle shareIpcHandle(void* ptr) {
|
||||
return get()->shareIpcHandle(ptr);
|
||||
}
|
||||
|
||||
inline std::string name() {
|
||||
return get()->name();
|
||||
}
|
||||
|
||||
inline hipError_t memcpyAsync(
|
||||
void* dst,
|
||||
int dstDevice,
|
||||
const void* src,
|
||||
int srcDevice,
|
||||
size_t count,
|
||||
hipStream_t stream,
|
||||
bool p2p_enabled) {
|
||||
return get()->memcpyAsync(
|
||||
dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
|
||||
}
|
||||
|
||||
inline void enablePeerAccess(
|
||||
c10::DeviceIndex dev,
|
||||
c10::DeviceIndex dev_to_access) {
|
||||
return get()->enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
} // namespace HIPCachingAllocatorMasqueradingAsCUDA
|
||||
} // namespace hip
|
||||
} // namespace c10
|
||||
14
aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.cpp
Normal file
14
aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
|
||||
// THIS IS A MASSIVE HACK. This will BREAK you Caffe2 CUDA code if you
|
||||
// load ATen_hip, even if you don't ever actually use ATen_hip at runtime.
|
||||
//
|
||||
// If you ever link ATen_hip statically into the full library along
|
||||
// with ATen_cuda (libomnibus), the loading order of this versus the regular
|
||||
// ATen_cuda will be nondeterministic, and you'll nondeterministically get
|
||||
// one or the other. (This will be obvious because all of your code
|
||||
// will fail.)
|
||||
//
|
||||
// This hack can be removed once PyTorch is out-of-place HIPified, and
|
||||
// doesn't pretend CUDA is HIP.
|
||||
C10_REGISTER_GUARD_IMPL(CUDA, at::cuda::HIPGuardImplMasqueradingAsCUDA)
|
||||
383
aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h
Normal file
383
aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h
Normal file
@ -0,0 +1,383 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/hip/HIPConfig.h>
|
||||
|
||||
// The includes of HIPGuard.h
|
||||
#include <c10/hip/impl/HIPGuardImpl.h>
|
||||
#include <c10/hip/HIPMacros.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/impl/InlineDeviceGuard.h>
|
||||
#include <c10/core/impl/InlineStreamGuard.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <c10/hip/impl/HIPGuardImpl.h>
|
||||
|
||||
#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
|
||||
// Use of c10::hip namespace here makes hipification easier, because
|
||||
// I don't have to also fix namespaces. Sorry!
|
||||
namespace c10 { namespace hip {
|
||||
|
||||
// Note [Masquerading as CUDA]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// c10_hip is very easy to understand: it is HIPified from c10_cuda,
|
||||
// and anywhere you said CUDA, the source code now says HIP. HIPified
|
||||
// PyTorch is much harder to understand: it is HIPified from regular
|
||||
// PyTorch, yes, but NO source-to-source translation from CUDA to
|
||||
// HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP".
|
||||
// For example, when you use HIPified PyTorch, you say x.cuda() to
|
||||
// move a tensor onto ROCm device. We call this situation "HIP
|
||||
// masquerading as CUDA".
|
||||
//
|
||||
// This leads to a very awkward situation when we want to call c10_hip
|
||||
// code from PyTorch, since c10_hip is expecting things to be called
|
||||
// HIP, but PyTorch is calling them CUDA (masquerading as HIP). To
|
||||
// fix this impedance mismatch, we have MasqueradingAsCUDA variants
|
||||
// for all c10_hip classes. These translate between the "HIP" and "CUDA
|
||||
// masquerading as HIP" worlds. For example,
|
||||
// HIPGuardImplMasqueradingAsCUDA (this file) provides something like a
|
||||
// HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type()
|
||||
// returns CUDA, getDevice() reports the current HIP device as a CUDA
|
||||
// device.)
|
||||
//
|
||||
// We should be able to delete all of these classes entirely once
|
||||
// we switch PyTorch to calling a HIP a HIP.
|
||||
//
|
||||
// When you add a new MasqueradingAsCUDA class/function, you need to
|
||||
// also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py
|
||||
//
|
||||
//
|
||||
//
|
||||
// By the way, note that the cpp file associated with this also
|
||||
// *overwrites* the entry in the DeviceGuardImpl registry for CUDA with
|
||||
// this HIP implementation.
|
||||
|
||||
struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface {
|
||||
static constexpr c10::DeviceType static_type = c10::DeviceType::CUDA;
|
||||
HIPGuardImplMasqueradingAsCUDA() {}
|
||||
HIPGuardImplMasqueradingAsCUDA(c10::DeviceType t) {
|
||||
TORCH_INTERNAL_ASSERT(t == c10::DeviceType::CUDA);
|
||||
}
|
||||
c10::DeviceType type() const override {
|
||||
return c10::DeviceType::CUDA;
|
||||
}
|
||||
Device exchangeDevice(Device d) const override {
|
||||
TORCH_INTERNAL_ASSERT(d.is_cuda());
|
||||
Device old_device = getDevice();
|
||||
if (old_device.index() != d.index()) {
|
||||
C10_HIP_CHECK(hipSetDevice(d.index()));
|
||||
}
|
||||
return old_device;
|
||||
}
|
||||
Device getDevice() const override {
|
||||
int device;
|
||||
C10_HIP_CHECK(hipGetDevice(&device));
|
||||
return Device(c10::DeviceType::CUDA, device);
|
||||
}
|
||||
void setDevice(Device d) const override {
|
||||
TORCH_INTERNAL_ASSERT(d.is_cuda());
|
||||
C10_HIP_CHECK(hipSetDevice(d.index()));
|
||||
}
|
||||
void uncheckedSetDevice(Device d) const noexcept override {
|
||||
C10_HIP_CHECK_WARN(hipSetDevice(d.index()));
|
||||
}
|
||||
Stream getStream(Device d) const override {
|
||||
return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap();
|
||||
}
|
||||
Stream getDefaultStream(Device d) const override {
|
||||
return getDefaultHIPStreamMasqueradingAsCUDA(d.index());
|
||||
}
|
||||
Stream getNewStream(Device d, int priority = 0) const override {
|
||||
return getStreamFromPoolMasqueradingAsCUDA(priority, d.index());
|
||||
}
|
||||
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override {
|
||||
return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index());
|
||||
}
|
||||
Stream exchangeStream(Stream s) const override {
|
||||
HIPStreamMasqueradingAsCUDA cs(s);
|
||||
auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index());
|
||||
setCurrentHIPStreamMasqueradingAsCUDA(cs);
|
||||
return old_stream.unwrap();
|
||||
}
|
||||
DeviceIndex deviceCount() const noexcept override {
|
||||
int deviceCnt;
|
||||
hipError_t _err;
|
||||
_err = hipGetDeviceCount(&deviceCnt);
|
||||
if(_err != hipErrorNoDevice && _err != hipSuccess)
|
||||
C10_HIP_CHECK(_err);
|
||||
return deviceCnt;
|
||||
}
|
||||
|
||||
// Event-related functions
|
||||
// Note: hipEventCreateWithFlags should be called on the same device as
|
||||
// the recording stream's device.
|
||||
void createEvent(
|
||||
hipEvent_t* hip_event,
|
||||
const EventFlag flag) const {
|
||||
// Maps PyTorch's Event::Flag to HIP flag
|
||||
auto hip_flag = hipEventDefault;
|
||||
switch (flag) {
|
||||
case EventFlag::PYTORCH_DEFAULT:
|
||||
hip_flag = hipEventDisableTiming;
|
||||
break;
|
||||
case EventFlag::BACKEND_DEFAULT:
|
||||
hip_flag = hipEventDefault;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "HIP event received unknown flag");
|
||||
}
|
||||
|
||||
C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag));
|
||||
}
|
||||
|
||||
void destroyEvent(
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept override {
|
||||
if (!event) return;
|
||||
auto hip_event = static_cast<hipEvent_t>(event);
|
||||
int orig_device;
|
||||
C10_HIP_CHECK_WARN(hipGetDevice(&orig_device));
|
||||
C10_HIP_CHECK_WARN(hipSetDevice(device_index));
|
||||
C10_HIP_CHECK_WARN(hipEventDestroy(hip_event));
|
||||
C10_HIP_CHECK_WARN(hipSetDevice(orig_device));
|
||||
}
|
||||
|
||||
void record(void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override {
|
||||
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
|
||||
"Event device index ",
|
||||
device_index,
|
||||
" does not match recording stream's device index ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
|
||||
hipEvent_t hip_event = static_cast<hipEvent_t>(*event);
|
||||
HIPStreamMasqueradingAsCUDA hip_stream{stream};
|
||||
|
||||
// Moves to stream's device to record
|
||||
const auto orig_device = getDevice();
|
||||
setDevice(stream.device());
|
||||
|
||||
// Creates the event (lazily)
|
||||
if (!hip_event) createEvent(&hip_event, flag);
|
||||
C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream));
|
||||
// Makes the void* point to the (possibly just allocated) HIP event
|
||||
*event = hip_event;
|
||||
|
||||
// Resets device
|
||||
setDevice(orig_device);
|
||||
}
|
||||
|
||||
void block(
|
||||
void* event,
|
||||
const Stream& stream) const override {
|
||||
if (!event) return;
|
||||
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
|
||||
HIPStreamMasqueradingAsCUDA hip_stream{stream};
|
||||
const auto orig_device = getDevice();
|
||||
setDevice(stream.device());
|
||||
C10_HIP_CHECK(hipStreamWaitEvent(
|
||||
hip_stream,
|
||||
hip_event,
|
||||
/*flags (must be zero)=*/ 0));
|
||||
setDevice(orig_device);
|
||||
}
|
||||
|
||||
bool queryEvent(void* event) const override {
|
||||
if (!event) return true;
|
||||
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
|
||||
const hipError_t err = hipEventQuery(hip_event);
|
||||
if (err != hipErrorNotReady) C10_HIP_CHECK(err);
|
||||
else {
|
||||
// ignore and clear the error if not ready
|
||||
(void)hipGetLastError();
|
||||
}
|
||||
return (err == hipSuccess);
|
||||
}
|
||||
|
||||
// Stream-related functions
|
||||
bool queryStream(const Stream& stream) const override {
|
||||
HIPStreamMasqueradingAsCUDA hip_stream{stream};
|
||||
return hip_stream.query();
|
||||
}
|
||||
|
||||
void synchronizeStream(const Stream& stream) const override {
|
||||
HIPStreamMasqueradingAsCUDA hip_stream{stream};
|
||||
hip_stream.synchronize();
|
||||
}
|
||||
|
||||
void synchronizeEvent(void* event) const override {
|
||||
if (!event)
|
||||
return;
|
||||
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
|
||||
C10_HIP_CHECK(hipEventSynchronize(hip_event));
|
||||
}
|
||||
|
||||
// Note: synchronizeDevice can be safely called from any device
|
||||
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
|
||||
int orig_device{-1};
|
||||
C10_HIP_CHECK(hipGetDevice(&orig_device));
|
||||
C10_HIP_CHECK(hipSetDevice(device_index));
|
||||
C10_HIP_CHECK(hipDeviceSynchronize());
|
||||
C10_HIP_CHECK(hipSetDevice(orig_device));
|
||||
}
|
||||
|
||||
void recordDataPtrOnStream(
|
||||
const c10::DataPtr& data_ptr,
|
||||
const Stream& stream) const override {
|
||||
HIPStreamMasqueradingAsCUDA hip_stream{stream};
|
||||
HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA(data_ptr, hip_stream);
|
||||
}
|
||||
|
||||
double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
|
||||
const override {
|
||||
TORCH_CHECK(
|
||||
event1 && event2,
|
||||
"Both events must be recorded before calculating elapsed time.");
|
||||
int orig_device;
|
||||
C10_HIP_CHECK(hipGetDevice(&orig_device));
|
||||
C10_HIP_CHECK(hipSetDevice(device_index));
|
||||
hipEvent_t hip_event1 = static_cast<hipEvent_t>(event1);
|
||||
hipEvent_t hip_event2 = static_cast<hipEvent_t>(event2);
|
||||
float time_ms = 0;
|
||||
// raise hipErrorNotReady if either event is recorded but not yet completed
|
||||
C10_HIP_CHECK(hipEventElapsedTime(&time_ms, hip_event1, hip_event2));
|
||||
C10_HIP_CHECK(hipSetDevice(orig_device));
|
||||
return static_cast<double>(time_ms);
|
||||
}
|
||||
};
|
||||
|
||||
// All of the guards which have HIPGuardImpl burned in need to also have
|
||||
// variants using HIPGuardImplMasqueradingAsCUDA.
|
||||
|
||||
/// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with
|
||||
/// the correct InlineDeviceGuard burned in. Sorry about the
|
||||
/// copy-pasting.
|
||||
|
||||
struct HIPGuardMasqueradingAsCUDA {
|
||||
explicit HIPGuardMasqueradingAsCUDA() = delete;
|
||||
explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {}
|
||||
explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {}
|
||||
|
||||
HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete;
|
||||
HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete;
|
||||
|
||||
void set_device(Device device) { guard_.set_device(device); }
|
||||
void reset_device(Device device) { guard_.reset_device(device); }
|
||||
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
|
||||
Device original_device() const { return guard_.original_device(); }
|
||||
Device current_device() const { return guard_.current_device(); }
|
||||
|
||||
private:
|
||||
c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
|
||||
};
|
||||
|
||||
struct OptionalHIPGuardMasqueradingAsCUDA {
|
||||
explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {}
|
||||
explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<Device> device_opt) : guard_(device_opt) {}
|
||||
explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
|
||||
|
||||
OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
|
||||
OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
|
||||
OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
|
||||
OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
|
||||
|
||||
void set_device(Device device) { guard_.set_device(device); }
|
||||
void reset_device(Device device) { guard_.reset_device(device); }
|
||||
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
|
||||
std::optional<Device> original_device() const { return guard_.original_device(); }
|
||||
std::optional<Device> current_device() const { return guard_.current_device(); }
|
||||
void reset() { guard_.reset(); }
|
||||
|
||||
private:
|
||||
c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
|
||||
};
|
||||
|
||||
struct HIPStreamGuardMasqueradingAsCUDA {
|
||||
explicit HIPStreamGuardMasqueradingAsCUDA() = delete;
|
||||
explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
|
||||
HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
|
||||
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
|
||||
|
||||
HIPStreamMasqueradingAsCUDA original_stream() const {
|
||||
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream());
|
||||
}
|
||||
HIPStreamMasqueradingAsCUDA current_stream() const {
|
||||
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream());
|
||||
}
|
||||
|
||||
Device current_device() const { return guard_.current_device(); }
|
||||
Device original_device() const { return guard_.original_device(); }
|
||||
|
||||
private:
|
||||
c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
|
||||
};
|
||||
|
||||
struct OptionalHIPStreamGuardMasqueradingAsCUDA {
|
||||
explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {}
|
||||
explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
|
||||
explicit OptionalHIPStreamGuardMasqueradingAsCUDA(std::optional<Stream> stream_opt) : guard_(stream_opt) {}
|
||||
|
||||
OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
|
||||
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
|
||||
|
||||
std::optional<HIPStreamMasqueradingAsCUDA> original_stream() const {
|
||||
auto r = guard_.original_stream();
|
||||
if (r.has_value()) {
|
||||
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value());
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<HIPStreamMasqueradingAsCUDA> current_stream() const {
|
||||
auto r = guard_.current_stream();
|
||||
if (r.has_value()) {
|
||||
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value());
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
void reset() { guard_.reset(); }
|
||||
|
||||
private:
|
||||
c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
|
||||
};
|
||||
|
||||
struct HIPMultiStreamGuardMasqueradingAsCUDA {
|
||||
explicit HIPMultiStreamGuardMasqueradingAsCUDA(ArrayRef<HIPStreamMasqueradingAsCUDA> streams)
|
||||
: guard_(unwrapStreams(streams)) {}
|
||||
|
||||
HIPMultiStreamGuardMasqueradingAsCUDA(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPMultiStreamGuardMasqueradingAsCUDA& operator=(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPMultiStreamGuardMasqueradingAsCUDA(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
HIPMultiStreamGuardMasqueradingAsCUDA& operator=(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
|
||||
private:
|
||||
c10::impl::InlineMultiStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
|
||||
|
||||
static std::vector<Stream> unwrapStreams(ArrayRef<HIPStreamMasqueradingAsCUDA> hipStreams) {
|
||||
std::vector<Stream> streams;
|
||||
streams.reserve(hipStreams.size());
|
||||
for (const HIPStreamMasqueradingAsCUDA& hipStream : hipStreams) {
|
||||
streams.push_back(hipStream);
|
||||
}
|
||||
return streams;
|
||||
}
|
||||
};
|
||||
|
||||
}} // namespace c10::hip
|
||||
135
aten/src/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h
Normal file
135
aten/src/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h
Normal file
@ -0,0 +1,135 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/hip/HIPStream.h>
|
||||
|
||||
// Use of c10::hip namespace here makes hipification easier, because
|
||||
// I don't have to also fix namespaces. Sorry!
|
||||
namespace c10 { namespace hip {
|
||||
|
||||
// See Note [Masquerading as CUDA] for motivation
|
||||
|
||||
class HIPStreamMasqueradingAsCUDA {
|
||||
public:
|
||||
|
||||
enum Unchecked { UNCHECKED };
|
||||
|
||||
explicit HIPStreamMasqueradingAsCUDA(Stream stream)
|
||||
: HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) {
|
||||
// We did the coercion unchecked; check that it was right.
|
||||
TORCH_CHECK(stream.device().is_cuda() /* !!! */);
|
||||
}
|
||||
|
||||
explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream)
|
||||
// Unsafely coerce the "CUDA" stream into a HIP stream
|
||||
: stream_(
|
||||
HIPStream(
|
||||
Stream(
|
||||
Stream::UNSAFE,
|
||||
Device(c10::DeviceType::HIP, stream.device_index()),
|
||||
stream.id())
|
||||
)
|
||||
) {}
|
||||
|
||||
// New constructor, just for this. Does NOT coerce.
|
||||
explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {}
|
||||
|
||||
bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
|
||||
return stream_ == other.stream_;
|
||||
}
|
||||
|
||||
bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
|
||||
return stream_ != other.stream_;
|
||||
}
|
||||
|
||||
operator hipStream_t() const { return stream_.stream(); }
|
||||
|
||||
operator Stream() const {
|
||||
// Unsafely coerce HIP stream into a "CUDA" stream
|
||||
return Stream(Stream::UNSAFE, device(), id());
|
||||
}
|
||||
|
||||
DeviceIndex device_index() const { return stream_.device_index(); }
|
||||
|
||||
// Unsafely coerce HIP device into CUDA device
|
||||
c10::DeviceType device_type() const { return c10::DeviceType::CUDA; }
|
||||
|
||||
Device device() const {
|
||||
// Unsafely coerce HIP device into CUDA device
|
||||
return Device(c10::DeviceType::CUDA, stream_.device_index());
|
||||
}
|
||||
|
||||
StreamId id() const { return stream_.id(); }
|
||||
bool query() const { return stream_.query(); }
|
||||
void synchronize() const { stream_.synchronize(); }
|
||||
int priority() const { return stream_.priority(); }
|
||||
hipStream_t stream() const { return stream_.stream(); }
|
||||
|
||||
Stream unwrap() const {
|
||||
// Unsafely coerce HIP stream into "CUDA" stream
|
||||
return Stream(Stream::UNSAFE, device(), id());
|
||||
}
|
||||
|
||||
c10::StreamData3 pack3() const noexcept {
|
||||
// Unsafely coerce HIP stream into "CUDA" stream before packing
|
||||
return unwrap().pack3();
|
||||
}
|
||||
|
||||
static HIPStreamMasqueradingAsCUDA unpack3(StreamId stream_id,
|
||||
DeviceIndex device_index,
|
||||
c10::DeviceType device_type) {
|
||||
// NB: constructor manages CUDA->HIP translation for us
|
||||
return HIPStreamMasqueradingAsCUDA(Stream::unpack3(
|
||||
stream_id, device_index, device_type));
|
||||
}
|
||||
|
||||
static std::tuple<int, int> priority_range() { return HIPStream::priority_range(); }
|
||||
|
||||
// New method, gets the underlying HIPStream
|
||||
HIPStream hip_stream() const { return stream_; }
|
||||
|
||||
private:
|
||||
HIPStream stream_;
|
||||
};
|
||||
|
||||
HIPStreamMasqueradingAsCUDA
|
||||
inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) {
|
||||
return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device));
|
||||
}
|
||||
|
||||
HIPStreamMasqueradingAsCUDA
|
||||
inline getStreamFromPoolMasqueradingAsCUDA(const int priority, DeviceIndex device = -1) {
|
||||
return HIPStreamMasqueradingAsCUDA(getStreamFromPool(priority, device));
|
||||
}
|
||||
|
||||
HIPStreamMasqueradingAsCUDA
|
||||
inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) {
|
||||
return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device));
|
||||
}
|
||||
|
||||
inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
|
||||
return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index));
|
||||
}
|
||||
|
||||
inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
|
||||
return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index));
|
||||
}
|
||||
|
||||
inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) {
|
||||
setCurrentHIPStream(stream.hip_stream());
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) {
|
||||
stream << s.hip_stream() << " (masquerading as CUDA)";
|
||||
return stream;
|
||||
}
|
||||
|
||||
}} // namespace c10::hip
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<c10::hip::HIPStreamMasqueradingAsCUDA> {
|
||||
size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept {
|
||||
return std::hash<c10::Stream>{}(s.unwrap());
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
@ -39,7 +39,7 @@ using MIOpenPoolType = at::cuda::DeviceThreadHandlePool<
|
||||
|
||||
miopenHandle_t getMiopenHandle() {
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(at::cuda::GetDevice(&device));
|
||||
AT_CUDA_CHECK(c10::hip::GetDevice(&device));
|
||||
|
||||
// Thread local PoolWindows are lazily-initialized
|
||||
// to avoid initialization issues that caused hangs on Windows.
|
||||
@ -51,7 +51,7 @@ miopenHandle_t getMiopenHandle() {
|
||||
pool->newPoolWindow());
|
||||
|
||||
auto handle = myPoolWindow->reserve(device);
|
||||
MIOPEN_CHECK(miopenSetStream(handle, at::cuda::getCurrentCUDAStream()));
|
||||
MIOPEN_CHECK(miopenSetStream(handle, c10::hip::getCurrentHIPStream()));
|
||||
return handle;
|
||||
}
|
||||
|
||||
|
||||
@ -5,13 +5,9 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/MathConstants.h>
|
||||
|
||||
// ROCm hip compiler doesn't work well with using std:: in kernel functions
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
|
||||
// ROCM hcc doesn't work well with using std:: in kernel functions
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#elif defined(__HIPCC__)
|
||||
#include <c10/hip/HIPMathCompat.h>
|
||||
#endif
|
||||
#define compat_exp c10::cuda::compat::exp
|
||||
#define compat_ceil c10::cuda::compat::ceil
|
||||
#define compat_floor c10::cuda::compat::floor
|
||||
@ -21,6 +17,17 @@
|
||||
#define compat_tan c10::cuda::compat::tan
|
||||
#define compat_abs c10::cuda::compat::abs
|
||||
#define compat_log1p c10::cuda::compat::log1p
|
||||
#elif defined(__HIPCC__)
|
||||
#include <c10/hip/HIPMathCompat.h>
|
||||
#define compat_exp c10::hip::compat::exp
|
||||
#define compat_ceil c10::hip::compat::ceil
|
||||
#define compat_floor c10::hip::compat::floor
|
||||
#define compat_log c10::hip::compat::log
|
||||
#define compat_pow c10::hip::compat::pow
|
||||
#define compat_sqrt c10::hip::compat::sqrt
|
||||
#define compat_tan c10::hip::compat::tan
|
||||
#define compat_abs c10::hip::compat::abs
|
||||
#define compat_log1p c10::hip::compat::log1p
|
||||
#else
|
||||
#define compat_exp std::exp
|
||||
#define compat_ceil std::ceil
|
||||
|
||||
@ -1936,7 +1936,7 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
|
||||
|
||||
// We order the tensors. t1 will be the larger tensor
|
||||
// We can always transpose tensor2 as the dimensions are always >= 1 (precondition from matmul)
|
||||
// and tensor1_larger iff tensor2.dim() > tensor1.dim(9
|
||||
// and tensor1_larger iff tensor2.dim() > tensor1.dim()
|
||||
const auto t1 = tensor1_larger ? MaybeOwned<Tensor>::borrowed(tensor1)
|
||||
: MaybeOwned<Tensor>::owned(tensor2.mT());
|
||||
const int64_t dim_t1 = t1->dim();
|
||||
@ -1948,20 +1948,11 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
|
||||
return false;
|
||||
}
|
||||
|
||||
// In this case we *do* incur in an extra copy to avoid creating an unnecessary large tensor in the backward
|
||||
// Suppose we don't fold here. Let t1.shape = [b, m, n] t2.shape = [n, k] like in a transformer
|
||||
// t2 will be expanded to a tensor of shape [b, n, k] and then we do t1.bmm(t2_expanded)
|
||||
// The issue appears in the backward.
|
||||
// The output gradient g of this operation would have shape [b, m, k]
|
||||
// The backward wrt. t2 of bmm would be given by t1.mH @ g, which has shape [b, n, k]
|
||||
// Then, the backward of expand is simply `sum(0)`. As such, we are instantiating a tensor
|
||||
// of shape [b, n, k] unnecessarily, which may cause a large memory footprint, and in the
|
||||
// worst case, an OOM
|
||||
bool t2_requires_grad = tensor1_larger ? tensor2.requires_grad() : tensor1.requires_grad();
|
||||
if (t2_requires_grad && !has_out) {
|
||||
// We should be checking !at::GradMode::is_enabled(), but apparently
|
||||
// this regresses performance in some cases:
|
||||
// https://github.com/pytorch/pytorch/issues/118548#issuecomment-1916022394
|
||||
// If we require a gradient, we should fold to minimize backward memory usage - even if this
|
||||
// leads to a copy in forward because is needed in backward,
|
||||
// only time we avoid this strict pre-allocated memory usage (has_out = True)
|
||||
bool requires_grad = tensor1.requires_grad() || tensor2.requires_grad();
|
||||
if (requires_grad && !has_out) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -52,14 +52,13 @@ inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
|
||||
#define MIN(X, Y) min_impl(X,Y)
|
||||
#endif
|
||||
|
||||
// ROCm hip compiler doesn't work well with using std:: in kernel functions
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
|
||||
// ROCM hcc doesn't work well with using std:: in kernel functions
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#define compat_pow c10::cuda::compat::pow
|
||||
#elif defined(__HIPCC__)
|
||||
#include <c10/hip/HIPMathCompat.h>
|
||||
#endif
|
||||
#define compat_pow c10::cuda::compat::pow
|
||||
#define compat_pow c10::hip::compat::pow
|
||||
#else
|
||||
#define compat_pow std::pow
|
||||
#endif
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace at::native {
|
||||
|
||||
// Used as an interface between the different BLAS-like libraries
|
||||
@ -21,3 +23,5 @@ static inline char to_blas(TransposeType trans) {
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/CompositeRandomAccessorCommon.h>
|
||||
#include <thrust/swap.h>
|
||||
#include <thrust/tuple.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
@ -75,30 +75,52 @@ static inline bool can_use_int32_nhwc(
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline bool can_use_int32_nchw(
|
||||
int64_t nbatch, int64_t channels,
|
||||
int64_t height, int64_t width,
|
||||
int64_t pooled_height, int64_t pooled_width) {
|
||||
int64_t hw = height * width;
|
||||
return can_use_int32_nhwc(
|
||||
nbatch, channels, height, width,
|
||||
pooled_height, pooled_width,
|
||||
channels * hw, // in_stride_n
|
||||
hw, // in_stride_c
|
||||
width, // in_stride_h
|
||||
1 // in_stride_w
|
||||
);
|
||||
}
|
||||
|
||||
// kernels borrowed from Caffe
|
||||
template <typename scalar_t>
|
||||
__global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom_data,
|
||||
const int64_t channels, const int64_t height,
|
||||
const int64_t width, const int pooled_height, const int pooled_width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w, scalar_t* top_data,
|
||||
template <typename scalar_t, typename index_t>
|
||||
__global__ void max_pool_forward_nchw(
|
||||
const index_t nthreads,
|
||||
const scalar_t* bottom_data,
|
||||
const int64_t channels,
|
||||
const int64_t height,
|
||||
const int64_t width,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
scalar_t* top_data,
|
||||
int64_t* top_mask) {
|
||||
CUDA_KERNEL_LOOP(index, nthreads) {
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
int hstart = ph * stride_h - pad_h;
|
||||
int wstart = pw * stride_w - pad_w;
|
||||
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
|
||||
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
|
||||
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
|
||||
index_t pw = index % pooled_width;
|
||||
index_t ph = (index / pooled_width) % pooled_height;
|
||||
index_t c = (index / pooled_width / pooled_height) % channels;
|
||||
index_t n = index / pooled_width / pooled_height / channels;
|
||||
index_t hstart = ph * stride_h - pad_h;
|
||||
index_t wstart = pw * stride_w - pad_w;
|
||||
index_t hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
|
||||
index_t wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
|
||||
while(hstart < 0)
|
||||
hstart += dilation_h;
|
||||
while(wstart < 0)
|
||||
wstart += dilation_w;
|
||||
scalar_t maxval = at::numeric_limits<scalar_t>::lower_bound(); // -Infinity
|
||||
int maxidx = hstart * width + wstart;
|
||||
index_t maxidx = hstart * width + wstart;
|
||||
const scalar_t* btm_data = bottom_data + (n * channels + c) * height * width;
|
||||
for (int h = hstart; h < hend; h += dilation_h) {
|
||||
for (int w = wstart; w < wend; w += dilation_w) {
|
||||
@ -251,32 +273,39 @@ __global__ void max_pool_forward_nhwc(
|
||||
|
||||
static constexpr int BLOCK_THREADS = 256;
|
||||
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
template <typename scalar_t, typename accscalar_t, typename index_t>
|
||||
#if defined (USE_ROCM)
|
||||
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4)
|
||||
#else
|
||||
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 8)
|
||||
#endif
|
||||
__global__ void max_pool_backward_nchw(const scalar_t* top_diff,
|
||||
const int64_t* top_mask, const int num, const int64_t channels,
|
||||
const int64_t height, const int64_t width, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
|
||||
__global__ void max_pool_backward_nchw(
|
||||
const scalar_t* top_diff,
|
||||
const int64_t* top_mask,
|
||||
const index_t num,
|
||||
const index_t channels,
|
||||
const index_t height,
|
||||
const index_t width,
|
||||
const index_t pooled_height,
|
||||
const index_t pooled_width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
scalar_t* bottom_diff) {
|
||||
CUDA_KERNEL_LOOP(index, height*width) {
|
||||
int h = index / width;
|
||||
int w = index - h * width;
|
||||
int phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
|
||||
int phend = p_end(h, pad_h, pooled_height, stride_h);
|
||||
int pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
|
||||
int pwend = p_end(w, pad_w, pooled_width, stride_w);
|
||||
for (int n = blockIdx.y; n < num; n += gridDim.y) {
|
||||
for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
|
||||
CUDA_KERNEL_LOOP_TYPE(index, height*width, index_t) {
|
||||
index_t h = index / width;
|
||||
index_t w = index - h * width;
|
||||
index_t phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
|
||||
index_t phend = p_end(h, pad_h, pooled_height, stride_h);
|
||||
index_t pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
|
||||
index_t pwend = p_end(w, pad_w, pooled_width, stride_w);
|
||||
for (index_t n = blockIdx.y; n < num; n += gridDim.y) {
|
||||
for (index_t c = blockIdx.z; c < channels; c += gridDim.z) {
|
||||
accscalar_t gradient = accscalar_t(0);
|
||||
int offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
for (int ph = phstart; ph < phend; ++ph) {
|
||||
for (int pw = pwstart; pw < pwend; ++pw) {
|
||||
index_t offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
for (index_t ph = phstart; ph < phend; ++ph) {
|
||||
for (index_t pw = pwstart; pw < pwend; ++pw) {
|
||||
if (top_mask[ph * pooled_width + pw + offset] == h * width + w) {
|
||||
gradient += static_cast<accscalar_t>(top_diff[ph * pooled_width + pw + offset]);
|
||||
}
|
||||
@ -469,8 +498,6 @@ const Tensor& indices) {
|
||||
const int64_t in_stride_h = input.stride(-2);
|
||||
const int64_t in_stride_w = input.stride(-1);
|
||||
|
||||
const int count = safe_downcast<int, int64_t>(output.numel());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"max_pool2d_with_indices_out_cuda_frame",
|
||||
[&] {
|
||||
@ -553,14 +580,42 @@ const Tensor& indices) {
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::Contiguous: {
|
||||
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
|
||||
BLOCK_THREADS);
|
||||
max_pool_forward_nchw<scalar_t>
|
||||
<<<ceil_div(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count, input_data,
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
output_data, indices_data);
|
||||
const int threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
|
||||
BLOCK_THREADS);
|
||||
const int64_t nthreads = output.numel();
|
||||
bool use_int32 = can_use_int32_nchw(
|
||||
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
|
||||
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
|
||||
const int blocks = static_cast<int>(std::min<int64_t>(
|
||||
ceil_div(nthreads, static_cast<int64_t>(threads)),
|
||||
static_cast<int64_t>(maxGridX)));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (use_int32) {
|
||||
max_pool_forward_nchw<scalar_t, int32_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
static_cast<int32_t>(nthreads),
|
||||
input_data,
|
||||
static_cast<int32_t>(nInputPlane),
|
||||
static_cast<int32_t>(inputHeight),
|
||||
static_cast<int32_t>(inputWidth),
|
||||
static_cast<int32_t>(outputHeight),
|
||||
static_cast<int32_t>(outputWidth),
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
output_data, indices_data);
|
||||
} else {
|
||||
max_pool_forward_nchw<scalar_t, int64_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
nthreads,
|
||||
input_data,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
output_data, indices_data);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
break;
|
||||
}
|
||||
@ -633,8 +688,6 @@ const Tensor& gradInput) {
|
||||
|
||||
gradInput.zero_();
|
||||
|
||||
int64_t count = input.numel();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"max_pool2d_with_indices_out_cuda_frame",
|
||||
[&] {
|
||||
@ -692,25 +745,45 @@ const Tensor& gradInput) {
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::Contiguous: {
|
||||
int imgcount = inputWidth * inputHeight;
|
||||
dim3 grid;
|
||||
const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS;
|
||||
grid.x = blocks;
|
||||
grid.y = nbatch;
|
||||
uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
if (maxGridY < grid.y) grid.y = maxGridY;
|
||||
grid.z = nInputPlane;
|
||||
uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
|
||||
if (maxGridZ < grid.z) grid.z = maxGridZ;
|
||||
|
||||
max_pool_backward_nchw<scalar_t, accscalar_t>
|
||||
<<<grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
gradOutput_data,
|
||||
indices_data,
|
||||
nbatch,
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
gradInput_data);
|
||||
const int threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
|
||||
BLOCK_THREADS);
|
||||
const int imgcount = inputWidth * inputHeight;
|
||||
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
|
||||
const int maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
const int maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
|
||||
const int blocks_x = std::min(ceil_div(imgcount, threads), maxGridX);
|
||||
dim3 grid(blocks_x, static_cast<unsigned>(std::min<int64_t>(nbatch, maxGridY)), static_cast<unsigned>(std::min<int64_t>(nInputPlane, maxGridZ)));
|
||||
bool use_int32 = can_use_int32_nchw(
|
||||
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (use_int32) {
|
||||
max_pool_backward_nchw<scalar_t, accscalar_t, int32_t>
|
||||
<<<grid, threads, 0, stream>>>(
|
||||
gradOutput_data,
|
||||
indices_data,
|
||||
static_cast<int32_t>(nbatch),
|
||||
static_cast<int32_t>(nInputPlane),
|
||||
static_cast<int32_t>(inputHeight),
|
||||
static_cast<int32_t>(inputWidth),
|
||||
static_cast<int32_t>(outputHeight),
|
||||
static_cast<int32_t>(outputWidth),
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
gradInput_data);
|
||||
} else {
|
||||
max_pool_backward_nchw<scalar_t, accscalar_t, int64_t>
|
||||
<<<grid, threads, 0, stream>>>(
|
||||
gradOutput_data,
|
||||
indices_data,
|
||||
nbatch,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
gradInput_data);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
break;
|
||||
}
|
||||
|
||||
@ -267,15 +267,15 @@ void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, con
|
||||
* outer dimensions, which contains several "inner rows").
|
||||
* Each thread processes a single inner row at a time.
|
||||
*/
|
||||
template<typename scalar_t, class BinaryOp>
|
||||
template<typename scalar_t, typename index_t, class BinaryOp>
|
||||
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
|
||||
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
|
||||
const scalar_t init, BinaryOp binary_op)
|
||||
{
|
||||
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
|
||||
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
|
||||
const scalar_t *src = src_ + orow * row_size * num_irows + irow;
|
||||
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
|
||||
const scalar_t *src = src_ + static_cast<index_t>(orow) * row_size * num_irows + irow;
|
||||
scalar_t *tgt = tgt_ + (index_t) orow * row_size * num_irows + irow;
|
||||
scalar_t acc = init;
|
||||
|
||||
for (uint32_t col = 0; col < row_size; ++col) {
|
||||
@ -409,10 +409,15 @@ __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
|
||||
check_fits_in_unsigned(num_irows, "num_irows");
|
||||
check_fits_in_unsigned(num_orows, "num_orows");
|
||||
check_fits_in_unsigned(row_size, "row_size");
|
||||
|
||||
tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
if (static_cast<size_t>(num_irows) * num_orows * row_size <= UINT_MAX) {
|
||||
tensor_kernel_scan_outer_dim<scalar_t, uint32_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
|
||||
num_orows, num_irows, row_size, init, binary_op);
|
||||
} else {
|
||||
tensor_kernel_scan_outer_dim<scalar_t, size_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
|
||||
num_orows, num_irows, row_size, init, binary_op);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
|
||||
@ -157,7 +157,7 @@ void bgemm_kernel_impl(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
invoker.Run(argument, StreamConfig{stream, false});
|
||||
}
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <numeric>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
#include <ATen/native/hip/ck_gemm.h>
|
||||
#include <ATen/native/hip/ck_types.h>
|
||||
|
||||
@ -232,7 +233,7 @@ void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
}
|
||||
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
invoker.Run(argument, StreamConfig{stream, false});
|
||||
}
|
||||
|
||||
@ -390,7 +391,7 @@ void gemm_impl_wmma(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
}
|
||||
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
#if 1
|
||||
invoker.Run(argument, StreamConfig{stream, false});
|
||||
#else
|
||||
|
||||
@ -385,7 +385,7 @@ void launch_grouped_bgemm_ck_impl_dispatch(
|
||||
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
|
||||
|
||||
auto invoker = gemm_instance.MakeInvoker();
|
||||
hipStream_t stream = c10::cuda::getCurrentCUDAStream();
|
||||
hipStream_t stream = c10::hip::getCurrentHIPStream();
|
||||
invoker.Run(argument, {stream});
|
||||
hipFree(gemm_arg_buf);
|
||||
hipFree(ws_buf);
|
||||
|
||||
@ -278,14 +278,14 @@ BenchmarkCache<size_t> bwd_filter_wssizes;
|
||||
|
||||
struct Workspace {
|
||||
Workspace(size_t size) : size(size), data(NULL) {
|
||||
data = c10::cuda::CUDACachingAllocator::raw_alloc(size);
|
||||
data = c10::hip::HIPCachingAllocator::raw_alloc(size);
|
||||
}
|
||||
Workspace(const Workspace&) = delete;
|
||||
Workspace(Workspace&&) = default;
|
||||
Workspace& operator=(Workspace&&) = default;
|
||||
~Workspace() {
|
||||
if (data) {
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(data);
|
||||
c10::hip::HIPCachingAllocator::raw_delete(data);
|
||||
}
|
||||
}
|
||||
|
||||
@ -587,7 +587,7 @@ void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) {
|
||||
wsscache.insert(args.params, perfResults.memory);
|
||||
|
||||
if (at::native::_cudnn_get_conv_benchmark_empty_cache()) {
|
||||
c10::cuda::CUDACachingAllocator::emptyCache();
|
||||
c10::hip::HIPCachingAllocator::emptyCache();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -76,14 +76,14 @@ namespace {
|
||||
|
||||
struct DropoutState {
|
||||
DropoutState(size_t size) : size(size), data(NULL) {
|
||||
data = c10::cuda::CUDACachingAllocator::raw_alloc(size);
|
||||
data = c10::hip::HIPCachingAllocator::raw_alloc(size);
|
||||
}
|
||||
DropoutState(const DropoutState&) = delete;
|
||||
DropoutState(DropoutState&&) = default;
|
||||
DropoutState& operator=(DropoutState&&) = default;
|
||||
~DropoutState() {
|
||||
if (data) {
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(data);
|
||||
c10::hip::HIPCachingAllocator::raw_delete(data);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -82,6 +82,7 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
|
||||
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
|
||||
std::string getMPSShapeString(MPSShape* shape);
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
|
||||
std::string to_hex_key(float);
|
||||
std::string getArrayRefString(const IntArrayRef s);
|
||||
// use has_storage() on the returned tensor to determine if src actually is a view
|
||||
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
|
||||
|
||||
@ -301,6 +301,10 @@ std::string getArrayRefString(const IntArrayRef s) {
|
||||
return fmt::to_string(fmt::join(s, ","));
|
||||
}
|
||||
|
||||
std::string to_hex_key(float f) {
|
||||
return fmt::format("{:a}", f);
|
||||
}
|
||||
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
|
||||
fmt::basic_memory_buffer<char, 100> buffer;
|
||||
auto buf_iterator = std::back_inserter(buffer);
|
||||
|
||||
@ -40,7 +40,7 @@ inline c10::metal::opmath_t<T> matmul_inner(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint k = 0; k < TILE_DIM; k++) {
|
||||
sum += A_tile[tid.y][k] * B_tile[k][tid.x];
|
||||
sum += c10::metal::mul(A_tile[tid.y][k], B_tile[k][tid.x]);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -832,6 +832,10 @@ INSTANTIATE_MM_OPS(float);
|
||||
INSTANTIATE_MM_OPS(half);
|
||||
INSTANTIATE_MM_OPS(bfloat);
|
||||
|
||||
// Complex MM
|
||||
INSTANTIATE_MM_OPS(float2);
|
||||
INSTANTIATE_MM_OPS(half2);
|
||||
|
||||
// Integral MM
|
||||
INSTANTIATE_MM_OPS(long);
|
||||
INSTANTIATE_MM_OPS(int);
|
||||
|
||||
@ -190,10 +190,16 @@ std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* gr
|
||||
bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
static bool always_use_metal = c10::utils::has_env("PYTORCH_MPS_PREFER_METAL");
|
||||
constexpr auto max_stride_size = 32768;
|
||||
constexpr auto max_complex_inner_size = 2048;
|
||||
static bool is_macos_14_4_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
|
||||
if (always_use_metal || c10::isIntegralType(self.scalar_type(), true)) {
|
||||
return true;
|
||||
}
|
||||
// multiplicationWithPrimaryTensor: returns incorrect results if inner size exceeds 2048
|
||||
// See https://github.com/pytorch/pytorch/issues/167727#issuecomment-3529308548
|
||||
if (c10::isComplexType(self.scalar_type()) && self.size(1) > max_complex_inner_size) {
|
||||
return true;
|
||||
}
|
||||
return !is_macos_14_4_or_newer &&
|
||||
(self.stride(0) > max_stride_size || self.stride(1) > max_stride_size || self.size(0) > max_stride_size ||
|
||||
self.size(1) > max_stride_size || other.stride(0) > max_stride_size || other.stride(1) > max_stride_size ||
|
||||
|
||||
@ -244,8 +244,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar
|
||||
|
||||
@ -4389,7 +4389,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: mv
|
||||
SparseCPU, SparseCUDA: mv_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: mv_sparse
|
||||
|
||||
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
@ -7518,7 +7518,7 @@
|
||||
- func: _sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_mask_projection
|
||||
SparseCPU, SparseCUDA, SparseMPS: sparse_mask_projection
|
||||
autogen: _sparse_mask_projection.out
|
||||
|
||||
- func: _to_cpu(Tensor[] tensors) -> Tensor[]
|
||||
|
||||
@ -30,10 +30,12 @@
|
||||
|
||||
#include <thrust/binary_search.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/distance.h>
|
||||
#include <thrust/iterator/constant_iterator.h>
|
||||
#include <thrust/scan.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/system/cuda/execution_policy.h>
|
||||
#include <thrust/iterator/constant_iterator.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusparse.h>
|
||||
@ -60,6 +62,8 @@
|
||||
#include <thrust/transform.h>
|
||||
#include <thrust/unique.h>
|
||||
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace {
|
||||
|
||||
|
||||
@ -445,6 +445,33 @@ static SparseTensor& mul_out_dense_sparse_mps(
|
||||
return out;
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor, int64_t> mps_intersect_binary_search(
|
||||
const Tensor& A_keys,
|
||||
const Tensor& B_keys,
|
||||
int64_t lenA,
|
||||
int64_t lenB,
|
||||
bool boolean_flag) {
|
||||
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto outA_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
|
||||
auto outB_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
|
||||
auto counter = at::zeros({1}, A_keys.options().dtype(at::kInt));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), boolean_flag);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const auto match_count = static_cast<int64_t>(counter.item<int32_t>());
|
||||
return std::make_tuple(std::move(outA_idx), std::move(outB_idx), match_count);
|
||||
}
|
||||
|
||||
|
||||
SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTensor& r_) {
|
||||
TORCH_CHECK(r_.is_mps(), "mul: expected 'out' to be MPS, but got ", r_.device());
|
||||
@ -523,22 +550,10 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
||||
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
|
||||
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
|
||||
|
||||
auto outA_idx = at::empty({lenA}, at::device(device).dtype(kLong));
|
||||
auto outB_idx = at::empty({lenA}, at::device(device).dtype(kLong));
|
||||
auto counter = at::zeros({1}, at::device(device).dtype(kInt));
|
||||
auto [outA_idx, outB_idx, M_int64] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_lhs);
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), A_is_lhs);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const uint32_t M = counter.item<int32_t>(); // number of structural matches
|
||||
const auto M = static_cast<uint32_t>(M_int64); // number of structural matches
|
||||
|
||||
r_.resize_as_(lhs);
|
||||
|
||||
@ -762,6 +777,14 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self,
|
||||
|
||||
using OptTensor = std::optional<Tensor>;
|
||||
|
||||
static Tensor create_sparse_output_values(
|
||||
const Tensor& template_values,
|
||||
int64_t output_nnz,
|
||||
ScalarType dtype) {
|
||||
auto out_val_sizes = template_values.sizes().vec();
|
||||
out_val_sizes[0] = output_nnz;
|
||||
return at::zeros(out_val_sizes, template_values.options().dtype(dtype));
|
||||
}
|
||||
|
||||
static void sparse_mask_apply_out_mps_kernel(
|
||||
Tensor& result,
|
||||
@ -783,9 +806,9 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
auto src = src_in.coalesce();
|
||||
auto mask = coalesce_mask ? mask_in.coalesce() : mask_in;
|
||||
|
||||
const int64_t src_nnz = src._nnz();
|
||||
const int64_t mask_nnz = mask._nnz();
|
||||
const int64_t sd = src.sparse_dim();
|
||||
const auto src_nnz = src._nnz();
|
||||
const auto mask_nnz = mask._nnz();
|
||||
const auto sd = src.sparse_dim();
|
||||
result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim());
|
||||
|
||||
auto commonDtype = at::result_type(src, mask);
|
||||
@ -814,53 +837,27 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_indices = mask._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype).contiguous();
|
||||
auto out_values = create_sparse_output_values(src_values, mask_nnz, commonDtype);
|
||||
|
||||
if (src_nnz == 0) {
|
||||
auto out_indices = mask._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype);
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
alias_into_sparse(result, out_indices, out_values);
|
||||
alias_into_sparse(result, mask_indices, out_values);
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_indices = mask._indices().contiguous();
|
||||
auto src_indices = src._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype).contiguous();
|
||||
auto mask_keys = flatten_indices(mask._indices().contiguous(), mask.sizes().slice(0, sd)).contiguous();
|
||||
auto src_keys = flatten_indices(src._indices().contiguous(), src.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous();
|
||||
auto src_keys = flatten_indices(src_indices, src.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
const bool A_is_src = (src_nnz <= mask_nnz);
|
||||
const int64_t lenA = A_is_src ? src_nnz : mask_nnz;
|
||||
const int64_t lenB = A_is_src ? mask_nnz : src_nnz;
|
||||
const auto A_is_src = (src_nnz <= mask_nnz);
|
||||
const auto lenA = A_is_src ? src_nnz : mask_nnz;
|
||||
const auto lenB = A_is_src ? mask_nnz : src_nnz;
|
||||
auto A_keys = A_is_src ? src_keys : mask_keys;
|
||||
auto B_keys = A_is_src ? mask_keys : src_keys;
|
||||
|
||||
const auto device = result.device();
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto counter = at::zeros({1}, at::device(device).dtype(at::kInt));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), A_is_src);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const int64_t M = static_cast<int64_t>(counter.item<int32_t>());
|
||||
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_src);
|
||||
|
||||
if (M > 0) {
|
||||
auto src_match = outA_idx.narrow(0, 0, M);
|
||||
@ -878,6 +875,70 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
}
|
||||
|
||||
static void sparse_mask_projection_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& lhs,
|
||||
const Tensor& rhs,
|
||||
const OptTensor& /*x_hash_opt*/,
|
||||
bool accumulate_matches) {
|
||||
|
||||
TORCH_CHECK(lhs.is_sparse() && rhs.is_sparse(), "sparse_mask_projection: expected sparse COO");
|
||||
TORCH_CHECK(lhs.is_mps() && rhs.is_mps(), "sparse_mask_projection: expected MPS tensors");
|
||||
TORCH_CHECK(lhs.sparse_dim() == rhs.sparse_dim(), "sparse_dim mismatch");
|
||||
|
||||
auto lhs_c = lhs.coalesce();
|
||||
auto rhs_c = rhs.coalesce();
|
||||
|
||||
const auto sd = lhs_c.sparse_dim();
|
||||
const auto lhs_nnz = lhs_c._nnz();
|
||||
const auto rhs_nnz = rhs_c._nnz();
|
||||
|
||||
auto commonDtype = at::result_type(lhs_c, rhs_c);
|
||||
TORCH_CHECK(canCast(commonDtype, result.scalar_type()),
|
||||
"Can't convert ", commonDtype, " to output ", result.scalar_type());
|
||||
|
||||
result.sparse_resize_(lhs.sizes(), lhs.sparse_dim(), lhs.dense_dim());
|
||||
|
||||
auto lhs_indices = lhs_c._indices().contiguous();
|
||||
auto rhs_values = rhs_c._values().to(commonDtype).contiguous();
|
||||
auto out_values = create_sparse_output_values(rhs_values, lhs_nnz, commonDtype);
|
||||
|
||||
if (lhs_nnz > 0 && rhs_nnz > 0) {
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs_c.sizes().slice(0, sd)).contiguous();
|
||||
auto rhs_keys = flatten_indices(rhs_c._indices().contiguous(), rhs_c.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
const auto A_is_lhs = (lhs_nnz <= rhs_nnz);
|
||||
const auto lenA = A_is_lhs ? lhs_nnz : rhs_nnz;
|
||||
const auto lenB = A_is_lhs ? rhs_nnz : lhs_nnz;
|
||||
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
|
||||
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
|
||||
|
||||
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_lhs);
|
||||
|
||||
if (M > 0) {
|
||||
auto idx_in_A = outA_idx.narrow(0, 0, M);
|
||||
auto idx_in_B = outB_idx.narrow(0, 0, M);
|
||||
auto idx_in_lhs = A_is_lhs ? idx_in_A : idx_in_B;
|
||||
auto idx_in_rhs = A_is_lhs ? idx_in_B : idx_in_A;
|
||||
|
||||
const auto view_cols = rhs_values.numel() / std::max<int64_t>(rhs_nnz, 1);
|
||||
auto rhs_rows = rhs_values.index_select(0, idx_in_rhs).contiguous();
|
||||
auto rhs_rows_2d = rhs_rows.view({M, view_cols});
|
||||
auto out_2d = out_values.view({lhs_nnz, view_cols});
|
||||
|
||||
if (accumulate_matches) {
|
||||
out_2d.index_add_(0, idx_in_lhs, rhs_rows_2d);
|
||||
} else {
|
||||
out_2d.index_copy_(0, idx_in_lhs, rhs_rows_2d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
alias_into_sparse(result, lhs._indices(), out_values);
|
||||
result._coalesced_(lhs.is_coalesced());
|
||||
}
|
||||
|
||||
static void sparse_mask_intersection_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& lhs,
|
||||
@ -1002,4 +1063,5 @@ Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_mps_kernel);
|
||||
} // namespace at::native
|
||||
@ -37,6 +37,7 @@
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/HIPGraphsUtils.cuh>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
@ -161,7 +162,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||
std::optional<int64_t> window_size_right,
|
||||
const bool return_softmax,
|
||||
const std::optional<at::Generator>& gen_) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
check_gpu_arch(stream);
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
@ -347,8 +348,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||
TORCH_CHECK(!paged_KV, "[ROCm] mha_varlen_fwd: block_table_ must be nullopt");
|
||||
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
check_gpu_arch(stream);
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
@ -559,8 +560,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||
const at::Tensor& philox_offset) {
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
check_gpu_arch(stream);
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
@ -792,8 +793,8 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
check_gpu_arch(stream);
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
|
||||
@ -261,7 +261,7 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
@ -365,7 +365,7 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
||||
}
|
||||
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
||||
@ -261,7 +261,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
bool has_lse = true;
|
||||
@ -299,7 +299,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
|
||||
|
||||
hipLaunchKernelGGL(
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::cuda::getCurrentCUDAStream(), philox_args, rng_state_ptr);
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr);
|
||||
seed_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[0])), at::dtype(at::kLong));
|
||||
offset_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[1])), at::dtype(at::kLong));
|
||||
}
|
||||
@ -317,7 +317,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
|
||||
if (seqlen_k > 0) {
|
||||
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
auto traits =
|
||||
|
||||
@ -255,7 +255,7 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
@ -366,7 +366,7 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
|
||||
}
|
||||
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
||||
@ -273,7 +273,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
bool has_lse = true;
|
||||
@ -307,7 +307,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
auto philox_args = gen->philox_cuda_state(counter_offset);
|
||||
hipLaunchKernelGGL(
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::cuda::getCurrentCUDAStream(), philox_args, rng_state_ptr);
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr);
|
||||
}
|
||||
|
||||
// remove const from attn_bias_
|
||||
@ -320,7 +320,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
||||
|
||||
if (max_seqlen_k > 0) {
|
||||
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
auto traits =
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <ATen/TensorIndexing.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/HIPGraphsUtils.cuh>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
||||
62
benchmarks/dynamo/pr_time_benchmarks/benchmarks/dtensor.py
Normal file
62
benchmarks/dynamo/pr_time_benchmarks/benchmarks/dtensor.py
Normal file
@ -0,0 +1,62 @@
|
||||
import sys
|
||||
|
||||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DTensor, Replicate
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
|
||||
class BenchmarkDTensorDispatch(BenchmarkBase):
|
||||
def __init__(self, operator, world_size) -> None:
|
||||
super().__init__(
|
||||
category=f"dtensor_dispatch_{operator}",
|
||||
device="cuda",
|
||||
)
|
||||
self.world_size = world_size
|
||||
|
||||
def name(self) -> str:
|
||||
prefix = f"{self.category()}"
|
||||
return prefix
|
||||
|
||||
def description(self) -> str:
|
||||
return f"DTensor dispatch time for {self.category()}"
|
||||
|
||||
def _prepare_once(self) -> None:
|
||||
self.mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
"cuda", (self.world_size,), mesh_dim_names=("dp",)
|
||||
)
|
||||
self.a = DTensor.from_local(
|
||||
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
|
||||
)
|
||||
self.b = DTensor.from_local(
|
||||
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
|
||||
)
|
||||
|
||||
def _prepare(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkDetach(BenchmarkDTensorDispatch):
|
||||
def __init__(self, world_size) -> None:
|
||||
super().__init__(operator="detach", world_size=world_size)
|
||||
|
||||
def _work(self) -> None:
|
||||
self.a.detach()
|
||||
|
||||
|
||||
def main():
|
||||
world_size = 256
|
||||
fake_store = FakeStore()
|
||||
torch.distributed.init_process_group(
|
||||
"fake", store=fake_store, rank=0, world_size=world_size
|
||||
)
|
||||
result_path = sys.argv[1]
|
||||
BenchmarkDetach(world_size).enable_instruction_count().collect_all().append_results(
|
||||
result_path
|
||||
)
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -125,6 +125,17 @@ AttentionType = Literal[
|
||||
]
|
||||
DtypeString = Literal["bfloat16", "float16", "float32"]
|
||||
SpeedupType = Literal["fwd", "bwd"]
|
||||
# Operator Name mapping
|
||||
backend_to_operator_name = {
|
||||
"math": "math attention kernel",
|
||||
"efficient": "efficient attention kernel",
|
||||
"cudnn": "cudnn attention kernel",
|
||||
"fav2": "flash attention 2 kernel",
|
||||
"fav3": "flash attention 3 kernel",
|
||||
"fakv": "flash attention kv cache kernel",
|
||||
"og-eager": "eager attention kernel",
|
||||
"flex": "flex attention kernel",
|
||||
}
|
||||
|
||||
|
||||
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
||||
@ -1265,12 +1276,14 @@ def _output_json_for_dashboard(
|
||||
model: ModelInfo
|
||||
metric: MetricInfo
|
||||
|
||||
operator_name = backend_to_operator_name.get(backend, backend)
|
||||
|
||||
# Benchmark extra info
|
||||
benchmark_extra_info = {
|
||||
"input_config": input_config,
|
||||
"device": device,
|
||||
"arch": device_arch,
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
"attn_type": config.attn_type,
|
||||
"shape": str(config.shape),
|
||||
"max_autotune": config.max_autotune,
|
||||
@ -1288,7 +1301,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
"attn_type": config.attn_type,
|
||||
},
|
||||
),
|
||||
@ -1315,7 +1328,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
@ -1341,7 +1354,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
@ -1371,7 +1384,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
|
||||
@ -19,6 +19,17 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
using CaptureId_t = unsigned long long;
|
||||
// first is set if the instance is created by CUDAGraph::capture_begin.
|
||||
// second is set if the instance is created by at::cuda::graph_pool_handle.
|
||||
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
|
||||
|
||||
struct MempoolIdHash {
|
||||
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
|
||||
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
|
||||
}
|
||||
};
|
||||
|
||||
// A DataPtr is a unique pointer (with an attached deleter and some
|
||||
// context for the deleter) to some memory, which also records what
|
||||
// device is for its data.
|
||||
@ -120,9 +131,8 @@ class C10_API DataPtr {
|
||||
}
|
||||
// Unsafely mutates the device on a DataPtr. Under normal use,
|
||||
// you should never actually need to call this function.
|
||||
// We used to need this for the implementation of the hack detailed
|
||||
// in Note [Masquerading as CUDA], but that hack has been removed.
|
||||
// Other uses of this function now exist so it cannot be deprecated.
|
||||
// We need this for the implementation of the hack detailed
|
||||
// in Note [Masquerading as CUDA]
|
||||
void unsafe_set_device(Device device) {
|
||||
device_ = device;
|
||||
}
|
||||
|
||||
@ -96,6 +96,13 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
|
||||
// Resets peak memory usage statistics for the specified device
|
||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||
|
||||
// Return the free memory size and total memory size in bytes for the
|
||||
// specified device.
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "getMemoryInfo is not implemented for this allocator yet.");
|
||||
}
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
|
||||
@ -27,6 +27,7 @@
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -205,6 +206,12 @@ inline bool isSignedType(ScalarType t) {
|
||||
break;
|
||||
// Do not add default here, but rather define behavior of every new entry
|
||||
// here. `-Wswitch-enum` would raise a warning in those cases.
|
||||
// TODO: get PyTorch to adopt exhaustive switches by default with a way to
|
||||
// opt specific switches to being non-exhaustive.
|
||||
// Exhaustive:
|
||||
// `-Wswitch-enum`, `-Wswitch-default`, `-Wno-covered-switch-default`
|
||||
// Non-Exhaustive:
|
||||
// `-Wno-switch-enum`, `-Wswitch-default`, `-Wcovered-switch-default`
|
||||
}
|
||||
TORCH_CHECK(false, "Unknown ScalarType ", t);
|
||||
#undef CASE_ISSIGNED
|
||||
|
||||
@ -57,6 +57,8 @@ C10_DECLARE_bool(caffe2_keep_on_shrink);
|
||||
// respect caffe2_keep_on_shrink.
|
||||
C10_DECLARE_int64(caffe2_max_keep_on_shrink_memory);
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace at {
|
||||
class Tensor;
|
||||
class TensorBase;
|
||||
@ -3303,3 +3305,5 @@ static_assert(
|
||||
#undef C10_GCC_VERSION_MINOR
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -1012,12 +1012,6 @@ PrivatePoolState::PrivatePoolState(
|
||||
}
|
||||
}
|
||||
|
||||
struct MempoolIdHash {
|
||||
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
|
||||
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
|
||||
}
|
||||
};
|
||||
|
||||
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
|
||||
if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) {
|
||||
*ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size);
|
||||
@ -4510,66 +4504,3 @@ std::atomic<CUDAAllocator*> allocator;
|
||||
static BackendStaticInitializer backend_static_initializer;
|
||||
} // namespace cuda::CUDACachingAllocator
|
||||
} // namespace c10
|
||||
|
||||
namespace c10::cuda {
|
||||
|
||||
// uid_ is incremented when a user creates a MemPool,
|
||||
// for example: using graph_pool_handle() or c10::cuda::MemPool().
|
||||
//
|
||||
// uuid_ is incremented when CUDAGraph creates a MemPool
|
||||
// as a result of a user not providing a pool.
|
||||
//
|
||||
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
|
||||
// passed to a function, either by user or CUDAGraphs. For example,
|
||||
// default value of MempoolId_t for capture_begin function is {0, 0}.
|
||||
// That's why uid_ and uuid_ start at 1.
|
||||
std::atomic<CaptureId_t> MemPool::uid_{1};
|
||||
std::atomic<CaptureId_t> MemPool::uuid_{1};
|
||||
|
||||
MemPool::MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created,
|
||||
bool use_on_oom)
|
||||
: allocator_(allocator), is_user_created_(is_user_created) {
|
||||
if (is_user_created_) {
|
||||
id_ = {0, uid_++};
|
||||
} else {
|
||||
id_ = {uuid_++, 0};
|
||||
}
|
||||
device_ = c10::cuda::current_device();
|
||||
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
|
||||
if (use_on_oom) {
|
||||
CUDACachingAllocator::setUseOnOOM(device_, id_);
|
||||
}
|
||||
}
|
||||
|
||||
MemPool::~MemPool() {
|
||||
TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
CUDACachingAllocator::releasePool(device_, id_);
|
||||
c10::cuda::CUDACachingAllocator::emptyCache(id_);
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::id() {
|
||||
return id_;
|
||||
}
|
||||
|
||||
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
int MemPool::use_count() {
|
||||
return CUDACachingAllocator::getPoolUseCount(device_, id_);
|
||||
}
|
||||
|
||||
c10::DeviceIndex MemPool::device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
|
||||
if (is_user_created) {
|
||||
return {0, uid_++};
|
||||
}
|
||||
return {uuid_++, 0};
|
||||
}
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
||||
@ -345,6 +345,13 @@ class CUDAAllocator : public DeviceAllocator {
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<AllocatorState> pps) = 0;
|
||||
virtual std::string name() = 0;
|
||||
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
|
||||
c10::DeviceGuard device_guard({at::kCUDA, device});
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
||||
return {free, total};
|
||||
}
|
||||
};
|
||||
|
||||
// Allocator object, statically initialized
|
||||
@ -555,41 +562,7 @@ inline std::string getUserMetadata() {
|
||||
} // namespace c10::cuda::CUDACachingAllocator
|
||||
|
||||
namespace c10::cuda {
|
||||
|
||||
// Keep BC only
|
||||
using c10::CaptureId_t;
|
||||
using c10::MempoolId_t;
|
||||
|
||||
// MemPool represents a pool of memory in a caching allocator. Currently,
|
||||
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
||||
//
|
||||
// An allocator pointer can be passed to the MemPool to define how the
|
||||
// allocations should be done in the pool. For example: using a different
|
||||
// system allocator such as ncclMemAlloc.
|
||||
struct C10_CUDA_API MemPool {
|
||||
MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
|
||||
bool is_user_created = true,
|
||||
bool use_on_oom = false);
|
||||
MemPool(const MemPool&) = delete;
|
||||
MemPool(MemPool&&) = default;
|
||||
MemPool& operator=(const MemPool&) = delete;
|
||||
MemPool& operator=(MemPool&&) = default;
|
||||
~MemPool();
|
||||
|
||||
MempoolId_t id();
|
||||
CUDACachingAllocator::CUDAAllocator* allocator();
|
||||
int use_count();
|
||||
c10::DeviceIndex device();
|
||||
static MempoolId_t graph_pool_handle(bool is_user_created = true);
|
||||
|
||||
private:
|
||||
static std::atomic<CaptureId_t> uid_;
|
||||
static std::atomic<CaptureId_t> uuid_;
|
||||
CUDACachingAllocator::CUDAAllocator* allocator_;
|
||||
bool is_user_created_;
|
||||
MempoolId_t id_;
|
||||
c10::DeviceIndex device_;
|
||||
};
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
||||
@ -295,11 +295,19 @@ DeviceAssertionsData* CUDAKernelLaunchRegistry::
|
||||
C10_CUDA_CHECK_WO_DSA(
|
||||
cudaMallocManaged(&uvm_assertions_ptr, sizeof(DeviceAssertionsData)));
|
||||
|
||||
#if CUDART_VERSION >= 13000
|
||||
cudaMemLocation cpuDevice;
|
||||
cpuDevice.type = cudaMemLocationTypeDevice;
|
||||
cpuDevice.id = cudaCpuDeviceId;
|
||||
#else
|
||||
const auto cpuDevice = cudaCpuDeviceId;
|
||||
#endif
|
||||
|
||||
C10_CUDA_CHECK_WO_DSA(cudaMemAdvise(
|
||||
uvm_assertions_ptr,
|
||||
sizeof(DeviceAssertionsData),
|
||||
cudaMemAdviseSetPreferredLocation,
|
||||
cudaCpuDeviceId));
|
||||
cpuDevice));
|
||||
|
||||
// GPU will establish direct mapping of data in CPU memory, no page faults
|
||||
// will be generated
|
||||
@ -307,7 +315,7 @@ DeviceAssertionsData* CUDAKernelLaunchRegistry::
|
||||
uvm_assertions_ptr,
|
||||
sizeof(DeviceAssertionsData),
|
||||
cudaMemAdviseSetAccessedBy,
|
||||
cudaCpuDeviceId));
|
||||
cpuDevice));
|
||||
|
||||
// Initialize the memory from the CPU; otherwise, pages may have to be created
|
||||
// on demand. We think that UVM documentation indicates that first access may
|
||||
|
||||
@ -4,13 +4,12 @@
|
||||
#include <c10/util/TypeSafeSignMath.h>
|
||||
#include <cmath>
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign
|
||||
#elif defined(__HIPCC__)
|
||||
#include <c10/hip/HIPMathCompat.h>
|
||||
#endif
|
||||
#define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign
|
||||
#define C10_COMPAT_COPYSIGN c10::hip::compat::copysign
|
||||
#else
|
||||
#include <c10/util/copysign.h>
|
||||
#define C10_COMPAT_COPYSIGN c10::copysign
|
||||
|
||||
@ -926,15 +926,14 @@ class DeviceCachingAllocator {
|
||||
(release_cached_blocks() && alloc_block(params, true));
|
||||
}
|
||||
if (!block_found) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
const auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
const auto device_total =
|
||||
raw_device.get_info<sycl::info::device::global_mem_size>();
|
||||
// Estimate the available device memory when the SYCL runtime does not
|
||||
// support the corresponding aspect (ext_intel_free_memory).
|
||||
size_t device_free = device_prop.global_mem_size -
|
||||
size_t device_free = device_total -
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current;
|
||||
auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
|
||||
// affected devices.
|
||||
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
|
||||
@ -1052,21 +1051,37 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo() {
|
||||
const auto& device = c10::xpu::get_raw_device(device_index);
|
||||
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
device.get_info<sycl::info::device::name>(),
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
const size_t free =
|
||||
device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return {free, total};
|
||||
}
|
||||
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_prop.global_mem_size);
|
||||
static_cast<double>(device_total);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
|
||||
set_fraction = true;
|
||||
}
|
||||
@ -1240,6 +1255,11 @@ class XPUAllocator : public DeviceAllocator {
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryInfo();
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# This will define the following variables:
|
||||
# SYCL_FOUND : True if the system has the SYCL library.
|
||||
# SYCL_INCLUDE_DIR : Include directories needed to use SYCL.
|
||||
# SYCL_LIBRARY_DIR :The path to the SYCL library.
|
||||
# SYCL_LIBRARY_DIR : The path to the SYCL library.
|
||||
# SYCL_LIBRARY : SYCL library fullname.
|
||||
# SYCL_COMPILER_VERSION : SYCL compiler version.
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_memory_info
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
memory_allocated
|
||||
|
||||
164
docs/source/accelerator/hooks.md
Normal file
164
docs/source/accelerator/hooks.md
Normal file
@ -0,0 +1,164 @@
|
||||
# Accelerator Hooks
|
||||
|
||||
## Background
|
||||
|
||||
OpenReg hooks provide a mechanism for integrating custom accelerator devices into PyTorch's runtime system. OpenReg (Open Registration) is PyTorch's extensibility framework that allows accelerator vendors to register custom device backends without modifying PyTorch core code.
|
||||
|
||||
## Design
|
||||
|
||||
The following tables list all hooks that accelerator vendors need to implement when integrating a new device backend. These hooks are categorized into two priority levels:
|
||||
|
||||
- **High Priority Hooks**: Core APIs that PyTorch runtime directly depends on. Accelerator vendors are recommended to implement all high priority hooks to ensure full PyTorch compatibility and enable basic device functionality.
|
||||
|
||||
- **Low Priority Hooks**: Device management and utility APIs that PyTorch does not directly depend on. These hooks enhance user experience and multi-device support but are *optional*. Accelerator vendors can choose to implement them based on their specific requirements and use cases.
|
||||
|
||||
### High Priority Hooks
|
||||
|
||||
| Hook Method | Description | Application Scenario |
|
||||
| ---------------------------------- | --------------------------------------------------------- | -------------------------------------------------------------------------------- |
|
||||
| `init()` | Initializes the accelerator runtime and device contexts | Set up necessary state when PyTorch first accesses the device |
|
||||
| `hasPrimaryContext(DeviceIndex)` | Checks if a primary context exists for the device | Determine whether device initialization has occurred |
|
||||
| `getDefaultGenerator(DeviceIndex)` | Returns the default random number generator for a device | Access the device's primary RNG for reproducible random operations |
|
||||
| `getNewGenerator(DeviceIndex)` | Creates a new independent random number generator | Create isolated RNG instances for parallel operations |
|
||||
| `getDeviceFromPtr(void*)` | Determines which device a memory pointer belongs to | Identify the accelerator device associated with a memory allocation |
|
||||
| `getPinnedMemoryAllocator()` | Returns an allocator for pinned (page-locked) host memory | Allocate host memory that can be efficiently transferred to/from the accelerator |
|
||||
| `isPinnedPtr(void*)` | Checks if a pointer points to pinned memory | Validate memory types before performing operations |
|
||||
|
||||
### Low Priority Hooks
|
||||
|
||||
| Hook Method | Description | Application Scenario |
|
||||
| ---------------------------------- | ---------------------------------------------------------------------------- | -------------------------------------------------------------------- |
|
||||
| `isBuilt()` | Returns whether the accelerator backend is built/compiled into the extension | Check whether the accelerator library is available at compile time |
|
||||
| `isAvailable()` | Returns whether the accelerator hardware is available at runtime | Verify whether accelerator devices can be detected and initialized |
|
||||
| `deviceCount()` | Returns the number of available accelerator devices | Enumerate all available accelerator devices for device selection |
|
||||
| `setCurrentDevice(DeviceIndex)` | Sets the active device for the current thread | Switch the current thread's context to a specific accelerator device |
|
||||
| `getCurrentDevice()` | Returns the currently active device index | Query which accelerator device is active in the current thread |
|
||||
| `exchangeDevice(DeviceIndex)` | Atomically exchanges the current device and returns the previous one | Temporarily switch devices and restore the previous device afterward |
|
||||
| `maybeExchangeDevice(DeviceIndex)` | Conditionally exchanges device only if the index is valid | Safely attempt device switching with validation |
|
||||
|
||||
## Implementation
|
||||
|
||||
We can just take `getDefaultGenerator` as an implementation example:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
|
||||
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
|
||||
:linenos:
|
||||
```
|
||||
|
||||
In this implementation:
|
||||
|
||||
1. **Override the base interface**: The `getDefaultGenerator` method overrides the virtual method from `at::PrivateUse1HooksInterface`.
|
||||
|
||||
2. **Delegate to device-specific implementation**: It calls `getDefaultOpenRegGenerator(device_index)`, which manages a per-device generator instance.
|
||||
|
||||
3. **Return device-specific generator**: The returned `at::Generator` wraps an `OpenRegGeneratorImpl` that implements device-specific random number generation.
|
||||
|
||||
This pattern applies to all hooks: override the interface method, validate inputs, delegate to your device-specific API, and return results in PyTorch's expected format.
|
||||
|
||||
## Integration Example
|
||||
|
||||
The following sections demonstrate how PyTorch integrates with accelerator hooks when accessing the default random number generator. The example traces the complete flow from user-facing Python code down to the device-specific implementation.
|
||||
|
||||
### Layer 1: User Code
|
||||
|
||||
User code initiates the operation by calling `manual_seed` to set the random seed for reproducible results:
|
||||
|
||||
```python
|
||||
import torch
|
||||
torch.openreg.manual_seed(42)
|
||||
```
|
||||
|
||||
### Layer 2: Extension Python API
|
||||
|
||||
The Python API layer handles device management and calls into the C++ extension (defined in [`torch_openreg/openreg/random.py`][random.py]):
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py
|
||||
:language: python
|
||||
:start-after: LITERALINCLUDE START: OPENREG MANUAL SEED
|
||||
:end-before: LITERALINCLUDE END: OPENREG MANUAL SEED
|
||||
:linenos:
|
||||
```
|
||||
|
||||
The `manual_seed` function gets the current device index and calls `torch_openreg._C._get_default_generator(idx)` to obtain the device-specific generator, then sets the seed on it.
|
||||
|
||||
### Layer 3: Python/C++ Bridge
|
||||
|
||||
The C++ extension exposes `_getDefaultGenerator` to Python, which bridges to PyTorch's core runtime:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR
|
||||
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
|
||||
:linenos:
|
||||
:emphasize-lines: 10-11
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
|
||||
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
|
||||
:linenos:
|
||||
:emphasize-lines: 3
|
||||
```
|
||||
|
||||
This function unpacks the device index from Python, creates a `PrivateUse1` device object, and calls `at::globalContext().defaultGenerator()`. PyTorch's context then dispatches to the registered hooks.
|
||||
|
||||
### Layer 4: PyTorch Core Context
|
||||
|
||||
PyTorch's Context class dispatches to the appropriate accelerator hooks ([`aten/src/ATen/Context.h`][Context.h]):
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../aten/src/ATen/Context.h
|
||||
:language: c++
|
||||
:lines: 60-103
|
||||
:linenos:
|
||||
:emphasize-lines: 8-9, 24-25
|
||||
```
|
||||
|
||||
This layered architecture enables PyTorch to remain device-agnostic while delegating hardware-specific operations to accelerator implementations. The hooks are registered once at module load time:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG HOOK REGISTER
|
||||
:end-before: LITERALINCLUDE END: OPENREG HOOK REGISTER
|
||||
:linenos:
|
||||
:emphasize-lines: 4
|
||||
```
|
||||
|
||||
### Layer 5: Accelerator Hooks
|
||||
|
||||
The hooks interface provides the abstraction that PyTorch uses to delegate to device-specific implementations:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
|
||||
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
|
||||
:linenos:
|
||||
```
|
||||
|
||||
The `getDefaultGenerator` hook method overrides the base interface and delegates to `getDefaultOpenRegGenerator`, which manages the actual generator instances.
|
||||
|
||||
### Layer 6: Device-Specific Implementation
|
||||
|
||||
The device-specific implementation manages per-device generator instances:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR IMPL
|
||||
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR IMPL
|
||||
:linenos:
|
||||
```
|
||||
|
||||
This function maintains a static vector of generators (one per device), initializes them on first access, validates the device index, and returns the appropriate generator instance.
|
||||
|
||||
[random.py]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py#L48-L53 "random.py"
|
||||
[Context.h]: https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/Context.h#L61-L102 "Context.h"
|
||||
@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
|
||||
:glob:
|
||||
:maxdepth: 1
|
||||
|
||||
hooks
|
||||
autoload
|
||||
operators
|
||||
amp
|
||||
|
||||
@ -24,15 +24,11 @@ def gen_data(special_op_lists, analysis_name):
|
||||
all_ops = get_ops_for_key(None)
|
||||
composite_ops = get_ops_for_key("CompositeImplicitAutograd")
|
||||
noncomposite_ops = all_ops - composite_ops
|
||||
with open("../../aten/src/ATen/native/native_functions.yaml") as f:
|
||||
ops = yaml.load(f.read(), Loader=yaml.CLoader)
|
||||
|
||||
ops = yaml.load(
|
||||
open("../../aten/src/ATen/native/native_functions.yaml").read(),
|
||||
Loader=yaml.CLoader,
|
||||
)
|
||||
|
||||
annotated_ops = {
|
||||
a.strip(): b.strip() for a, b in list(csv.reader(open("annotated_ops")))
|
||||
}
|
||||
with open("annotated_ops") as f:
|
||||
annotated_ops = {a.strip(): b.strip() for a, b in csv.reader(f)}
|
||||
|
||||
uniq_ops = []
|
||||
uniq_names = set()
|
||||
|
||||
47
setup.py
47
setup.py
@ -1358,6 +1358,45 @@ class concat_license_files:
|
||||
|
||||
# Need to create the proper LICENSE.txt for the wheel
|
||||
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
|
||||
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
|
||||
|
||||
Excludes:
|
||||
- torch/include/torch/headeronly/*
|
||||
- torch/include/torch/csrc/stable/*
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/generated/
|
||||
"""
|
||||
header_extensions = (".h", ".hpp", ".cuh")
|
||||
header_files = [
|
||||
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
|
||||
]
|
||||
|
||||
# Paths to exclude from wrapping
|
||||
exclude_dir_patterns = [
|
||||
"torch/include/torch/headeronly/",
|
||||
"torch/include/torch/csrc/stable/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/c/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
|
||||
]
|
||||
|
||||
for header_file in header_files:
|
||||
rel_path = header_file.relative_to(bdist_dir).as_posix()
|
||||
|
||||
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
|
||||
report(f"Skipping header: {rel_path}")
|
||||
continue
|
||||
|
||||
original_content = header_file.read_text(encoding="utf-8")
|
||||
wrapped_content = (
|
||||
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
f"{original_content}"
|
||||
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
)
|
||||
|
||||
header_file.write_text(wrapped_content, encoding="utf-8")
|
||||
report(f"Wrapped header: {rel_path}")
|
||||
|
||||
def run(self) -> None:
|
||||
with concat_license_files(include_files=True):
|
||||
super().run()
|
||||
@ -1380,6 +1419,14 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
# need an __init__.py file otherwise we wouldn't have a package
|
||||
(bdist_dir / "torch" / "__init__.py").touch()
|
||||
|
||||
# Wrap all header files with TORCH_STABLE_ONLY macro
|
||||
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
|
||||
bdist_dir = Path(self.bdist_dir)
|
||||
report(
|
||||
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
|
||||
)
|
||||
self._wrap_headers_with_macro(bdist_dir)
|
||||
|
||||
|
||||
class clean(Command):
|
||||
user_options: ClassVar[list[tuple[str, str | None, str]]] = []
|
||||
|
||||
@ -308,12 +308,16 @@ class StepcurrentPlugin:
|
||||
self.report_status = ""
|
||||
assert config.cache is not None
|
||||
self.cache: pytest.Cache = config.cache
|
||||
self.directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}"
|
||||
self.lastrun: Optional[str] = self.cache.get(self.directory, None)
|
||||
directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}"
|
||||
self.lastrun_location = f"{directory}/lastrun"
|
||||
self.lastrun: Optional[str] = self.cache.get(self.lastrun_location, None)
|
||||
self.initial_val = self.lastrun
|
||||
self.skip: bool = config.getoption("stepcurrent_skip")
|
||||
self.run_single: bool = config.getoption("run_single")
|
||||
|
||||
self.made_failing_xml_location = f"{directory}/made_failing_xml"
|
||||
self.cache.set(self.made_failing_xml_location, False)
|
||||
|
||||
def pytest_collection_modifyitems(self, config: Config, items: list[Any]) -> None:
|
||||
if not self.lastrun:
|
||||
self.report_status = "Cannot find last run test, not skipping"
|
||||
@ -349,8 +353,10 @@ class StepcurrentPlugin:
|
||||
|
||||
def pytest_runtest_protocol(self, item, nextitem) -> None:
|
||||
self.lastrun = item.nodeid
|
||||
self.cache.set(self.directory, self.lastrun)
|
||||
self.cache.set(self.lastrun_location, self.lastrun)
|
||||
|
||||
def pytest_sessionfinish(self, session, exitstatus):
|
||||
if exitstatus == 0:
|
||||
self.cache.set(self.directory, self.initial_val)
|
||||
self.cache.set(self.lastrun_location, self.initial_val)
|
||||
if exitstatus != 0:
|
||||
self.cache.set(self.made_failing_xml_location, True)
|
||||
|
||||
@ -38,7 +38,7 @@ using torch::stable::Tensor;
|
||||
Tensor sgd_out_of_place(
|
||||
const Tensor param,
|
||||
const Tensor grad,
|
||||
const float weight_decay,
|
||||
const double weight_decay,
|
||||
const double lr,
|
||||
const bool maximize) {
|
||||
STD_TORCH_CHECK(param.dim() == 1, "param must be 1D");
|
||||
@ -57,7 +57,7 @@ Tensor sgd_out_of_place(
|
||||
reinterpret_cast<float*>(param.data_ptr()),
|
||||
reinterpret_cast<float*>(grad.data_ptr()),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
weight_decay,
|
||||
float(weight_decay),
|
||||
lr,
|
||||
maximize,
|
||||
param.numel()
|
||||
@ -66,44 +66,29 @@ Tensor sgd_out_of_place(
|
||||
return out;
|
||||
}
|
||||
|
||||
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = sgd_out_of_place(
|
||||
torch::stable::detail::to<Tensor>(stack[0]),
|
||||
torch::stable::detail::to<Tensor>(stack[1]),
|
||||
float(torch::stable::detail::to<double>(stack[2])),
|
||||
torch::stable::detail::to<double>(stack[3]),
|
||||
torch::stable::detail::to<bool>(stack[4]));
|
||||
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
|
||||
m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
m.impl("sgd_out_of_place", &boxed_sgd_out_of_place);
|
||||
m.impl("sgd_out_of_place", TORCH_BOX(&sgd_out_of_place));
|
||||
}
|
||||
|
||||
Tensor identity(Tensor t) {
|
||||
return t;
|
||||
}
|
||||
|
||||
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = identity(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("identity(Tensor t) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
|
||||
m.impl("identity", &boxed_identity);
|
||||
m.impl("identity", TORCH_BOX(&identity));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
m.impl("identity", &boxed_identity);
|
||||
m.impl("identity", TORCH_BOX(&identity));
|
||||
}
|
||||
|
||||
Tensor my_abs(Tensor t) {
|
||||
@ -114,17 +99,12 @@ Tensor my_abs(Tensor t) {
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_abs(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_abs(Tensor t) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_abs", &boxed_my_abs);
|
||||
m.impl("my_abs", TORCH_BOX(&my_abs));
|
||||
}
|
||||
|
||||
Tensor my_ones_like(Tensor t, StableIValue device) {
|
||||
@ -145,17 +125,12 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = my_ones_like(torch::stable::detail::to<Tensor>(stack[0]), stack[1]);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_ones_like(Tensor t, Device d) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_ones_like", &boxed_my_ones_like);
|
||||
m.impl("my_ones_like", TORCH_BOX(&my_ones_like));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
|
||||
@ -177,19 +152,12 @@ std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3
|
||||
torch::stable::detail::to<bool>(stack_is_leaf[0]));
|
||||
}
|
||||
|
||||
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto tuple = exp_neg_is_leaf(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<Tensor>(stack[2]));
|
||||
stack[0] = torch::stable::detail::from(std::get<0>(tuple));
|
||||
stack[1] = torch::stable::detail::from(std::get<1>(tuple));
|
||||
stack[2] = torch::stable::detail::from(std::get<2>(tuple));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("exp_neg_is_leaf", &boxed_exp_neg_is_leaf);
|
||||
m.impl("exp_neg_is_leaf", TORCH_BOX(&exp_neg_is_leaf));
|
||||
}
|
||||
|
||||
Tensor neg_exp(Tensor t) {
|
||||
@ -200,17 +168,12 @@ Tensor neg_exp(Tensor t) {
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("neg_exp(Tensor t) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("neg_exp", &boxed_neg_exp);
|
||||
m.impl("neg_exp", TORCH_BOX(&neg_exp));
|
||||
}
|
||||
|
||||
Tensor divide_neg_exp(Tensor t) {
|
||||
@ -229,108 +192,53 @@ Tensor divide_neg_exp(Tensor t) {
|
||||
return torch::stable::detail::to<Tensor>(stack_div[0]);
|
||||
}
|
||||
|
||||
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = divide_neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("divide_neg_exp(Tensor t) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("divide_neg_exp", &boxed_divide_neg_exp);
|
||||
m.impl("divide_neg_exp", TORCH_BOX(÷_neg_exp));
|
||||
}
|
||||
|
||||
bool is_contiguous(Tensor t) {
|
||||
return t.is_contiguous();
|
||||
}
|
||||
|
||||
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
bool res = is_contiguous(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("is_contiguous(Tensor t) -> bool");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("is_contiguous", &boxed_is_contiguous);
|
||||
m.impl("is_contiguous", TORCH_BOX(&is_contiguous));
|
||||
}
|
||||
|
||||
Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
|
||||
return transpose(t, dim0, dim1);
|
||||
}
|
||||
|
||||
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_transpose(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<int64_t>(stack[1]), torch::stable::detail::to<int64_t>(stack[2]));
|
||||
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_empty_like(Tensor t) {
|
||||
return empty_like(t);
|
||||
}
|
||||
|
||||
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_empty_like(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
bool my_is_cpu(Tensor t) {
|
||||
return t.is_cpu();
|
||||
}
|
||||
|
||||
|
||||
void boxed_my_is_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_is_cpu(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor fill_infinity(Tensor t) {
|
||||
auto value = std::numeric_limits<float>::infinity();
|
||||
return fill_(t, value);
|
||||
}
|
||||
|
||||
void boxed_fill_infinity(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = fill_infinity(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_pad(Tensor t) {
|
||||
std::string mode = "constant";
|
||||
double value = 0.0;
|
||||
return pad(t, {1, 2, 2, 1}, mode, value);
|
||||
}
|
||||
|
||||
void boxed_my_pad(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_pad(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
|
||||
return narrow(t, dim, start, length);
|
||||
}
|
||||
|
||||
void boxed_my_narrow(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_narrow(
|
||||
torch::stable::detail::to<Tensor>(stack[0]),
|
||||
torch::stable::detail::to<int64_t>(stack[1]),
|
||||
torch::stable::detail::to<int64_t>(stack[2]),
|
||||
torch::stable::detail::to<int64_t>(stack[3]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
// Still using a std::vector below even though people can just pass in an
|
||||
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
|
||||
@ -342,40 +250,19 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
return new_empty(t, sizes, dtype);
|
||||
}
|
||||
|
||||
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_empty_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
auto dtype = std::make_optional(at::ScalarType::Float);
|
||||
return new_zeros(t, {2, 5}, dtype);
|
||||
}
|
||||
|
||||
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_zeros_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
|
||||
return copy_(dst, src, non_blocking);
|
||||
}
|
||||
|
||||
void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_copy_(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<bool>(stack[2]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
Tensor my_clone(Tensor t) {
|
||||
return clone(t);
|
||||
}
|
||||
|
||||
void boxed_my_clone(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_clone(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
|
||||
m.def("my_empty_like(Tensor t) -> Tensor");
|
||||
@ -389,57 +276,39 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_transpose", &boxed_my_transpose);
|
||||
m.impl("my_empty_like", &boxed_empty_like);
|
||||
m.impl("fill_infinity", &boxed_fill_infinity);
|
||||
m.impl("my_is_cpu", &boxed_my_is_cpu);
|
||||
m.impl("my_new_empty_dtype_variant", &boxed_my_new_empty_dtype_variant);
|
||||
m.impl("my_new_zeros_dtype_variant", &boxed_my_new_zeros_dtype_variant);
|
||||
m.impl("my_copy_", &boxed_my_copy_);
|
||||
m.impl("my_clone", &boxed_my_clone);
|
||||
m.impl("my_transpose", TORCH_BOX(&my_transpose));
|
||||
m.impl("my_empty_like", TORCH_BOX(&my_empty_like));
|
||||
m.impl("fill_infinity", TORCH_BOX(&fill_infinity));
|
||||
m.impl("my_is_cpu", TORCH_BOX(&my_is_cpu));
|
||||
m.impl("my_new_empty_dtype_variant", TORCH_BOX(&my_new_empty_dtype_variant));
|
||||
m.impl("my_new_zeros_dtype_variant", TORCH_BOX(&my_new_zeros_dtype_variant));
|
||||
m.impl("my_copy_", TORCH_BOX(&my_copy_));
|
||||
m.impl("my_clone", TORCH_BOX(&my_clone));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {
|
||||
m.impl("my_pad", &boxed_my_pad);
|
||||
m.impl("my_narrow", &boxed_my_narrow);
|
||||
m.impl("my_pad", TORCH_BOX(&my_pad));
|
||||
m.impl("my_narrow", TORCH_BOX(&my_narrow));
|
||||
}
|
||||
|
||||
Tensor my_zero_(Tensor t) {
|
||||
return zero_(t);
|
||||
}
|
||||
|
||||
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_zero_(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax(Tensor t) {
|
||||
return amax(t, 0, false);
|
||||
}
|
||||
|
||||
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax_vec(Tensor t) {
|
||||
return amax(t, {0,1}, false);
|
||||
}
|
||||
|
||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax_vec(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)");
|
||||
m.def("my_amax(Tensor a) -> Tensor");
|
||||
m.def("my_amax_vec(Tensor a) -> Tensor");
|
||||
m.def("my_is_cpu(Tensor t) -> bool");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
m.impl("my_zero_", &boxed_my_zero_);
|
||||
m.def("test_default_constructor(bool undefined) -> bool");
|
||||
}
|
||||
|
||||
bool test_default_constructor(bool defined) {
|
||||
@ -461,22 +330,12 @@ bool test_default_constructor(bool defined) {
|
||||
return out.defined();
|
||||
}
|
||||
|
||||
void boxed_test_default_constructor(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
bool res = test_default_constructor(torch::stable::detail::to<bool>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("test_default_constructor(bool undefined) -> bool");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_default_constructor", &boxed_test_default_constructor);
|
||||
m.impl("my_amax", &boxed_my_amax);
|
||||
m.impl("my_amax_vec", &boxed_my_amax_vec);
|
||||
m.impl("my_zero_", TORCH_BOX(&my_zero_));
|
||||
m.impl("my_amax", TORCH_BOX(&my_amax));
|
||||
m.impl("my_amax_vec", TORCH_BOX(&my_amax_vec));
|
||||
m.impl("test_default_constructor", TORCH_BOX(&test_default_constructor));
|
||||
}
|
||||
|
||||
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
|
||||
@ -485,23 +344,11 @@ std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor
|
||||
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
// Why is the following NOT torch::stable::detail::to<HeaderOnlyArrayRef<Tensor>>(stack[0])? Because calling `to`
|
||||
// on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef
|
||||
// is not owning, so it cannot safely steward the result of the torch::stable::detail::to<>.
|
||||
auto res = my__foreach_mul(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
|
||||
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
|
||||
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
|
||||
}
|
||||
|
||||
void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
my__foreach_mul_(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
|
||||
}
|
||||
|
||||
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
|
||||
// This function tests that my__foreach_mul can take in std::initializer_lists
|
||||
// in addition to std::vectors.
|
||||
@ -512,11 +359,6 @@ std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
|
||||
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
|
||||
}
|
||||
|
||||
void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = make_tensor_clones_and_call_foreach(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
|
||||
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
|
||||
@ -524,9 +366,9 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my__foreach_mul", &boxed_my__foreach_mul);
|
||||
m.impl("my__foreach_mul_", &boxed_my__foreach_mul_);
|
||||
m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach);
|
||||
m.impl("my__foreach_mul", TORCH_BOX(&my__foreach_mul));
|
||||
m.impl("my__foreach_mul_", TORCH_BOX(&my__foreach_mul_));
|
||||
m.impl("make_tensor_clones_and_call_foreach", TORCH_BOX(&make_tensor_clones_and_call_foreach));
|
||||
}
|
||||
|
||||
// Test functions for torch::stable::Tensor device method
|
||||
@ -690,14 +532,6 @@ int64_t test_device_guard(int64_t device_index) {
|
||||
return currentDevice;
|
||||
}
|
||||
|
||||
void boxed_test_device_guard(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int res = test_device_guard(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_device_guard_set_index() {
|
||||
using torch::stable::accelerator::DeviceGuard;
|
||||
|
||||
@ -709,14 +543,6 @@ int64_t test_device_guard_set_index() {
|
||||
return currentDevice;
|
||||
}
|
||||
|
||||
void boxed_test_device_guard_set_index(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_device_guard_set_index();
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_stream(int32_t device_index) {
|
||||
STD_TORCH_CHECK(
|
||||
device_index >= std::numeric_limits<int32_t>::min() &&
|
||||
@ -726,26 +552,10 @@ int64_t test_stream(int32_t device_index) {
|
||||
return torch::stable::accelerator::getCurrentStream(device_index).id();
|
||||
}
|
||||
|
||||
void boxed_test_stream(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_stream(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_get_current_device_index() {
|
||||
return torch::stable::accelerator::getCurrentDeviceIndex();
|
||||
}
|
||||
|
||||
void boxed_test_get_current_device_index(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_get_current_device_index();
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("test_device_guard(int device_index) -> int");
|
||||
m.def("test_device_guard_set_index() -> int");
|
||||
@ -754,10 +564,10 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_device_guard", &boxed_test_device_guard);
|
||||
m.impl("test_device_guard_set_index", &boxed_test_device_guard_set_index);
|
||||
m.impl("test_stream", &boxed_test_stream);
|
||||
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
|
||||
m.impl("test_device_guard", TORCH_BOX(&test_device_guard));
|
||||
m.impl("test_device_guard_set_index", TORCH_BOX(&test_device_guard_set_index));
|
||||
m.impl("test_stream", TORCH_BOX(&test_stream));
|
||||
m.impl("test_get_current_device_index", TORCH_BOX(&test_get_current_device_index));
|
||||
}
|
||||
|
||||
#endif // LAE_USE_CUDA
|
||||
|
||||
@ -33,7 +33,7 @@ class clean(distutils.command.clean.clean):
|
||||
|
||||
def get_extension():
|
||||
extra_compile_args = {
|
||||
"cxx": ["-fdiagnostics-color=always"],
|
||||
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
|
||||
}
|
||||
|
||||
extension = CppExtension
|
||||
|
||||
@ -5,6 +5,7 @@ static std::vector<at::Generator> default_generators;
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
// LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR IMPL
|
||||
const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) {
|
||||
static bool flag [[maybe_unused]] = []() {
|
||||
auto deivce_nums = device_count();
|
||||
@ -24,5 +25,6 @@ const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) {
|
||||
}
|
||||
return default_generators[idx];
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR IMPL
|
||||
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include "OpenRegHooks.h"
|
||||
|
||||
// LITERALINCLUDE START: OPENREG HOOK REGISTER
|
||||
namespace c10::openreg {
|
||||
|
||||
static bool register_hook_flag [[maybe_unused]] = []() {
|
||||
@ -9,3 +10,4 @@ static bool register_hook_flag [[maybe_unused]] = []() {
|
||||
}();
|
||||
|
||||
} // namespace c10::openreg
|
||||
// LITERALINCLUDE END: OPENREG HOOK REGISTER
|
||||
@ -8,17 +8,58 @@
|
||||
|
||||
#include <include/openreg.h>
|
||||
|
||||
#include "OpenRegFunctions.h"
|
||||
#include "OpenRegGenerator.h"
|
||||
|
||||
namespace c10::openreg {
|
||||
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
|
||||
struct OPENREG_EXPORT OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
|
||||
OpenRegHooksInterface() {};
|
||||
~OpenRegHooksInterface() override = default;
|
||||
|
||||
bool hasPrimaryContext(c10::DeviceIndex device_index) const override {
|
||||
void init() const override {
|
||||
// Initialize OpenReg runtime if needed
|
||||
// This is called when PyTorch first accesses the device
|
||||
}
|
||||
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isBuilt() const override {
|
||||
// This extension is compiled as part of the OpenReg test extension.
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isAvailable() const override {
|
||||
// Consider OpenReg available if there's at least one device reported.
|
||||
return device_count() > 0;
|
||||
}
|
||||
|
||||
DeviceIndex deviceCount() const override {
|
||||
return device_count();
|
||||
}
|
||||
|
||||
void setCurrentDevice(DeviceIndex device) const override {
|
||||
set_device(device);
|
||||
}
|
||||
|
||||
DeviceIndex getCurrentDevice() const override {
|
||||
return current_device();
|
||||
}
|
||||
|
||||
DeviceIndex exchangeDevice(DeviceIndex device) const override {
|
||||
return ExchangeDevice(device);
|
||||
}
|
||||
|
||||
DeviceIndex maybeExchangeDevice(DeviceIndex device) const override {
|
||||
// Only exchange if the requested device is valid; otherwise, no-op and return current
|
||||
auto count = device_count();
|
||||
if (device < 0 || device >= count) {
|
||||
return getCurrentDevice();
|
||||
}
|
||||
return exchangeDevice(device);
|
||||
}
|
||||
|
||||
at::Allocator* getPinnedMemoryAllocator() const override {
|
||||
return at::getHostAllocator(at::kPrivateUse1);
|
||||
}
|
||||
@ -30,12 +71,23 @@ struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
|
||||
return attr.type == orMemoryTypeHost;
|
||||
}
|
||||
|
||||
const at::Generator& getDefaultGenerator(
|
||||
c10::DeviceIndex device_index) const override {
|
||||
at::Device getDeviceFromPtr(void* data) const override {
|
||||
orPointerAttributes attr{};
|
||||
auto err = orPointerGetAttributes(&attr, data);
|
||||
if (err == orSuccess && attr.type == orMemoryTypeDevice) {
|
||||
return at::Device(at::DeviceType::PrivateUse1, static_cast<int>(attr.device));
|
||||
} else {
|
||||
TORCH_CHECK(false, "failed to get device from pointer");
|
||||
}
|
||||
return at::Device(at::DeviceType::PrivateUse1, current_device());
|
||||
}
|
||||
// LITERALINCLUDE START: OPENREG HOOK EXAMPLES
|
||||
const at::Generator& getDefaultGenerator(DeviceIndex device_index) const override {
|
||||
return getDefaultOpenRegGenerator(device_index);
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG HOOK EXAMPLES
|
||||
|
||||
at::Generator getNewGenerator(c10::DeviceIndex device_index) const override {
|
||||
at::Generator getNewGenerator(DeviceIndex device_index) const override {
|
||||
return at::make_generator<OpenRegGeneratorImpl>(device_index);
|
||||
}
|
||||
};
|
||||
|
||||
@ -17,6 +17,7 @@ static PyObject* _initExtension(PyObject* self, PyObject* noargs) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR
|
||||
static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
@ -31,6 +32,7 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
|
||||
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
|
||||
|
||||
PyObject* _setDevice(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
@ -73,6 +75,7 @@ PyObject* _getDeviceCount(PyObject* self, PyObject* noargs) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// LITERALINCLUDE START: OPENREG MODULE METHODS
|
||||
static PyMethodDef methods[] = {
|
||||
{"_init", _initExtension, METH_NOARGS, nullptr},
|
||||
{"_get_default_generator", _getDefaultGenerator, METH_O, nullptr},
|
||||
@ -81,7 +84,7 @@ static PyMethodDef methods[] = {
|
||||
{"_exchangeDevice", _exchangeDevice, METH_O, nullptr},
|
||||
{"_get_device_count", _getDeviceCount, METH_NOARGS, nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
|
||||
// LITERALINCLUDE END: OPENREG MODULE METHODS
|
||||
/*
|
||||
* When ASAN is enabled, PyTorch modifies the dlopen flag during import,
|
||||
* causing all global and weak symbols in _C.so and its dependent libraries
|
||||
|
||||
@ -45,6 +45,7 @@ def initial_seed() -> int:
|
||||
return default_generator.initial_seed()
|
||||
|
||||
|
||||
# LITERALINCLUDE START: OPENREG MANUAL SEED
|
||||
def manual_seed(seed: int) -> None:
|
||||
seed = int(seed)
|
||||
|
||||
@ -53,6 +54,9 @@ def manual_seed(seed: int) -> None:
|
||||
default_generator.manual_seed(seed)
|
||||
|
||||
|
||||
# LITERALINCLUDE END: OPENREG MANUAL SEED
|
||||
|
||||
|
||||
def manual_seed_all(seed: int) -> None:
|
||||
seed = int(seed)
|
||||
|
||||
|
||||
@ -1,67 +0,0 @@
|
||||
import distutils.command.clean
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent
|
||||
CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc"
|
||||
|
||||
|
||||
class clean(distutils.command.clean.clean):
|
||||
def run(self):
|
||||
# Run default behavior first
|
||||
distutils.command.clean.clean.run(self)
|
||||
|
||||
# Remove extension
|
||||
for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"):
|
||||
path.unlink()
|
||||
# Remove build and dist and egg-info directories
|
||||
dirs = [
|
||||
ROOT_DIR / "build",
|
||||
ROOT_DIR / "dist",
|
||||
ROOT_DIR / "torch_stable_test.egg-info",
|
||||
]
|
||||
for path in dirs:
|
||||
if path.exists():
|
||||
shutil.rmtree(str(path), ignore_errors=True)
|
||||
|
||||
|
||||
def get_extension():
|
||||
extra_compile_args = {
|
||||
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
|
||||
}
|
||||
|
||||
sources = list(CSRC_DIR.glob("**/*.cpp"))
|
||||
|
||||
return [
|
||||
CppExtension(
|
||||
"torch_stable_test._C",
|
||||
sources=sorted(str(s) for s in sources),
|
||||
py_limited_api=True,
|
||||
extra_compile_args=extra_compile_args,
|
||||
extra_link_args=[],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
setup(
|
||||
name="torch_stable_test",
|
||||
version="0.0",
|
||||
author="PyTorch Core Team",
|
||||
description="Test extension to verify TORCH_STABLE_ONLY flag",
|
||||
packages=find_packages(exclude=("test",)),
|
||||
package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]},
|
||||
install_requires=[
|
||||
"torch",
|
||||
],
|
||||
ext_modules=get_extension(),
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
|
||||
"clean": clean,
|
||||
},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
||||
@ -1 +0,0 @@
|
||||
#include <ATen/core/TensorBase.h> // This should trigger the TORCH_STABLE_ONLY error
|
||||
@ -1,22 +0,0 @@
|
||||
# Owner(s): ["module: cpp"]
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
install_cpp_extension,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
if not IS_WINDOWS:
|
||||
|
||||
class TestTorchStable(TestCase):
|
||||
def test_setup_fails(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"):
|
||||
install_cpp_extension(extension_root=Path(__file__).parent.parent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -225,9 +225,11 @@ class ApiTest(unittest.TestCase):
|
||||
raise_child_failure_error_fn("trainer", trainer_error_file)
|
||||
pf = cm.exception.get_first_failure()[1]
|
||||
# compare worker error file with reply file and overridden error code
|
||||
expect = json.load(open(pf.error_file))
|
||||
with open(pf.error_file) as f:
|
||||
expect = json.load(f)
|
||||
expect["message"]["errorCode"] = pf.exitcode
|
||||
actual = json.load(open(self.test_error_file))
|
||||
with open(self.test_error_file) as f:
|
||||
actual = json.load(f)
|
||||
self.assertTrue(
|
||||
json.dumps(expect, sort_keys=True),
|
||||
json.dumps(actual, sort_keys=True),
|
||||
|
||||
@ -5,6 +5,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
@ -30,6 +31,8 @@ from torch.utils._debug_mode import (
|
||||
_RedistributeCall,
|
||||
_TritonKernelCall,
|
||||
DebugMode,
|
||||
hash_tensor_fn,
|
||||
norm_hash_fn,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._triton import has_triton_package
|
||||
@ -114,6 +117,28 @@ class TestDTensorDebugMode(TestCase):
|
||||
"aten::sum(t: f32[1, 32]) # {'hash': " in debug_mode.debug_string()
|
||||
)
|
||||
|
||||
# check tuple hash functions
|
||||
with (
|
||||
DebugMode() as debug_mode,
|
||||
DebugMode.log_tensor_hashes(hash_fn=["norm", "hash_tensor"]),
|
||||
):
|
||||
mm(x_dtensor, y_dtensor)
|
||||
|
||||
output_hash = debug_mode.operators[-1].log["hash"]
|
||||
norm_ = lambda x: norm_hash_fn(x, use_scalar=True) # noqa: E731
|
||||
hash_ = lambda x: hash_tensor_fn(x, use_scalar=True) # noqa: E731
|
||||
|
||||
self.assertEqual(output_hash[0], norm_(eager_out))
|
||||
self.assertEqual(output_hash[1], hash_(eager_out))
|
||||
|
||||
# some edge cases
|
||||
self.assertEqual(norm_(torch.tensor(torch.nan)), torch.nan)
|
||||
self.assertEqual(norm_(torch.tensor(torch.inf)), torch.inf)
|
||||
self.assertEqual(norm_(torch.complex(torch.ones(4), torch.zeros(4))), 4)
|
||||
self.assertEqual(hash_(torch.ones(4, dtype=torch.float8_e5m2)), 0)
|
||||
self.assertEqual(hash_(torch.ones(4, dtype=torch.int8)), 0)
|
||||
self.assertEqual(hash_(torch.ones(5, dtype=torch.int8)), 1)
|
||||
|
||||
def test_debug_string_inside_context(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
@ -384,14 +409,22 @@ class TestDTensorDebugMode(TestCase):
|
||||
self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string())
|
||||
|
||||
def test_compile(self):
|
||||
@torch.compile
|
||||
cnt = CompileCounterWithBackend("inductor")
|
||||
|
||||
@torch.compile(backend=cnt)
|
||||
def f(x):
|
||||
return x.sin().cos()
|
||||
|
||||
x = torch.randn(8)
|
||||
f(x)
|
||||
with DebugMode() as debug_mode:
|
||||
f(x)
|
||||
self.assertEqual(len(debug_mode.debug_string()), 0)
|
||||
self.assertEqual(len(debug_mode.debug_string()), 0)
|
||||
f(x)
|
||||
f(x)
|
||||
self.assertEqual(
|
||||
cnt.frame_count, 1
|
||||
) # check DebugMode doesn't trigger additional recompilations
|
||||
|
||||
def test_nn_module(self):
|
||||
class Foo(torch.nn.Module):
|
||||
@ -441,6 +474,9 @@ class TestDTensorDebugMode(TestCase):
|
||||
op for op in debug_mode.operators if str(op.op) == "aten.sum.dim_IntList"
|
||||
][-1]
|
||||
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
|
||||
self.assertTrue(
|
||||
"self.l2(self.l1(x))" in debug_mode.debug_string(show_stack_trace=True)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "requires GPU")
|
||||
@unittest.skipIf(not has_triton_package(), "requires triton")
|
||||
|
||||
@ -6,10 +6,7 @@ import unittest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._dynamo.functional_export import (
|
||||
_dynamo_graph_capture_for_export,
|
||||
dynamo_graph_capture_for_export,
|
||||
)
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
from torch._guards import tracing, TracingContext
|
||||
@ -153,17 +150,6 @@ def graph_capture_and_aot_export_joint_with_descriptors_v2(model, args, kwargs=N
|
||||
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
|
||||
|
||||
|
||||
def graph_capture_and_aot_export_joint_with_descriptors(model, args, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = _dynamo_graph_capture_for_export(model)(*args, **kwargs)
|
||||
fake_mode = gm.meta.get("fake_mode", None)
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
|
||||
|
||||
|
||||
def aot_export_joint_with_descriptors_alone(model, args, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
@ -360,7 +346,6 @@ class DTensorExportTest(TestCase):
|
||||
"export_fn",
|
||||
[
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
aot_export_joint_with_descriptors_alone,
|
||||
],
|
||||
)
|
||||
@ -386,10 +371,6 @@ class DTensorExportTest(TestCase):
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
"[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]",
|
||||
),
|
||||
(
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
"[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_dynamic_shapes(self, export_fn_with_answer):
|
||||
@ -434,7 +415,6 @@ class DTensorExportTest(TestCase):
|
||||
"export_fn",
|
||||
[
|
||||
dynamo_graph_capture_for_export,
|
||||
_dynamo_graph_capture_for_export,
|
||||
],
|
||||
)
|
||||
def test_einsum_dtensor_export(self, export_fn):
|
||||
@ -456,11 +436,7 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
# Run model to verify it works
|
||||
output = model(*inputs)
|
||||
with torch._dynamo.config.patch(
|
||||
install_free_tensors=(export_fn is _dynamo_graph_capture_for_export)
|
||||
):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = export_fn(model)(*inputs)
|
||||
gm = export_fn(model)(*inputs)
|
||||
output_gm = gm(*inputs)
|
||||
self.assertEqual(output, output_gm)
|
||||
|
||||
@ -468,7 +444,6 @@ class DTensorExportTest(TestCase):
|
||||
"export_fn",
|
||||
[
|
||||
graph_capture_and_aot_export_joint_with_descriptors_v2,
|
||||
graph_capture_and_aot_export_joint_with_descriptors,
|
||||
],
|
||||
)
|
||||
def test_flex_attention_dtensor_export(self, export_fn):
|
||||
@ -531,7 +506,7 @@ class DTensorExportTest(TestCase):
|
||||
return nest_fn(leaf) + 1
|
||||
|
||||
z = torch.randn(16, 16)
|
||||
gm = graph_capture_and_aot_export_joint_with_descriptors(fn, (z,))
|
||||
gm = graph_capture_and_aot_export_joint_with_descriptors_v2(fn, (z,))
|
||||
|
||||
self.assertEqual(fn(z), gm(z)[0])
|
||||
|
||||
@ -546,7 +521,7 @@ class DTensorExportTest(TestCase):
|
||||
y = torch.randint(1, (10,)).bool()
|
||||
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
|
||||
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()])
|
||||
_dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
|
||||
dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -556,25 +531,25 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
x = torch.randint(1000, (4, 64, 16))
|
||||
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
|
||||
gm = _dynamo_graph_capture_for_export(Bar())(x_dt)
|
||||
gm = dynamo_graph_capture_for_export(Bar())(x_dt)
|
||||
self.assertExpectedInline(
|
||||
str(gm.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%l_flat_args_0_ : [num_users=2] = placeholder[target=arg_0]
|
||||
%max_1 : [num_users=1] = call_method[target=max](args = (%l_flat_args_0_,), kwargs = {})
|
||||
%l_x_ : torch.distributed.tensor.DTensor [num_users=2] = placeholder[target=L_x_]
|
||||
%max_1 : [num_users=1] = call_method[target=max](args = (%l_x_,), kwargs = {})
|
||||
%clamp : [num_users=1] = call_function[target=torch.clamp](args = (%max_1,), kwargs = {min: 1})
|
||||
%item : [num_users=2] = call_method[target=item](args = (%clamp,), kwargs = {})
|
||||
%ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%item, 1), kwargs = {})
|
||||
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 1 on node 'ge_1'), kwargs = {})
|
||||
%res : [num_users=2] = call_function[target=operator.getitem](args = (%l_flat_args_0_, slice(None, item, None)), kwargs = {})
|
||||
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%res, _local_tensor), kwargs = {})
|
||||
%getitem : [num_users=2] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {})
|
||||
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%getitem, _local_tensor), kwargs = {})
|
||||
%sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getattr_1, 0), kwargs = {})
|
||||
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
|
||||
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u2 >= 0 on node 'ge_2'), kwargs = {})
|
||||
%le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 4), kwargs = {})
|
||||
%_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u2 <= 4 on node 'le'), kwargs = {})
|
||||
return (res,)""", # noqa: B950
|
||||
str(gm.graph).strip(),
|
||||
return (getitem,)""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -706,11 +706,11 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_dtensor_dtype_conversion(self):
|
||||
from torch.distributed.tensor.debug import (
|
||||
_clear_sharding_prop_cache,
|
||||
_get_sharding_prop_cache_info,
|
||||
_clear_fast_path_sharding_prop_cache,
|
||||
_get_fast_path_sharding_prop_cache_stats,
|
||||
)
|
||||
|
||||
_clear_sharding_prop_cache()
|
||||
_clear_fast_path_sharding_prop_cache()
|
||||
device_mesh = self.build_device_mesh()
|
||||
shard_spec = [Shard(0)]
|
||||
# by default we start from bf16 dtype
|
||||
@ -730,13 +730,13 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16)
|
||||
|
||||
# by this point we only have cache misses
|
||||
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
||||
hits, misses = _get_fast_path_sharding_prop_cache_stats()
|
||||
self.assertEqual(hits, 0)
|
||||
self.assertEqual(misses, 2)
|
||||
|
||||
# convert to fp32 again and see if there's cache hit
|
||||
bf16_sharded_dtensor1.float()
|
||||
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
||||
hits, misses = _get_fast_path_sharding_prop_cache_stats()
|
||||
# by now we should have cache hit
|
||||
self.assertEqual(hits, 1)
|
||||
self.assertEqual(misses, 2)
|
||||
|
||||
@ -664,6 +664,101 @@ class TestViewOps(DTensorTestBase):
|
||||
)
|
||||
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_slice(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked on DTensor when slicing
|
||||
a replicated tensor.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a replicated DTensor
|
||||
tensor = torch.randn(10, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Replicate()])
|
||||
|
||||
# Perform a slice operation [1:]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[1:]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
# Verify that the DTensor's storage_offset matches the expected value
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), 1)
|
||||
|
||||
# Verify that the local tensor also has the correct storage_offset
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 1)
|
||||
|
||||
# Verify the shape is correct
|
||||
self.assertEqual(sliced_dtensor.shape, torch.Size([9]))
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[1:]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_shard_dim0_slice_dim1(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked when tensor is sharded on dim 0
|
||||
and sliced on dim 1.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a 2D tensor and shard on dim 0
|
||||
tensor = torch.randn(12, 8, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
|
||||
|
||||
# Perform a slice operation [:, 2:]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[:, 2:]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
# The storage_offset should be 2 (skipping 2 elements in each row)
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), 2)
|
||||
|
||||
# Verify that the local tensor also has the correct storage_offset
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 2)
|
||||
|
||||
# Verify the shape is correct
|
||||
expected_shape = torch.Size([12, 6])
|
||||
self.assertEqual(sliced_dtensor.shape, expected_shape)
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[:, 2:]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_shard_dim1_slice_dim0(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked when tensor is sharded on dim 1
|
||||
and sliced on dim 0.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a 2D tensor and shard on dim 1
|
||||
tensor = torch.randn(10, 12, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(1)])
|
||||
|
||||
# Perform a slice operation [2:, :]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[2:, :]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
local_dim1_size = 12 // self.world_size
|
||||
expected_offset = 2 * local_dim1_size
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), expected_offset)
|
||||
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), expected_offset)
|
||||
|
||||
# Verify the shape is correct
|
||||
expected_shape = torch.Size([8, 12])
|
||||
self.assertEqual(sliced_dtensor.shape, expected_shape)
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[2:, :]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
|
||||
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
|
||||
TestViewOps,
|
||||
|
||||
@ -54,12 +54,10 @@ from torch.testing._internal.common_distributed import (
|
||||
verify_ddp_error_logged,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
MI300_ARCH,
|
||||
retry_on_connect_failures,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle,
|
||||
skipIfRocm,
|
||||
skipIfRocmArch,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
@ -1233,7 +1231,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
||||
self._test_gather_stress(inputs, lambda t: t.clone())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skipIfRocmArch(MI300_ARCH)
|
||||
@skipIfRocm
|
||||
@requires_gloo()
|
||||
def test_gather_stress_cuda(self):
|
||||
inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]
|
||||
|
||||
@ -1681,14 +1681,13 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None
|
||||
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]
|
||||
getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None
|
||||
return (getitem, getitem_1)
|
||||
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[4, 4]"):
|
||||
y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (y, y)
|
||||
return (y,)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -1798,9 +1797,9 @@ class GraphModule(torch.nn.Module):
|
||||
out: "f32[4, 4]" = l_x_.sin()
|
||||
|
||||
sin_1: "f32[4, 4]" = torch.sin(o)
|
||||
child: "f32[4, 4]" = torch.cos(sin_1)
|
||||
child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (child, child_1, matmul, o, out, sin_1)
|
||||
cos: "f32[4, 4]" = torch.cos(sin_1)
|
||||
sin_2: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (cos, sin_2, matmul, o, out, sin_1)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@ -222,13 +222,13 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
matmul: "f32[3, 3]" = l_x_ @ l_y_
|
||||
sin: "f32[3, 3]" = matmul.sin(); matmul = None
|
||||
child: "f32[3, 3]" = sin.cos(); sin = None
|
||||
cos: "f32[3, 3]" = sin.cos(); sin = None
|
||||
|
||||
child_1: "f32[3, 3]" = l_x_ + l_y_
|
||||
child_2: "f32[3, 3]" = l_x_ - l_y_
|
||||
add: "f32[3, 3]" = l_x_ + l_y_
|
||||
sub: "f32[3, 3]" = l_x_ - l_y_
|
||||
|
||||
child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
|
||||
return (child, child_1, child_2, child_3)
|
||||
matmul_1: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
|
||||
return (cos, add, sub, matmul_1)
|
||||
""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
|
||||
@ -962,7 +962,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
x = (torch.randn(4, 16, requires_grad=True),)
|
||||
|
||||
with self.assertRaisesRegex(Exception, "weight = self.linear.w"):
|
||||
torch._dynamo.functional_export._dynamo_graph_capture_for_export(Model())(x)
|
||||
torch._dynamo.functional_export.dynamo_graph_capture_for_export(Model())(x)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ExceptionTests)
|
||||
|
||||
@ -131,7 +131,7 @@ def default_args_generator(seed_value):
|
||||
yield new_args
|
||||
|
||||
|
||||
class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
class HigherOrderOpTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
def _assert_wrap_fallback(self, func, args, setup=lambda: None):
|
||||
counters.clear()
|
||||
backend = EagerAndRecordGraphs()
|
||||
@ -249,7 +249,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
# when testing with dynamic shape, symbols are lifted as input
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
|
||||
|
||||
def test_return_captured_vars(self):
|
||||
freevar1 = torch.randn(3)
|
||||
@ -267,7 +267,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
# be the input.
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 4)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
|
||||
|
||||
def test_return_captured_var_used_multiple_times(self):
|
||||
freevar = torch.randn(3)
|
||||
@ -282,7 +282,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
x = torch.randn(3)
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 3)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 2)
|
||||
|
||||
def test_capture_untracked_global(self):
|
||||
def f(x):
|
||||
@ -762,15 +762,15 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
|
||||
child: "f32[s77]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
getitem: "f32[s77]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
child: "f32[s77]" = l_x_.sin(); l_x_ = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
sin: "f32[s77]" = l_x_.sin(); l_x_ = None
|
||||
sin_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (sin, sin_1)
|
||||
""",
|
||||
)
|
||||
else:
|
||||
@ -801,15 +801,15 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, u0, c); wrap_body_0 = l_x_ = u0 = c = None
|
||||
child: "f32[3]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
getitem: "f32[3]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
child: "f32[3]" = l_x_.sin(); l_x_ = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
sin: "f32[3]" = l_x_.sin(); l_x_ = None
|
||||
sin_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (sin, sin_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -922,16 +922,16 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, size, c); wrap_body_0 = l_x_ = size = c = None
|
||||
child: "f32[3]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
getitem: "f32[3]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
sin: "f32[3]" = l_x_.sin(); l_x_ = None
|
||||
child: "f32[3]" = sin + size; sin = size = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
add: "f32[3]" = sin + size; sin = size = None
|
||||
sin_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (add, sin_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -2458,10 +2458,10 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
|
||||
child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
|
||||
add: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
|
||||
|
||||
child_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
|
||||
return (child, child_1)
|
||||
add_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
|
||||
return (add, add_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -2655,9 +2655,9 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[2, 3]"):
|
||||
child: "f32[2, 3]" = l_x_.sin()
|
||||
child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None
|
||||
return (child, child_1)
|
||||
sin: "f32[2, 3]" = l_x_.sin()
|
||||
cos: "f32[2, 3]" = l_x_.cos(); l_x_ = None
|
||||
return (sin, cos)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -2687,13 +2687,13 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
||||
value: "f32[3]" = wrap[0]; wrap = None
|
||||
return (value,)
|
||||
getitem: "f32[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]"):
|
||||
child: "f32[3]" = -l_x_; l_x_ = None
|
||||
return (child,)
|
||||
neg: "f32[3]" = -l_x_; l_x_ = None
|
||||
return (neg,)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -3318,17 +3318,17 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
hints_wrapper_body_1 = self.hints_wrapper_body_1
|
||||
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None
|
||||
res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
return (res,)
|
||||
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
return (getitem,)
|
||||
|
||||
class hints_wrapper_body_1(torch.nn.Module):
|
||||
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
||||
hints_wrapper_body_0 = self.hints_wrapper_body_0
|
||||
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None
|
||||
x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
||||
|
||||
x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None
|
||||
return (x_2,)
|
||||
x_1: "f32[2, 4]" = torch.abs(getitem); getitem = None
|
||||
return (x_1,)
|
||||
|
||||
class hints_wrapper_body_0(torch.nn.Module):
|
||||
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
||||
@ -3396,7 +3396,9 @@ class GraphModule(torch.nn.Module):
|
||||
fn_with_hints(x, y)
|
||||
|
||||
|
||||
class HigherOrderOpVmapGuardTests(LoggingTestCase):
|
||||
class HigherOrderOpVmapGuardTests(
|
||||
torch._dynamo.test_case.TestCaseWithNestedGraphBreaks, LoggingTestCase
|
||||
):
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_vmap_grad_guard_ok(self, records):
|
||||
vmap = torch.vmap
|
||||
@ -3665,7 +3667,9 @@ class HigherOrderOpVmapGuardTests(LoggingTestCase):
|
||||
self.assertGreater(len(records), 0)
|
||||
|
||||
|
||||
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
class FuncTorchHigherOrderOpTests(
|
||||
torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
|
||||
):
|
||||
def tearDown(self):
|
||||
# Ensure that in the case of a test failure, the next test won't fail
|
||||
# because of a previous call to _vmap_increment_nesting that wasn't undone
|
||||
@ -6782,7 +6786,9 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
||||
class ActivationCheckpointingTests(
|
||||
torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
|
||||
):
|
||||
def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True):
|
||||
cloned_args = []
|
||||
for arg in args:
|
||||
@ -7173,7 +7179,7 @@ xfail_hops_compile = {
|
||||
}
|
||||
|
||||
|
||||
class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCase):
|
||||
class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
@requires_cuda_and_triton
|
||||
@parametrize("backend", ("aot_eager", "inductor"))
|
||||
@ops(
|
||||
|
||||
@ -874,6 +874,32 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreak
|
||||
self.assertEqual(cnts.frame_count, 8)
|
||||
self.assertEqual(cnts.op_count, 10)
|
||||
|
||||
def test_functorch_with_nested_graph_break(self):
|
||||
def f1(x):
|
||||
x = x * 2
|
||||
torch._dynamo.graph_break()
|
||||
return x * 4
|
||||
|
||||
def f2(x):
|
||||
return (f1(x * 8) * 16).sum()
|
||||
|
||||
def f3(x):
|
||||
return torch.func.grad(f2)(x * 32) * 64
|
||||
|
||||
def f4(x):
|
||||
return f3(x * 128) * 256
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
x = torch.randn(3)
|
||||
actual = f4(x)
|
||||
expected = torch.compile(f4, backend=cnts, fullgraph=False)(x)
|
||||
self.assertEqual(actual, expected)
|
||||
self.assertEqual(len(torch._dynamo.utils.counters["graph_break"]), 1)
|
||||
# f4 + f3, f3 end + f4 end
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
# multiplication by 32, 64, 128, 256
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -8146,7 +8146,6 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
unsafe_grad(y) # should not warn
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
@torch._dynamo.config.patch(install_free_tensors=True)
|
||||
def test_partial_export(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -8166,14 +8165,14 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
def forward(self, a, b):
|
||||
return a + b
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
foo = Foo()
|
||||
foo.parallelize()
|
||||
x = torch.randn(4, 4, dtype=torch.float32)
|
||||
y = torch.randn(4, 4, dtype=torch.float32)
|
||||
ref = foo(x, y)
|
||||
gm = _dynamo_graph_capture_for_export(foo)(x, y)
|
||||
gm = dynamo_graph_capture_for_export(foo)(x, y)
|
||||
res = gm(x, y)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
|
||||
@ -388,6 +388,37 @@ class <lambda>(torch.nn.Module):
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_current_stream_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import get_current_stream
|
||||
|
||||
cur_stream = torch.accelerator.current_stream()
|
||||
s0 = None
|
||||
|
||||
def stream_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
nonlocal s0
|
||||
s0_ind = get_current_stream(torch.device("cuda:0"))
|
||||
self.assertEqual(get_external_object_by_index(s0_ind), cur_stream)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(s0_ind,), kwargs={}
|
||||
)
|
||||
gm.graph.call_function(
|
||||
lambda x: self.assertEqual(
|
||||
cur_stream, get_external_object_by_index(x)
|
||||
),
|
||||
args=(s0_ind,),
|
||||
kwargs={},
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=stream_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_with_mutation(self):
|
||||
def fn(x, y):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user