mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 14:54:56 +08:00
Compare commits
94 Commits
trunk/374e
...
update_sub
| Author | SHA1 | Date | |
|---|---|---|---|
| e6046904ed | |||
| deabb3e36d | |||
| 79d2397b6b | |||
| 6ef3a62c36 | |||
| 530e782239 | |||
| c66a6c432e | |||
| 3d7a8b7e61 | |||
| de0d69b2c4 | |||
| bc60b86066 | |||
| d7782ddde7 | |||
| fb04e9ad03 | |||
| cfe799b4aa | |||
| b7f52773e6 | |||
| f6b54d8899 | |||
| da91bf5262 | |||
| 1c1638297e | |||
| ee0b5b4b1c | |||
| fcfb213c5a | |||
| 08042bbb9c | |||
| e20ca3bc2e | |||
| 4ed26f7382 | |||
| 4c79305b87 | |||
| f4b8c4f907 | |||
| d629b7a459 | |||
| 0922ba5f42 | |||
| c87295c044 | |||
| 7aa210d215 | |||
| 5a368b8010 | |||
| 602102be50 | |||
| 200156e385 | |||
| a2daf3fc86 | |||
| 52b45c16de | |||
| 2ef85bed5a | |||
| d99c6bcf69 | |||
| 8378abda84 | |||
| 5b42a5d9a6 | |||
| caca3f2eec | |||
| 9e2bf129e1 | |||
| c429b1fc5c | |||
| 1176b2b0b7 | |||
| dd37a1a434 | |||
| a74adcf80e | |||
| 5eac46a011 | |||
| e0fff31ae3 | |||
| 7ede33b8e3 | |||
| 065176cd97 | |||
| 02ee7dd7d3 | |||
| 99fdca8f4d | |||
| 9d1a74cb0c | |||
| 40e6f090d9 | |||
| bfddfde50c | |||
| b6570615f8 | |||
| 226850cc66 | |||
| f8a2ce3b9a | |||
| e2c6834584 | |||
| 0e7235ed73 | |||
| 3522e0ce74 | |||
| 50bf1f0b81 | |||
| c78e64622e | |||
| 5623628894 | |||
| 2aba180114 | |||
| 45b2c3d312 | |||
| 5b1e112cf9 | |||
| 5e6ac5c6e1 | |||
| 79317dc7a7 | |||
| 96a4c4b3d1 | |||
| 05bcfcc5d1 | |||
| 8cf0bdde45 | |||
| 813e5eae9b | |||
| 2ef236e3e3 | |||
| 532389fe9e | |||
| 08de54f1ea | |||
| 0cd0bd7217 | |||
| fe33d7cadf | |||
| a9542426d0 | |||
| f79cdc89db | |||
| 3d063519bf | |||
| 0b3bdb0d89 | |||
| 8f00ec31ca | |||
| 21f32e4af3 | |||
| 940979a229 | |||
| 4fc688625a | |||
| 23f4f323ea | |||
| 9ac3fc0d0a | |||
| 38806f381a | |||
| cfb3a6b3da | |||
| d8384e296e | |||
| d273422582 | |||
| fadb62f592 | |||
| e5eb89e111 | |||
| b5e0e6932a | |||
| 6ea779188c | |||
| 460c7e196c | |||
| 7aac506cdc |
@ -1,19 +0,0 @@
|
||||
# Aarch64 (ARM/Graviton) Support Scripts
|
||||
Scripts for building aarch64 PyTorch PIP Wheels. These scripts build the following wheels:
|
||||
* torch
|
||||
* torchvision
|
||||
* torchaudio
|
||||
* torchtext
|
||||
* torchdata
|
||||
## Aarch64_ci_build.sh
|
||||
This script is design to support CD operations within PyPi manylinux aarch64 container, and be executed in the container. It prepares the container and then executes __aarch64_wheel_ci_build.py__ to build the wheels. The script "assumes" the PyTorch repo is located at: ```/pytorch``` and will put the wheels into ```/artifacts```.
|
||||
### Usage
|
||||
```DESIRED_PYTHON=<PythonVersion> aarch64_ci_build.sh```
|
||||
|
||||
__NOTE:__ CI build is currently __EXPERMINTAL__
|
||||
|
||||
## Build_aarch64_wheel.py
|
||||
This app allows a person to build using AWS EC3 resources and requires AWS-CLI and Boto3 with AWS credentials to support building EC2 instances for the wheel builds. Can be used in a codebuild CD or from a local system.
|
||||
|
||||
### Usage
|
||||
```build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch <RCtag>```
|
||||
@ -1,53 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -eux -o pipefail
|
||||
|
||||
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
|
||||
|
||||
# Set CUDA architecture lists to match x86 build_cuda.sh
|
||||
if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"13.0"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;11.0;12.0+PTX"
|
||||
fi
|
||||
|
||||
# Compress the fatbin with -compress-mode=size for CUDA 13
|
||||
if [[ "$DESIRED_CUDA" == *"13"* ]]; then
|
||||
export TORCH_NVCC_FLAGS="-compress-mode=size"
|
||||
# Bundle ptxas into the cu13 wheel, see https://github.com/pytorch/pytorch/issues/163801
|
||||
export BUILD_BUNDLE_PTXAS=1
|
||||
fi
|
||||
|
||||
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
|
||||
source $SCRIPTPATH/aarch64_ci_setup.sh
|
||||
|
||||
###############################################################################
|
||||
# Run aarch64 builder python
|
||||
###############################################################################
|
||||
cd /
|
||||
# adding safe directory for git as the permissions will be
|
||||
# on the mounted pytorch repo
|
||||
git config --global --add safe.directory /pytorch
|
||||
pip install -r /pytorch/requirements.txt
|
||||
pip install auditwheel==6.2.0 wheel
|
||||
if [ "$DESIRED_CUDA" = "cpu" ]; then
|
||||
echo "BASE_CUDA_VERSION is not set. Building cpu wheel."
|
||||
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
|
||||
else
|
||||
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
|
||||
export USE_SYSTEM_NCCL=1
|
||||
|
||||
# Check if we should use NVIDIA libs from PyPI (similar to x86 build_cuda.sh logic)
|
||||
if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
|
||||
echo "Bundling CUDA libraries with wheel for aarch64."
|
||||
else
|
||||
echo "Using nvidia libs from pypi for aarch64."
|
||||
echo "Updated PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64: $PYTORCH_EXTRA_INSTALL_REQUIREMENTS"
|
||||
export USE_NVIDIA_PYPI_LIBS=1
|
||||
fi
|
||||
|
||||
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
|
||||
fi
|
||||
@ -1,21 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -eux -o pipefail
|
||||
|
||||
# This script is used to prepare the Docker container for aarch64_ci_wheel_build.py python script
|
||||
# By creating symlinks from desired /opt/python to /usr/local/bin/
|
||||
|
||||
NUMPY_VERSION=2.0.2
|
||||
if [[ "$DESIRED_PYTHON" == "3.13" || "$DESIRED_PYTHON" == "3.13t" ]]; then
|
||||
NUMPY_VERSION=2.1.2
|
||||
fi
|
||||
|
||||
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
|
||||
source $SCRIPTPATH/../manywheel/set_desired_python.sh
|
||||
|
||||
pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2
|
||||
|
||||
for tool in python python3 pip pip3 ninja scons patchelf; do
|
||||
ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin;
|
||||
done
|
||||
|
||||
python --version
|
||||
@ -1,333 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# encoding: UTF-8
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from subprocess import check_call, check_output
|
||||
|
||||
|
||||
def list_dir(path: str) -> list[str]:
|
||||
"""'
|
||||
Helper for getting paths for Python
|
||||
"""
|
||||
return check_output(["ls", "-1", path]).decode().split("\n")
|
||||
|
||||
|
||||
def replace_tag(filename) -> None:
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("Tag:"):
|
||||
lines[i] = line.replace("-linux_", "-manylinux_2_28_")
|
||||
print(f"Updated tag from {line} to {lines[i]}")
|
||||
break
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
|
||||
def patch_library_rpath(
|
||||
folder: str,
|
||||
lib_name: str,
|
||||
use_nvidia_pypi_libs: bool = False,
|
||||
desired_cuda: str = "",
|
||||
) -> None:
|
||||
"""Apply patchelf to set RPATH for a library in torch/lib"""
|
||||
lib_path = f"{folder}/tmp/torch/lib/{lib_name}"
|
||||
|
||||
if use_nvidia_pypi_libs:
|
||||
# For PyPI NVIDIA libraries, construct CUDA RPATH
|
||||
cuda_rpaths = [
|
||||
"$ORIGIN/../../nvidia/cudnn/lib",
|
||||
"$ORIGIN/../../nvidia/nvshmem/lib",
|
||||
"$ORIGIN/../../nvidia/nccl/lib",
|
||||
"$ORIGIN/../../nvidia/cusparselt/lib",
|
||||
]
|
||||
|
||||
if "130" in desired_cuda:
|
||||
cuda_rpaths.append("$ORIGIN/../../nvidia/cu13/lib")
|
||||
else:
|
||||
cuda_rpaths.extend(
|
||||
[
|
||||
"$ORIGIN/../../nvidia/cublas/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_cupti/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_nvrtc/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_runtime/lib",
|
||||
"$ORIGIN/../../nvidia/cufft/lib",
|
||||
"$ORIGIN/../../nvidia/curand/lib",
|
||||
"$ORIGIN/../../nvidia/cusolver/lib",
|
||||
"$ORIGIN/../../nvidia/cusparse/lib",
|
||||
"$ORIGIN/../../nvidia/nvtx/lib",
|
||||
"$ORIGIN/../../nvidia/cufile/lib",
|
||||
]
|
||||
)
|
||||
|
||||
# Add $ORIGIN for local torch libs
|
||||
rpath = ":".join(cuda_rpaths) + ":$ORIGIN"
|
||||
else:
|
||||
# For bundled libraries, just use $ORIGIN
|
||||
rpath = "$ORIGIN"
|
||||
|
||||
if os.path.exists(lib_path):
|
||||
os.system(
|
||||
f"cd {folder}/tmp/torch/lib/; "
|
||||
f"patchelf --set-rpath '{rpath}' --force-rpath {lib_name}"
|
||||
)
|
||||
|
||||
|
||||
def copy_and_patch_library(
|
||||
src_path: str,
|
||||
folder: str,
|
||||
use_nvidia_pypi_libs: bool = False,
|
||||
desired_cuda: str = "",
|
||||
) -> None:
|
||||
"""Copy a library to torch/lib and patch its RPATH"""
|
||||
if os.path.exists(src_path):
|
||||
lib_name = os.path.basename(src_path)
|
||||
shutil.copy2(src_path, f"{folder}/tmp/torch/lib/{lib_name}")
|
||||
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
|
||||
def package_cuda_wheel(wheel_path, desired_cuda) -> None:
|
||||
"""
|
||||
Package the cuda wheel libraries
|
||||
"""
|
||||
folder = os.path.dirname(wheel_path)
|
||||
os.mkdir(f"{folder}/tmp")
|
||||
os.system(f"unzip {wheel_path} -d {folder}/tmp")
|
||||
# Delete original wheel since it will be repackaged
|
||||
os.system(f"rm {wheel_path}")
|
||||
|
||||
# Check if we should use PyPI NVIDIA libraries or bundle system libraries
|
||||
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
|
||||
|
||||
if use_nvidia_pypi_libs:
|
||||
print("Using nvidia libs from pypi - skipping CUDA library bundling")
|
||||
# For PyPI approach, we don't bundle CUDA libraries - they come from PyPI packages
|
||||
# We only need to bundle non-NVIDIA libraries
|
||||
minimal_libs_to_copy = [
|
||||
"/lib64/libgomp.so.1",
|
||||
"/usr/lib64/libgfortran.so.5",
|
||||
"/acl/build/libarm_compute.so",
|
||||
"/acl/build/libarm_compute_graph.so",
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0",
|
||||
]
|
||||
|
||||
# Copy minimal libraries to unzipped_folder/torch/lib
|
||||
for lib_path in minimal_libs_to_copy:
|
||||
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
# Patch torch libraries used for searching libraries
|
||||
torch_libs_to_patch = [
|
||||
"libtorch.so",
|
||||
"libtorch_cpu.so",
|
||||
"libtorch_cuda.so",
|
||||
"libtorch_cuda_linalg.so",
|
||||
"libtorch_global_deps.so",
|
||||
"libtorch_python.so",
|
||||
"libtorch_nvshmem.so",
|
||||
"libc10.so",
|
||||
"libc10_cuda.so",
|
||||
"libcaffe2_nvrtc.so",
|
||||
"libshm.so",
|
||||
]
|
||||
for lib_name in torch_libs_to_patch:
|
||||
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
|
||||
else:
|
||||
print("Bundling CUDA libraries with wheel")
|
||||
# Original logic for bundling system CUDA libraries
|
||||
# Common libraries for all CUDA versions
|
||||
common_libs = [
|
||||
# Non-NVIDIA system libraries
|
||||
"/lib64/libgomp.so.1",
|
||||
"/usr/lib64/libgfortran.so.5",
|
||||
"/acl/build/libarm_compute.so",
|
||||
"/acl/build/libarm_compute_graph.so",
|
||||
# Common CUDA libraries (same for all versions)
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0",
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so",
|
||||
"/usr/local/cuda/lib64/libcudnn.so.9",
|
||||
"/usr/local/cuda/lib64/libcusparseLt.so.0",
|
||||
"/usr/local/cuda/lib64/libcurand.so.10",
|
||||
"/usr/local/cuda/lib64/libnccl.so.2",
|
||||
"/usr/local/cuda/lib64/libnvshmem_host.so.3",
|
||||
"/usr/local/cuda/lib64/libcudnn_adv.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_cnn.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_graph.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_ops.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_heuristic.so.9",
|
||||
"/usr/local/cuda/lib64/libcufile.so.0",
|
||||
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
|
||||
"/usr/local/cuda/lib64/libcusparse.so.12",
|
||||
]
|
||||
|
||||
# CUDA version-specific libraries
|
||||
if "13" in desired_cuda:
|
||||
minor_version = desired_cuda[-1]
|
||||
version_specific_libs = [
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13",
|
||||
"/usr/local/cuda/lib64/libcublas.so.13",
|
||||
"/usr/local/cuda/lib64/libcublasLt.so.13",
|
||||
"/usr/local/cuda/lib64/libcudart.so.13",
|
||||
"/usr/local/cuda/lib64/libcufft.so.12",
|
||||
"/usr/local/cuda/lib64/libcusolver.so.12",
|
||||
"/usr/local/cuda/lib64/libnvJitLink.so.13",
|
||||
"/usr/local/cuda/lib64/libnvrtc.so.13",
|
||||
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.{minor_version}",
|
||||
]
|
||||
elif "12" in desired_cuda:
|
||||
# Get the last character for libnvrtc-builtins version (e.g., "129" -> "9")
|
||||
minor_version = desired_cuda[-1]
|
||||
version_specific_libs = [
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12",
|
||||
"/usr/local/cuda/lib64/libcublas.so.12",
|
||||
"/usr/local/cuda/lib64/libcublasLt.so.12",
|
||||
"/usr/local/cuda/lib64/libcudart.so.12",
|
||||
"/usr/local/cuda/lib64/libcufft.so.11",
|
||||
"/usr/local/cuda/lib64/libcusolver.so.11",
|
||||
"/usr/local/cuda/lib64/libnvJitLink.so.12",
|
||||
"/usr/local/cuda/lib64/libnvrtc.so.12",
|
||||
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.{minor_version}",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unsupported CUDA version: {desired_cuda}.")
|
||||
|
||||
# Combine all libraries
|
||||
libs_to_copy = common_libs + version_specific_libs
|
||||
|
||||
# Copy libraries to unzipped_folder/torch/lib
|
||||
for lib_path in libs_to_copy:
|
||||
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
# Make sure the wheel is tagged with manylinux_2_28
|
||||
for f in os.scandir(f"{folder}/tmp/"):
|
||||
if f.is_dir() and f.name.endswith(".dist-info"):
|
||||
replace_tag(f"{f.path}/WHEEL")
|
||||
break
|
||||
|
||||
os.system(f"wheel pack {folder}/tmp/ -d {folder}")
|
||||
os.system(f"rm -rf {folder}/tmp/")
|
||||
|
||||
|
||||
def complete_wheel(folder: str) -> str:
|
||||
"""
|
||||
Complete wheel build and put in artifact location
|
||||
"""
|
||||
wheel_name = list_dir(f"/{folder}/dist")[0]
|
||||
|
||||
# Please note for cuda we don't run auditwheel since we use custom script to package
|
||||
# the cuda dependencies to the wheel file using update_wheel() method.
|
||||
# However we need to make sure filename reflects the correct Manylinux platform.
|
||||
if "pytorch" in folder and not enable_cuda:
|
||||
print("Repairing Wheel with AuditWheel")
|
||||
check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder)
|
||||
repaired_wheel_name = list_dir(f"/{folder}/wheelhouse")[0]
|
||||
|
||||
print(f"Moving {repaired_wheel_name} wheel to /{folder}/dist")
|
||||
os.rename(
|
||||
f"/{folder}/wheelhouse/{repaired_wheel_name}",
|
||||
f"/{folder}/dist/{repaired_wheel_name}",
|
||||
)
|
||||
else:
|
||||
repaired_wheel_name = list_dir(f"/{folder}/dist")[0]
|
||||
|
||||
print(f"Copying {repaired_wheel_name} to artifacts")
|
||||
shutil.copy2(
|
||||
f"/{folder}/dist/{repaired_wheel_name}", f"/artifacts/{repaired_wheel_name}"
|
||||
)
|
||||
|
||||
return repaired_wheel_name
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""
|
||||
Parse inline arguments
|
||||
"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("AARCH64 wheels python CD")
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--build-only", action="store_true")
|
||||
parser.add_argument("--test-only", type=str)
|
||||
parser.add_argument("--enable-mkldnn", action="store_true")
|
||||
parser.add_argument("--enable-cuda", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Entry Point
|
||||
"""
|
||||
args = parse_arguments()
|
||||
enable_mkldnn = args.enable_mkldnn
|
||||
enable_cuda = args.enable_cuda
|
||||
branch = check_output(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd="/pytorch"
|
||||
).decode()
|
||||
|
||||
print("Building PyTorch wheel")
|
||||
build_vars = ""
|
||||
# MAX_JOB=5 is not required for CPU backend (see commit 465d98b)
|
||||
if enable_cuda:
|
||||
build_vars += "MAX_JOBS=5 "
|
||||
|
||||
# Handle PyPI NVIDIA libraries vs bundled libraries
|
||||
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
|
||||
if use_nvidia_pypi_libs:
|
||||
print("Configuring build for PyPI NVIDIA libraries")
|
||||
# Configure for dynamic linking (matching x86 logic)
|
||||
build_vars += "ATEN_STATIC_CUDA=0 USE_CUDA_STATIC_LINK=0 USE_CUPTI_SO=1 "
|
||||
else:
|
||||
print("Configuring build for bundled NVIDIA libraries")
|
||||
# Keep existing static linking approach - already configured above
|
||||
|
||||
override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION")
|
||||
desired_cuda = os.getenv("DESIRED_CUDA")
|
||||
if override_package_version is not None:
|
||||
version = override_package_version
|
||||
build_vars += (
|
||||
f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version} PYTORCH_BUILD_NUMBER=1 "
|
||||
)
|
||||
elif branch in ["nightly", "main"]:
|
||||
build_date = (
|
||||
check_output(["git", "log", "--pretty=format:%cs", "-1"], cwd="/pytorch")
|
||||
.decode()
|
||||
.replace("-", "")
|
||||
)
|
||||
version = (
|
||||
check_output(["cat", "version.txt"], cwd="/pytorch").decode().strip()[:-2]
|
||||
)
|
||||
if enable_cuda:
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date}+{desired_cuda} PYTORCH_BUILD_NUMBER=1 "
|
||||
else:
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 "
|
||||
elif branch.startswith(("v1.", "v2.")):
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
|
||||
|
||||
if enable_mkldnn:
|
||||
print("build pytorch with mkldnn+acl backend")
|
||||
build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
|
||||
build_vars += "ACL_ROOT_DIR=/acl "
|
||||
if enable_cuda:
|
||||
build_vars += "BLAS=NVPL "
|
||||
else:
|
||||
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS "
|
||||
else:
|
||||
print("build pytorch without mkldnn backend")
|
||||
|
||||
os.system(f"cd /pytorch; {build_vars} python3 -m build --wheel --no-isolation")
|
||||
if enable_cuda:
|
||||
print("Updating Cuda Dependency")
|
||||
filename = os.listdir("/pytorch/dist/")
|
||||
wheel_path = f"/pytorch/dist/{filename[0]}"
|
||||
package_cuda_wheel(wheel_path, desired_cuda)
|
||||
pytorch_wheel_name = complete_wheel("/pytorch/")
|
||||
print(f"Build Complete. Created {pytorch_wheel_name}..")
|
||||
@ -1,999 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# This script is for building AARCH64 wheels using AWS EC2 instances.
|
||||
# To generate binaries for the release follow these steps:
|
||||
# 1. Update mappings for each of the Domain Libraries by adding new row to a table like this:
|
||||
# "v1.11.0": ("0.11.0", "rc1"),
|
||||
# 2. Run script with following arguments for each of the supported python versions and required tag, for example:
|
||||
# build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch v1.11.0-rc3
|
||||
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
|
||||
# AMI images for us-east-1, change the following based on your ~/.aws/config
|
||||
os_amis = {
|
||||
"ubuntu20_04": "ami-052eac90edaa9d08f", # login_name: ubuntu
|
||||
"ubuntu22_04": "ami-0c6c29c5125214c77", # login_name: ubuntu
|
||||
"redhat8": "ami-0698b90665a2ddcf1", # login_name: ec2-user
|
||||
}
|
||||
|
||||
ubuntu20_04_ami = os_amis["ubuntu20_04"]
|
||||
|
||||
|
||||
def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]:
|
||||
if key_name is None:
|
||||
key_name = os.getenv("AWS_KEY_NAME")
|
||||
if key_name is None:
|
||||
return os.getenv("SSH_KEY_PATH", ""), ""
|
||||
|
||||
homedir_path = os.path.expanduser("~")
|
||||
default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem")
|
||||
return os.getenv("SSH_KEY_PATH", default_path), key_name
|
||||
|
||||
|
||||
ec2 = boto3.resource("ec2")
|
||||
|
||||
|
||||
def ec2_get_instances(filter_name, filter_value):
|
||||
return ec2.instances.filter(
|
||||
Filters=[{"Name": filter_name, "Values": [filter_value]}]
|
||||
)
|
||||
|
||||
|
||||
def ec2_instances_of_type(instance_type="t4g.2xlarge"):
|
||||
return ec2_get_instances("instance-type", instance_type)
|
||||
|
||||
|
||||
def ec2_instances_by_id(instance_id):
|
||||
rc = list(ec2_get_instances("instance-id", instance_id))
|
||||
return rc[0] if len(rc) > 0 else None
|
||||
|
||||
|
||||
def start_instance(
|
||||
key_name, ami=ubuntu20_04_ami, instance_type="t4g.2xlarge", ebs_size: int = 50
|
||||
):
|
||||
inst = ec2.create_instances(
|
||||
ImageId=ami,
|
||||
InstanceType=instance_type,
|
||||
SecurityGroups=["ssh-allworld"],
|
||||
KeyName=key_name,
|
||||
MinCount=1,
|
||||
MaxCount=1,
|
||||
BlockDeviceMappings=[
|
||||
{
|
||||
"DeviceName": "/dev/sda1",
|
||||
"Ebs": {
|
||||
"DeleteOnTermination": True,
|
||||
"VolumeSize": ebs_size,
|
||||
"VolumeType": "standard",
|
||||
},
|
||||
}
|
||||
],
|
||||
)[0]
|
||||
print(f"Create instance {inst.id}")
|
||||
inst.wait_until_running()
|
||||
running_inst = ec2_instances_by_id(inst.id)
|
||||
print(f"Instance started at {running_inst.public_dns_name}")
|
||||
return running_inst
|
||||
|
||||
|
||||
class RemoteHost:
|
||||
addr: str
|
||||
keyfile_path: str
|
||||
login_name: str
|
||||
container_id: Optional[str] = None
|
||||
ami: Optional[str] = None
|
||||
|
||||
def __init__(self, addr: str, keyfile_path: str, login_name: str = "ubuntu"):
|
||||
self.addr = addr
|
||||
self.keyfile_path = keyfile_path
|
||||
self.login_name = login_name
|
||||
|
||||
def _gen_ssh_prefix(self) -> list[str]:
|
||||
return [
|
||||
"ssh",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
f"{self.login_name}@{self.addr}",
|
||||
"--",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _split_cmd(args: Union[str, list[str]]) -> list[str]:
|
||||
return args.split() if isinstance(args, str) else args
|
||||
|
||||
def run_ssh_cmd(self, args: Union[str, list[str]]) -> None:
|
||||
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
|
||||
|
||||
def check_ssh_output(self, args: Union[str, list[str]]) -> str:
|
||||
return subprocess.check_output(
|
||||
self._gen_ssh_prefix() + self._split_cmd(args)
|
||||
).decode("utf-8")
|
||||
|
||||
def scp_upload_file(self, local_file: str, remote_file: str) -> None:
|
||||
subprocess.check_call(
|
||||
[
|
||||
"scp",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
local_file,
|
||||
f"{self.login_name}@{self.addr}:{remote_file}",
|
||||
]
|
||||
)
|
||||
|
||||
def scp_download_file(
|
||||
self, remote_file: str, local_file: Optional[str] = None
|
||||
) -> None:
|
||||
if local_file is None:
|
||||
local_file = "."
|
||||
subprocess.check_call(
|
||||
[
|
||||
"scp",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
f"{self.login_name}@{self.addr}:{remote_file}",
|
||||
local_file,
|
||||
]
|
||||
)
|
||||
|
||||
def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None:
|
||||
self.run_ssh_cmd("sudo apt-get install -y docker.io")
|
||||
self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}")
|
||||
self.run_ssh_cmd("sudo service docker start")
|
||||
self.run_ssh_cmd(f"docker pull {image}")
|
||||
self.container_id = self.check_ssh_output(
|
||||
f"docker run -t -d -w /root {image}"
|
||||
).strip()
|
||||
|
||||
def using_docker(self) -> bool:
|
||||
return self.container_id is not None
|
||||
|
||||
def run_cmd(self, args: Union[str, list[str]]) -> None:
|
||||
if not self.using_docker():
|
||||
return self.run_ssh_cmd(args)
|
||||
assert self.container_id is not None
|
||||
docker_cmd = self._gen_ssh_prefix() + [
|
||||
"docker",
|
||||
"exec",
|
||||
"-i",
|
||||
self.container_id,
|
||||
"bash",
|
||||
]
|
||||
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE)
|
||||
p.communicate(
|
||||
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
rc = p.wait()
|
||||
if rc != 0:
|
||||
raise subprocess.CalledProcessError(rc, docker_cmd)
|
||||
|
||||
def check_output(self, args: Union[str, list[str]]) -> str:
|
||||
if not self.using_docker():
|
||||
return self.check_ssh_output(args)
|
||||
assert self.container_id is not None
|
||||
docker_cmd = self._gen_ssh_prefix() + [
|
||||
"docker",
|
||||
"exec",
|
||||
"-i",
|
||||
self.container_id,
|
||||
"bash",
|
||||
]
|
||||
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||
(out, err) = p.communicate(
|
||||
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
rc = p.wait()
|
||||
if rc != 0:
|
||||
raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err)
|
||||
return out.decode("utf-8")
|
||||
|
||||
def upload_file(self, local_file: str, remote_file: str) -> None:
|
||||
if not self.using_docker():
|
||||
return self.scp_upload_file(local_file, remote_file)
|
||||
tmp_file = os.path.join("/tmp", os.path.basename(local_file))
|
||||
self.scp_upload_file(local_file, tmp_file)
|
||||
self.run_ssh_cmd(
|
||||
["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"]
|
||||
)
|
||||
self.run_ssh_cmd(["rm", tmp_file])
|
||||
|
||||
def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None:
|
||||
if not self.using_docker():
|
||||
return self.scp_download_file(remote_file, local_file)
|
||||
tmp_file = os.path.join("/tmp", os.path.basename(remote_file))
|
||||
self.run_ssh_cmd(
|
||||
["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file]
|
||||
)
|
||||
self.scp_download_file(tmp_file, local_file)
|
||||
self.run_ssh_cmd(["rm", tmp_file])
|
||||
|
||||
def download_wheel(
|
||||
self, remote_file: str, local_file: Optional[str] = None
|
||||
) -> None:
|
||||
if self.using_docker() and local_file is None:
|
||||
basename = os.path.basename(remote_file)
|
||||
local_file = basename.replace(
|
||||
"-linux_aarch64.whl", "-manylinux2014_aarch64.whl"
|
||||
)
|
||||
self.download_file(remote_file, local_file)
|
||||
|
||||
def list_dir(self, path: str) -> list[str]:
|
||||
return self.check_output(["ls", "-1", path]).split("\n")
|
||||
|
||||
|
||||
def wait_for_connection(addr, port, timeout=15, attempt_cnt=5):
|
||||
import socket
|
||||
|
||||
for i in range(attempt_cnt):
|
||||
try:
|
||||
with socket.create_connection((addr, port), timeout=timeout):
|
||||
return
|
||||
except (ConnectionRefusedError, TimeoutError): # noqa: PERF203
|
||||
if i == attempt_cnt - 1:
|
||||
raise
|
||||
time.sleep(timeout)
|
||||
|
||||
|
||||
def update_apt_repo(host: RemoteHost) -> None:
|
||||
time.sleep(5)
|
||||
host.run_cmd("sudo systemctl stop apt-daily.service || true")
|
||||
host.run_cmd("sudo systemctl stop unattended-upgrades.service || true")
|
||||
host.run_cmd(
|
||||
"while systemctl is-active --quiet apt-daily.service; do sleep 1; done"
|
||||
)
|
||||
host.run_cmd(
|
||||
"while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done"
|
||||
)
|
||||
host.run_cmd("sudo apt-get update")
|
||||
time.sleep(3)
|
||||
host.run_cmd("sudo apt-get update")
|
||||
|
||||
|
||||
def install_condaforge(
|
||||
host: RemoteHost, suffix: str = "latest/download/Miniforge3-Linux-aarch64.sh"
|
||||
) -> None:
|
||||
print("Install conda-forge")
|
||||
host.run_cmd(f"curl -OL https://github.com/conda-forge/miniforge/releases/{suffix}")
|
||||
host.run_cmd(f"sh -f {os.path.basename(suffix)} -b")
|
||||
host.run_cmd(f"rm -f {os.path.basename(suffix)}")
|
||||
if host.using_docker():
|
||||
host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc")
|
||||
else:
|
||||
host.run_cmd(
|
||||
[
|
||||
"sed",
|
||||
"-i",
|
||||
"'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH'",
|
||||
".bashrc",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None:
|
||||
if python_version == "3.6":
|
||||
# Python-3.6 EOLed and not compatible with conda-4.11
|
||||
install_condaforge(
|
||||
host, suffix="download/4.10.3-10/Miniforge3-4.10.3-10-Linux-aarch64.sh"
|
||||
)
|
||||
host.run_cmd(f"conda install -y python={python_version} numpy pyyaml")
|
||||
else:
|
||||
install_condaforge(
|
||||
host, suffix="download/4.11.0-4/Miniforge3-4.11.0-4-Linux-aarch64.sh"
|
||||
)
|
||||
# Pytorch-1.10 or older are not compatible with setuptools=59.6 or newer
|
||||
host.run_cmd(
|
||||
f"conda install -y python={python_version} numpy pyyaml setuptools>=59.5.0"
|
||||
)
|
||||
|
||||
|
||||
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
|
||||
host.run_cmd("pip3 install auditwheel")
|
||||
host.run_cmd(
|
||||
"conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf"
|
||||
)
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
with NamedTemporaryFile() as tmp:
|
||||
tmp.write(embed_library_script.encode("utf-8"))
|
||||
tmp.flush()
|
||||
host.upload_file(tmp.name, "embed_library.py")
|
||||
|
||||
print("Embedding libgomp into wheel")
|
||||
if host.using_docker():
|
||||
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
|
||||
else:
|
||||
host.run_cmd(f"python3 embed_library.py {wheel_name}")
|
||||
|
||||
|
||||
def checkout_repo(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
url: str,
|
||||
git_clone_flags: str,
|
||||
mapping: dict[str, tuple[str, str]],
|
||||
) -> Optional[str]:
|
||||
for prefix in mapping:
|
||||
if not branch.startswith(prefix):
|
||||
continue
|
||||
tag = f"v{mapping[prefix][0]}-{mapping[prefix][1]}"
|
||||
host.run_cmd(f"git clone {url} -b {tag} {git_clone_flags}")
|
||||
return mapping[prefix][0]
|
||||
|
||||
host.run_cmd(f"git clone {url} -b {branch} {git_clone_flags}")
|
||||
return None
|
||||
|
||||
|
||||
def build_torchvision(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str,
|
||||
run_smoke_tests: bool = True,
|
||||
) -> str:
|
||||
print("Checking out TorchVision repo")
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/vision",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.7.1": ("0.8.2", "rc2"),
|
||||
"v1.8.0": ("0.9.0", "rc3"),
|
||||
"v1.8.1": ("0.9.1", "rc1"),
|
||||
"v1.9.0": ("0.10.0", "rc1"),
|
||||
"v1.10.0": ("0.11.1", "rc1"),
|
||||
"v1.10.1": ("0.11.2", "rc1"),
|
||||
"v1.10.2": ("0.11.3", "rc1"),
|
||||
"v1.11.0": ("0.12.0", "rc1"),
|
||||
"v1.12.0": ("0.13.0", "rc4"),
|
||||
"v1.12.1": ("0.13.1", "rc6"),
|
||||
"v1.13.0": ("0.14.0", "rc4"),
|
||||
"v1.13.1": ("0.14.1", "rc2"),
|
||||
"v2.0.0": ("0.15.1", "rc2"),
|
||||
"v2.0.1": ("0.15.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchVision wheel")
|
||||
|
||||
# Please note libnpg and jpeg are required to build image.so extension
|
||||
if use_conda:
|
||||
host.run_cmd("conda install -y libpng jpeg")
|
||||
# Remove .so files to force static linking
|
||||
host.run_cmd(
|
||||
"rm miniforge3/lib/libpng.so miniforge3/lib/libpng16.so miniforge3/lib/libjpeg.so"
|
||||
)
|
||||
# And patch setup.py to include libz dependency for libpng
|
||||
host.run_cmd(
|
||||
[
|
||||
'sed -i -e \'s/image_link_flags\\.append("png")/image_link_flags += ["png", "z"]/\' vision/setup.py'
|
||||
]
|
||||
)
|
||||
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"]
|
||||
).strip()
|
||||
if len(version) == 0:
|
||||
# In older revisions, version was embedded in setup.py
|
||||
version = (
|
||||
host.check_output(["grep", '"version = \'"', "vision/setup.py"])
|
||||
.strip()
|
||||
.split("'")[1][:-2]
|
||||
)
|
||||
build_date = (
|
||||
host.check_output("cd vision && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd vision && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
vision_wheel_name = host.list_dir("vision/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("vision", "dist", vision_wheel_name))
|
||||
|
||||
print("Copying TorchVision wheel")
|
||||
host.download_wheel(os.path.join("vision", "dist", vision_wheel_name))
|
||||
if run_smoke_tests:
|
||||
host.run_cmd(
|
||||
f"pip3 install {os.path.join('vision', 'dist', vision_wheel_name)}"
|
||||
)
|
||||
host.run_cmd("python3 vision/test/smoke_test.py")
|
||||
print("Delete vision checkout")
|
||||
host.run_cmd("rm -rf vision")
|
||||
|
||||
return vision_wheel_name
|
||||
|
||||
|
||||
def build_torchdata(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchData repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/data",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.13.1": ("0.5.1", ""),
|
||||
"v2.0.0": ("0.6.0", "rc5"),
|
||||
"v2.0.1": ("0.6.1", "rc1"),
|
||||
},
|
||||
)
|
||||
print("Building TorchData wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f data/version.txt ]; then cat data/version.txt; fi"]
|
||||
).strip()
|
||||
build_date = (
|
||||
host.check_output("cd data && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd data && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
wheel_name = host.list_dir("data/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("data", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchData wheel")
|
||||
host.download_wheel(os.path.join("data", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def build_torchtext(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchText repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/text",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.9.0": ("0.10.0", "rc1"),
|
||||
"v1.10.0": ("0.11.0", "rc2"),
|
||||
"v1.10.1": ("0.11.1", "rc1"),
|
||||
"v1.10.2": ("0.11.2", "rc1"),
|
||||
"v1.11.0": ("0.12.0", "rc1"),
|
||||
"v1.12.0": ("0.13.0", "rc2"),
|
||||
"v1.12.1": ("0.13.1", "rc5"),
|
||||
"v1.13.0": ("0.14.0", "rc3"),
|
||||
"v1.13.1": ("0.14.1", "rc1"),
|
||||
"v2.0.0": ("0.15.1", "rc2"),
|
||||
"v2.0.1": ("0.15.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchText wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f text/version.txt ]; then cat text/version.txt; fi"]
|
||||
).strip()
|
||||
build_date = (
|
||||
host.check_output("cd text && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd text && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
wheel_name = host.list_dir("text/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("text", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchText wheel")
|
||||
host.download_wheel(os.path.join("text", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def build_torchaudio(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchAudio repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/audio",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.9.0": ("0.9.0", "rc2"),
|
||||
"v1.10.0": ("0.10.0", "rc5"),
|
||||
"v1.10.1": ("0.10.1", "rc1"),
|
||||
"v1.10.2": ("0.10.2", "rc1"),
|
||||
"v1.11.0": ("0.11.0", "rc1"),
|
||||
"v1.12.0": ("0.12.0", "rc3"),
|
||||
"v1.12.1": ("0.12.1", "rc5"),
|
||||
"v1.13.0": ("0.13.0", "rc4"),
|
||||
"v1.13.1": ("0.13.1", "rc2"),
|
||||
"v2.0.0": ("2.0.1", "rc3"),
|
||||
"v2.0.1": ("2.0.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchAudio wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = (
|
||||
host.check_output(["grep", '"version = \'"', "audio/setup.py"])
|
||||
.strip()
|
||||
.split("'")[1][:-2]
|
||||
)
|
||||
build_date = (
|
||||
host.check_output("cd audio && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(
|
||||
f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \
|
||||
&& ./packaging/ffmpeg/build.sh \
|
||||
&& {build_vars} python3 -m build --wheel --no-isolation"
|
||||
)
|
||||
|
||||
wheel_name = host.list_dir("audio/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("audio", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchAudio wheel")
|
||||
host.download_wheel(os.path.join("audio", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def configure_system(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
compiler: str = "gcc-8",
|
||||
use_conda: bool = True,
|
||||
python_version: str = "3.8",
|
||||
) -> None:
|
||||
if use_conda:
|
||||
install_condaforge_python(host, python_version)
|
||||
|
||||
print("Configuring the system")
|
||||
if not host.using_docker():
|
||||
update_apt_repo(host)
|
||||
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip")
|
||||
else:
|
||||
host.run_cmd("yum install -y sudo")
|
||||
host.run_cmd("conda install -y ninja scons")
|
||||
|
||||
if not use_conda:
|
||||
host.run_cmd(
|
||||
"sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip"
|
||||
)
|
||||
host.run_cmd("pip3 install dataclasses typing-extensions")
|
||||
if not use_conda:
|
||||
print("Installing Cython + numpy from PyPy")
|
||||
host.run_cmd("sudo pip3 install Cython")
|
||||
host.run_cmd("sudo pip3 install numpy")
|
||||
|
||||
|
||||
def build_domains(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> tuple[str, str, str, str]:
|
||||
vision_wheel_name = build_torchvision(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
audio_wheel_name = build_torchaudio(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
data_wheel_name = build_torchdata(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
text_wheel_name = build_torchtext(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
return (vision_wheel_name, audio_wheel_name, data_wheel_name, text_wheel_name)
|
||||
|
||||
|
||||
def start_build(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
compiler: str = "gcc-8",
|
||||
use_conda: bool = True,
|
||||
python_version: str = "3.8",
|
||||
pytorch_only: bool = False,
|
||||
pytorch_build_number: Optional[str] = None,
|
||||
shallow_clone: bool = True,
|
||||
enable_mkldnn: bool = False,
|
||||
) -> tuple[str, str, str, str, str]:
|
||||
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
|
||||
if host.using_docker() and not use_conda:
|
||||
print("Auto-selecting conda option for docker images")
|
||||
use_conda = True
|
||||
if not host.using_docker():
|
||||
print("Disable mkldnn for host builds")
|
||||
enable_mkldnn = False
|
||||
|
||||
configure_system(
|
||||
host, compiler=compiler, use_conda=use_conda, python_version=python_version
|
||||
)
|
||||
|
||||
if host.using_docker():
|
||||
print("Move libgfortant.a into a standard location")
|
||||
# HACK: pypa gforntran.a is compiled without PIC, which leads to the following error
|
||||
# libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17' # noqa: E501, B950
|
||||
# Workaround by copying gfortran library from the host
|
||||
host.run_ssh_cmd("sudo apt-get install -y gfortran-8")
|
||||
host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8")
|
||||
host.run_ssh_cmd(
|
||||
[
|
||||
"docker",
|
||||
"cp",
|
||||
"/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a",
|
||||
f"{host.container_id}:/opt/rh/devtoolset-10/root/usr/lib/gcc/aarch64-redhat-linux/10/",
|
||||
]
|
||||
)
|
||||
|
||||
print("Checking out PyTorch repo")
|
||||
host.run_cmd(
|
||||
f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}"
|
||||
)
|
||||
|
||||
host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh")
|
||||
|
||||
print("Building PyTorch wheel")
|
||||
build_opts = ""
|
||||
if pytorch_build_number is not None:
|
||||
build_opts += f" -C--build-option=--build-number={pytorch_build_number}"
|
||||
# Breakpad build fails on aarch64
|
||||
build_vars = "USE_BREAKPAD=0 "
|
||||
if branch == "nightly":
|
||||
build_date = (
|
||||
host.check_output("cd pytorch && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
version = host.check_output("cat pytorch/version.txt").strip()[:-2]
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1"
|
||||
if branch.startswith(("v1.", "v2.")):
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
if enable_mkldnn:
|
||||
host.run_cmd("pytorch/.ci/docker/common/install_acl.sh")
|
||||
print("build pytorch with mkldnn+acl backend")
|
||||
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
|
||||
build_vars += " BLAS=OpenBLAS"
|
||||
build_vars += " OpenBLAS_HOME=/opt/OpenBLAS"
|
||||
build_vars += " ACL_ROOT_DIR=/acl"
|
||||
host.run_cmd(
|
||||
f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
|
||||
)
|
||||
print("Repair the wheel")
|
||||
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
|
||||
ld_library_path = "/acl/build:$HOME/pytorch/build/lib"
|
||||
host.run_cmd(
|
||||
f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}"
|
||||
)
|
||||
print("replace the original wheel with the repaired one")
|
||||
pytorch_repaired_wheel_name = host.list_dir("wheelhouse")[0]
|
||||
host.run_cmd(
|
||||
f"cp $HOME/wheelhouse/{pytorch_repaired_wheel_name} $HOME/pytorch/dist/{pytorch_wheel_name}"
|
||||
)
|
||||
else:
|
||||
print("build pytorch without mkldnn backend")
|
||||
host.run_cmd(
|
||||
f"cd pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
|
||||
)
|
||||
|
||||
print("Deleting build folder")
|
||||
host.run_cmd("cd pytorch && rm -rf build")
|
||||
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("pytorch", "dist", pytorch_wheel_name))
|
||||
print("Copying the wheel")
|
||||
host.download_wheel(os.path.join("pytorch", "dist", pytorch_wheel_name))
|
||||
|
||||
print("Installing PyTorch wheel")
|
||||
host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}")
|
||||
|
||||
if pytorch_only:
|
||||
return (pytorch_wheel_name, None, None, None, None)
|
||||
domain_wheels = build_domains(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
|
||||
return (pytorch_wheel_name, *domain_wheels)
|
||||
|
||||
|
||||
embed_library_script = """
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from auditwheel.patcher import Patchelf
|
||||
from auditwheel.wheeltools import InWheelCtx
|
||||
from auditwheel.elfutils import elf_file_filter
|
||||
from auditwheel.repair import copylib
|
||||
from auditwheel.lddtree import lddtree
|
||||
from subprocess import check_call
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
|
||||
def replace_tag(filename):
|
||||
with open(filename, 'r') as f:
|
||||
lines = f.read().split("\\n")
|
||||
for i,line in enumerate(lines):
|
||||
if not line.startswith("Tag: "):
|
||||
continue
|
||||
lines[i] = line.replace("-linux_", "-manylinux2014_")
|
||||
print(f'Updated tag from {line} to {lines[i]}')
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
f.write("\\n".join(lines))
|
||||
|
||||
|
||||
class AlignedPatchelf(Patchelf):
|
||||
def set_soname(self, file_name: str, new_soname: str) -> None:
|
||||
check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name])
|
||||
|
||||
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
|
||||
check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name])
|
||||
|
||||
|
||||
def embed_library(whl_path, lib_soname, update_tag=False):
|
||||
patcher = AlignedPatchelf()
|
||||
out_dir = TemporaryDirectory()
|
||||
whl_name = os.path.basename(whl_path)
|
||||
tmp_whl_name = os.path.join(out_dir.name, whl_name)
|
||||
with InWheelCtx(whl_path) as ctx:
|
||||
torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib')
|
||||
ctx.out_wheel=tmp_whl_name
|
||||
new_lib_path, new_lib_soname = None, None
|
||||
for filename, elf in elf_file_filter(ctx.iter_files()):
|
||||
if not filename.startswith('torch/lib'):
|
||||
continue
|
||||
libtree = lddtree(filename)
|
||||
if lib_soname not in libtree['needed']:
|
||||
continue
|
||||
lib_path = libtree['libs'][lib_soname]['path']
|
||||
if lib_path is None:
|
||||
print(f"Can't embed {lib_soname} as it could not be found")
|
||||
break
|
||||
if lib_path.startswith(torchlib_path):
|
||||
continue
|
||||
|
||||
if new_lib_path is None:
|
||||
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
|
||||
patcher.replace_needed(filename, lib_soname, new_lib_soname)
|
||||
print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}')
|
||||
if update_tag:
|
||||
# Add manylinux2014 tag
|
||||
for filename in ctx.iter_files():
|
||||
if os.path.basename(filename) != 'WHEEL':
|
||||
continue
|
||||
replace_tag(filename)
|
||||
shutil.move(tmp_whl_name, whl_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
|
||||
"""
|
||||
|
||||
|
||||
def run_tests(host: RemoteHost, whl: str, branch="main") -> None:
|
||||
print("Configuring the system")
|
||||
update_apt_repo(host)
|
||||
host.run_cmd("sudo apt-get install -y python3-pip git")
|
||||
host.run_cmd("sudo pip3 install Cython")
|
||||
host.run_cmd("sudo pip3 install numpy")
|
||||
host.upload_file(whl, ".")
|
||||
host.run_cmd(f"sudo pip3 install {whl}")
|
||||
host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'")
|
||||
host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch")
|
||||
host.run_cmd("cd pytorch/test; python3 test_torch.py -v")
|
||||
|
||||
|
||||
def get_instance_name(instance) -> Optional[str]:
|
||||
if instance.tags is None:
|
||||
return None
|
||||
for tag in instance.tags:
|
||||
if tag["Key"] == "Name":
|
||||
return tag["Value"]
|
||||
return None
|
||||
|
||||
|
||||
def list_instances(instance_type: str) -> None:
|
||||
print(f"All instances of type {instance_type}")
|
||||
for instance in ec2_instances_of_type(instance_type):
|
||||
ifaces = instance.network_interfaces
|
||||
az = ifaces[0].subnet.availability_zone if len(ifaces) > 0 else None
|
||||
print(
|
||||
f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']} {az}"
|
||||
)
|
||||
|
||||
|
||||
def terminate_instances(instance_type: str) -> None:
|
||||
print(f"Terminating all instances of type {instance_type}")
|
||||
instances = list(ec2_instances_of_type(instance_type))
|
||||
for instance in instances:
|
||||
print(f"Terminating {instance.id}")
|
||||
instance.terminate()
|
||||
print("Waiting for termination to complete")
|
||||
for instance in instances:
|
||||
instance.wait_until_terminated()
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("Build and test AARCH64 wheels using EC2")
|
||||
parser.add_argument("--key-name", type=str)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--build-only", action="store_true")
|
||||
parser.add_argument("--test-only", type=str)
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--os", type=str, choices=list(os_amis.keys()))
|
||||
group.add_argument("--ami", type=str)
|
||||
parser.add_argument(
|
||||
"--python-version",
|
||||
type=str,
|
||||
choices=[f"3.{d}" for d in range(6, 12)],
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--alloc-instance", action="store_true")
|
||||
parser.add_argument("--list-instances", action="store_true")
|
||||
parser.add_argument("--pytorch-only", action="store_true")
|
||||
parser.add_argument("--keep-running", action="store_true")
|
||||
parser.add_argument("--terminate-instances", action="store_true")
|
||||
parser.add_argument("--instance-type", type=str, default="t4g.2xlarge")
|
||||
parser.add_argument("--ebs-size", type=int, default=50)
|
||||
parser.add_argument("--branch", type=str, default="main")
|
||||
parser.add_argument("--use-docker", action="store_true")
|
||||
parser.add_argument(
|
||||
"--compiler",
|
||||
type=str,
|
||||
choices=["gcc-7", "gcc-8", "gcc-9", "clang"],
|
||||
default="gcc-8",
|
||||
)
|
||||
parser.add_argument("--use-torch-from-pypi", action="store_true")
|
||||
parser.add_argument("--pytorch-build-number", type=str, default=None)
|
||||
parser.add_argument("--disable-mkldnn", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
ami = (
|
||||
args.ami
|
||||
if args.ami is not None
|
||||
else os_amis[args.os]
|
||||
if args.os is not None
|
||||
else ubuntu20_04_ami
|
||||
)
|
||||
keyfile_path, key_name = compute_keyfile_path(args.key_name)
|
||||
|
||||
if args.list_instances:
|
||||
list_instances(args.instance_type)
|
||||
sys.exit(0)
|
||||
|
||||
if args.terminate_instances:
|
||||
terminate_instances(args.instance_type)
|
||||
sys.exit(0)
|
||||
|
||||
if len(key_name) == 0:
|
||||
raise RuntimeError("""
|
||||
Cannot start build without key_name, please specify
|
||||
--key-name argument or AWS_KEY_NAME environment variable.""")
|
||||
if len(keyfile_path) == 0 or not os.path.exists(keyfile_path):
|
||||
raise RuntimeError(f"""
|
||||
Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please
|
||||
check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""")
|
||||
|
||||
# Starting the instance
|
||||
inst = start_instance(
|
||||
key_name, ami=ami, instance_type=args.instance_type, ebs_size=args.ebs_size
|
||||
)
|
||||
instance_name = f"{args.key_name}-{args.os}"
|
||||
if args.python_version is not None:
|
||||
instance_name += f"-py{args.python_version}"
|
||||
inst.create_tags(
|
||||
DryRun=False,
|
||||
Tags=[
|
||||
{
|
||||
"Key": "Name",
|
||||
"Value": instance_name,
|
||||
}
|
||||
],
|
||||
)
|
||||
addr = inst.public_dns_name
|
||||
wait_for_connection(addr, 22)
|
||||
host = RemoteHost(addr, keyfile_path)
|
||||
host.ami = ami
|
||||
if args.use_docker:
|
||||
update_apt_repo(host)
|
||||
host.start_docker()
|
||||
|
||||
if args.test_only:
|
||||
run_tests(host, args.test_only)
|
||||
sys.exit(0)
|
||||
|
||||
if args.alloc_instance:
|
||||
if args.python_version is None:
|
||||
sys.exit(0)
|
||||
install_condaforge_python(host, args.python_version)
|
||||
sys.exit(0)
|
||||
|
||||
python_version = args.python_version if args.python_version is not None else "3.10"
|
||||
|
||||
if args.use_torch_from_pypi:
|
||||
configure_system(host, compiler=args.compiler, python_version=python_version)
|
||||
print("Installing PyTorch wheel")
|
||||
host.run_cmd("pip3 install torch")
|
||||
build_domains(
|
||||
host, branch=args.branch, git_clone_flags=" --depth 1 --shallow-submodules"
|
||||
)
|
||||
else:
|
||||
start_build(
|
||||
host,
|
||||
branch=args.branch,
|
||||
compiler=args.compiler,
|
||||
python_version=python_version,
|
||||
pytorch_only=args.pytorch_only,
|
||||
pytorch_build_number=args.pytorch_build_number,
|
||||
enable_mkldnn=not args.disable_mkldnn,
|
||||
)
|
||||
if not args.keep_running:
|
||||
print(f"Waiting for instance {inst.id} to terminate")
|
||||
inst.terminate()
|
||||
inst.wait_until_terminated()
|
||||
@ -1,87 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from subprocess import check_call
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from auditwheel.elfutils import elf_file_filter
|
||||
from auditwheel.lddtree import lddtree
|
||||
from auditwheel.patcher import Patchelf
|
||||
from auditwheel.repair import copylib
|
||||
from auditwheel.wheeltools import InWheelCtx
|
||||
|
||||
|
||||
def replace_tag(filename):
|
||||
with open(filename) as f:
|
||||
lines = f.read().split("\\n")
|
||||
for i, line in enumerate(lines):
|
||||
if not line.startswith("Tag: "):
|
||||
continue
|
||||
lines[i] = line.replace("-linux_", "-manylinux2014_")
|
||||
print(f"Updated tag from {line} to {lines[i]}")
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("\\n".join(lines))
|
||||
|
||||
|
||||
class AlignedPatchelf(Patchelf):
|
||||
def set_soname(self, file_name: str, new_soname: str) -> None:
|
||||
check_call(
|
||||
["patchelf", "--page-size", "65536", "--set-soname", new_soname, file_name]
|
||||
)
|
||||
|
||||
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
|
||||
check_call(
|
||||
[
|
||||
"patchelf",
|
||||
"--page-size",
|
||||
"65536",
|
||||
"--replace-needed",
|
||||
soname,
|
||||
new_soname,
|
||||
file_name,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def embed_library(whl_path, lib_soname, update_tag=False):
|
||||
patcher = AlignedPatchelf()
|
||||
out_dir = TemporaryDirectory()
|
||||
whl_name = os.path.basename(whl_path)
|
||||
tmp_whl_name = os.path.join(out_dir.name, whl_name)
|
||||
with InWheelCtx(whl_path) as ctx:
|
||||
torchlib_path = os.path.join(ctx._tmpdir.name, "torch", "lib")
|
||||
ctx.out_wheel = tmp_whl_name
|
||||
new_lib_path, new_lib_soname = None, None
|
||||
for filename, _ in elf_file_filter(ctx.iter_files()):
|
||||
if not filename.startswith("torch/lib"):
|
||||
continue
|
||||
libtree = lddtree(filename)
|
||||
if lib_soname not in libtree["needed"]:
|
||||
continue
|
||||
lib_path = libtree["libs"][lib_soname]["path"]
|
||||
if lib_path is None:
|
||||
print(f"Can't embed {lib_soname} as it could not be found")
|
||||
break
|
||||
if lib_path.startswith(torchlib_path):
|
||||
continue
|
||||
|
||||
if new_lib_path is None:
|
||||
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
|
||||
patcher.replace_needed(filename, lib_soname, new_lib_soname)
|
||||
print(f"Replacing {lib_soname} with {new_lib_soname} for {filename}")
|
||||
if update_tag:
|
||||
# Add manylinux2014 tag
|
||||
for filename in ctx.iter_files():
|
||||
if os.path.basename(filename) != "WHEEL":
|
||||
continue
|
||||
replace_tag(filename)
|
||||
shutil.move(tmp_whl_name, whl_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
embed_library(
|
||||
sys.argv[1], "libgomp.so.1", len(sys.argv) > 2 and sys.argv[2] == "--update-tag"
|
||||
)
|
||||
@ -4,14 +4,17 @@ set -ex
|
||||
|
||||
SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||
|
||||
# Source the common build script for architecture-specific configurations (MKLDNN, ACL, etc.)
|
||||
source "${SCRIPTPATH}/../pytorch/build.sh" || true
|
||||
|
||||
case "${GPU_ARCH_TYPE:-BLANK}" in
|
||||
cuda)
|
||||
cuda | cuda-aarch64)
|
||||
bash "${SCRIPTPATH}/build_cuda.sh"
|
||||
;;
|
||||
rocm)
|
||||
bash "${SCRIPTPATH}/build_rocm.sh"
|
||||
;;
|
||||
cpu | cpu-cxx11-abi | cpu-s390x)
|
||||
cpu | cpu-cxx11-abi | cpu-aarch64 | cpu-s390x)
|
||||
bash "${SCRIPTPATH}/build_cpu.sh"
|
||||
;;
|
||||
xpu)
|
||||
|
||||
@ -18,12 +18,31 @@ retry () {
|
||||
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
|
||||
}
|
||||
|
||||
# Detect architecture first
|
||||
ARCH=$(uname -m)
|
||||
echo "Detected architecture: $ARCH"
|
||||
|
||||
PLATFORM=""
|
||||
# TODO move this into the Docker images
|
||||
OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release)
|
||||
if [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
|
||||
retry yum install -q -y zip openssl
|
||||
PLATFORM="manylinux_2_28_x86_64"
|
||||
# Set platform based on architecture
|
||||
case $ARCH in
|
||||
x86_64)
|
||||
PLATFORM="manylinux_2_28_x86_64"
|
||||
;;
|
||||
aarch64)
|
||||
PLATFORM="manylinux_2_28_aarch64"
|
||||
;;
|
||||
s390x)
|
||||
PLATFORM="manylinux_2_28_s390x"
|
||||
;;
|
||||
*)
|
||||
echo "Unsupported architecture: $ARCH"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
|
||||
retry dnf install -q -y zip openssl
|
||||
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
|
||||
@ -38,6 +57,8 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Platform set to: $PLATFORM"
|
||||
|
||||
# We use the package name to test the package by passing this to 'pip install'
|
||||
# This is the env variable that setup.py uses to name the package. Note that
|
||||
# pip 'normalizes' the name first by changing all - to _
|
||||
@ -299,8 +320,8 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
|
||||
# ROCm workaround for roctracer dlopens
|
||||
if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then
|
||||
patchedpath=$(fname_without_so_number $destpath)
|
||||
# Keep the so number for XPU dependencies and libgomp.so.1 to avoid twice load
|
||||
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" ]]; then
|
||||
# Keep the so number for XPU dependencies, libgomp.so.1, ACL libraries, and NVPL libraries to avoid twice load
|
||||
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" || "$filename" == libarm_compute* || "$filename" == libnvpl* || "$filename" == "libgfortran.so.5" ]]; then
|
||||
patchedpath=$destpath
|
||||
else
|
||||
patchedpath=$(fname_with_sha256 $destpath)
|
||||
@ -346,9 +367,22 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
|
||||
done
|
||||
|
||||
# create Manylinux 2_28 tag this needs to happen before regenerate the RECORD
|
||||
if [[ $PLATFORM == "manylinux_2_28_x86_64" && $GPU_ARCH_TYPE != "cpu-s390x" && $GPU_ARCH_TYPE != "xpu" ]]; then
|
||||
# Support all architectures (x86_64, aarch64, s390x)
|
||||
if [[ "$IS_MANYLINUX2_28" == "1" && $GPU_ARCH_TYPE != "xpu" ]]; then
|
||||
wheel_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/WHEEL/g')
|
||||
sed -i -e s#linux_x86_64#"${PLATFORM}"# $wheel_file;
|
||||
echo "Updating wheel tag for $ARCH architecture"
|
||||
# Replace linux_* with manylinux_2_28_* based on architecture
|
||||
case $ARCH in
|
||||
x86_64)
|
||||
sed -i -e 's#linux_x86_64#manylinux_2_28_x86_64#g' $wheel_file
|
||||
;;
|
||||
aarch64)
|
||||
sed -i -e 's#linux_aarch64#manylinux_2_28_aarch64#g' $wheel_file
|
||||
;;
|
||||
s390x)
|
||||
sed -i -e 's#linux_s390x#manylinux_2_28_s390x#g' $wheel_file
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# regenerate the RECORD file with new hashes
|
||||
|
||||
@ -15,6 +15,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS=()
|
||||
fi
|
||||
|
||||
# Detect architecture
|
||||
ARCH=$(uname -m)
|
||||
echo "Building CPU wheel for architecture: $ARCH"
|
||||
|
||||
WHEELHOUSE_DIR="wheelhousecpu"
|
||||
LIBTORCH_HOUSE_DIR="libtorch_housecpu"
|
||||
if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then
|
||||
@ -34,8 +38,10 @@ elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
|
||||
elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
|
||||
LIBGOMP_PATH="/usr/lib64/libgomp.so.1"
|
||||
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
|
||||
if [[ "$(uname -m)" == "s390x" ]]; then
|
||||
if [[ "$ARCH" == "s390x" ]]; then
|
||||
LIBGOMP_PATH="/usr/lib/s390x-linux-gnu/libgomp.so.1"
|
||||
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||
LIBGOMP_PATH="/usr/lib/aarch64-linux-gnu/libgomp.so.1"
|
||||
else
|
||||
LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1"
|
||||
fi
|
||||
@ -49,6 +55,32 @@ DEPS_SONAME=(
|
||||
"libgomp.so.1"
|
||||
)
|
||||
|
||||
# Add ARM-specific library dependencies for CPU builds
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Adding ARM-specific CPU library dependencies"
|
||||
|
||||
# ARM Compute Library (if available)
|
||||
if [[ -d "/acl/build" ]]; then
|
||||
echo "Adding ARM Compute Library for CPU"
|
||||
DEPS_LIST+=(
|
||||
"/acl/build/libarm_compute.so"
|
||||
"/acl/build/libarm_compute_graph.so"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libarm_compute.so"
|
||||
"libarm_compute_graph.so"
|
||||
)
|
||||
fi
|
||||
|
||||
# ARM system libraries
|
||||
DEPS_LIST+=(
|
||||
"/usr/lib64/libgfortran.so.5"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libgfortran.so.5"
|
||||
)
|
||||
fi
|
||||
|
||||
rm -rf /usr/local/cuda*
|
||||
|
||||
SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )"
|
||||
|
||||
@ -29,6 +29,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS=()
|
||||
fi
|
||||
|
||||
# Detect architecture
|
||||
ARCH=$(uname -m)
|
||||
echo "Building for architecture: $ARCH"
|
||||
|
||||
# Determine CUDA version and architectures to build for
|
||||
#
|
||||
# NOTE: We should first check `DESIRED_CUDA` when determining `CUDA_VERSION`,
|
||||
@ -53,34 +57,60 @@ fi
|
||||
cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
|
||||
|
||||
# Function to remove architectures from a list
|
||||
remove_archs() {
|
||||
local result="$1"
|
||||
shift
|
||||
for arch in "$@"; do
|
||||
result="${result//${arch};/}"
|
||||
done
|
||||
echo "$result"
|
||||
}
|
||||
|
||||
# Function to filter CUDA architectures for aarch64
|
||||
# aarch64 ARM GPUs only support certain compute capabilities
|
||||
# Keep: 8.0 (A100), 9.0+ (Hopper, Grace Hopper, newer)
|
||||
# Remove: < 8.0 (no ARM GPUs), 8.6 (x86_64 RTX 3090/A6000 only)
|
||||
filter_aarch64_archs() {
|
||||
local arch_list="$1"
|
||||
# Explicitly remove architectures not needed on aarch64
|
||||
arch_list=$(remove_archs "$arch_list" "5.0" "6.0" "7.0" "7.5" "8.6")
|
||||
echo "$arch_list"
|
||||
}
|
||||
|
||||
# Base: Common architectures across all modern CUDA versions
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0"
|
||||
|
||||
case ${CUDA_VERSION} in
|
||||
#removing sm_50-sm_60 as these architectures are deprecated in CUDA 12.8/9 and will be removed in future releases
|
||||
#however we would like to keep sm_70 architecture see: https://github.com/pytorch/pytorch/issues/157517
|
||||
12.8)
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0"
|
||||
;;
|
||||
12.9)
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0+PTX"
|
||||
# WAR to resolve the ld error in libtorch build with CUDA 12.9
|
||||
12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;${TORCH_CUDA_ARCH_LIST}" ;; # Only 12.6 includes Legacy Maxwell/Pascal that will be removed in future releases
|
||||
12.8) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0" ;; # +Hopper/Blackwell support
|
||||
12.9) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0+PTX" # +Hopper/Blackwell support + PTX for forward compatibility
|
||||
if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX"
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//7.0;/}" # Remove 7.0 to resolve the ld error
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//8.6;/}" # Remove 8.6 for libtorch
|
||||
fi
|
||||
;;
|
||||
13.0)
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX"
|
||||
;;
|
||||
12.6)
|
||||
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0"
|
||||
;;
|
||||
*)
|
||||
echo "unknown cuda version $CUDA_VERSION"
|
||||
exit 1
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;$([[ "$ARCH" == "aarch64" ]] && echo "11.0;" || echo "")12.0+PTX"
|
||||
export TORCH_NVCC_FLAGS="-compress-mode=size"
|
||||
export BUILD_BUNDLE_PTXAS=1
|
||||
;;
|
||||
*) echo "unknown cuda version $CUDA_VERSION"; exit 1 ;;
|
||||
esac
|
||||
|
||||
# Filter for aarch64: Remove < 8.0 and 8.6
|
||||
[[ "$ARCH" == "aarch64" ]] && TORCH_CUDA_ARCH_LIST=$(filter_aarch64_archs "$TORCH_CUDA_ARCH_LIST")
|
||||
|
||||
echo "TORCH_CUDA_ARCH_LIST set to: $TORCH_CUDA_ARCH_LIST"
|
||||
export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}
|
||||
echo "${TORCH_CUDA_ARCH_LIST}"
|
||||
|
||||
# Disable MAGMA for aarch64 as pre-built libraries are x86-64 only
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Disabling MAGMA for aarch64 architecture"
|
||||
export USE_MAGMA=0
|
||||
fi
|
||||
|
||||
# Package directories
|
||||
WHEELHOUSE_DIR="wheelhouse$cuda_version_nodot"
|
||||
LIBTORCH_HOUSE_DIR="libtorch_house$cuda_version_nodot"
|
||||
@ -244,6 +274,51 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Add ARM-specific library dependencies
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Adding ARM-specific library dependencies"
|
||||
|
||||
# ARM Compute Library (if available)
|
||||
if [[ -d "/acl/build" ]]; then
|
||||
echo "Adding ARM Compute Library"
|
||||
DEPS_LIST+=(
|
||||
"/acl/build/libarm_compute.so"
|
||||
"/acl/build/libarm_compute_graph.so"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libarm_compute.so"
|
||||
"libarm_compute_graph.so"
|
||||
)
|
||||
fi
|
||||
|
||||
# ARM system libraries
|
||||
DEPS_LIST+=(
|
||||
"/lib64/libgomp.so.1"
|
||||
"/usr/lib64/libgfortran.so.5"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libgomp.so.1"
|
||||
"libgfortran.so.5"
|
||||
)
|
||||
|
||||
# NVPL libraries (ARM optimized BLAS/LAPACK)
|
||||
if [[ -d "/usr/local/lib" && -f "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ]]; then
|
||||
echo "Adding NVPL libraries for ARM"
|
||||
DEPS_LIST+=(
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0"
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0"
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0"
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libnvpl_lapack_lp64_gomp.so.0"
|
||||
"libnvpl_blas_lp64_gomp.so.0"
|
||||
"libnvpl_lapack_core.so.0"
|
||||
"libnvpl_blas_core.so.0"
|
||||
)
|
||||
fi
|
||||
fi
|
||||
|
||||
# run_tests.sh requires DESIRED_CUDA to know what tests to exclude
|
||||
export DESIRED_CUDA="$cuda_version_nodot"
|
||||
|
||||
@ -251,9 +326,11 @@ export DESIRED_CUDA="$cuda_version_nodot"
|
||||
rm -rf /usr/local/cuda || true
|
||||
ln -s "/usr/local/cuda-${CUDA_VERSION}" /usr/local/cuda
|
||||
|
||||
# Switch `/usr/local/magma` to the desired CUDA version
|
||||
rm -rf /usr/local/magma || true
|
||||
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
|
||||
# Switch `/usr/local/magma` to the desired CUDA version (skip for aarch64)
|
||||
if [[ "$ARCH" != "aarch64" ]]; then
|
||||
rm -rf /usr/local/magma || true
|
||||
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
|
||||
fi
|
||||
|
||||
export CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) # 10.0.130
|
||||
export CUDA_VERSION_SHORT=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev | cut -f1,2 -d".") # 10.0
|
||||
|
||||
@ -86,10 +86,20 @@ else
|
||||
fi
|
||||
fi
|
||||
|
||||
# Enable MKLDNN with ARM Compute Library for ARM builds
|
||||
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
|
||||
export USE_MKLDNN=1
|
||||
|
||||
# ACL is required for aarch64 builds
|
||||
if [[ ! -d "/acl" ]]; then
|
||||
echo "ERROR: ARM Compute Library not found at /acl"
|
||||
echo "ACL is required for aarch64 builds. Check Docker image setup."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export USE_MKLDNN_ACL=1
|
||||
export ACL_ROOT_DIR=/acl
|
||||
echo "ARM Compute Library enabled for MKLDNN: ACL_ROOT_DIR=/acl"
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then
|
||||
|
||||
@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _compile_and_extract_symbols(
|
||||
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Helper to compile a C++ file and extract all symbols.
|
||||
|
||||
Args:
|
||||
cpp_content: C++ source code to compile
|
||||
compile_flags: Compilation flags
|
||||
exclude_list: List of symbol names to exclude. Defaults to ["main"].
|
||||
|
||||
Returns:
|
||||
List of all symbols found in the object file (excluding those in exclude_list).
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
if exclude_list is None:
|
||||
exclude_list = ["main"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmppath = Path(tmpdir)
|
||||
cpp_file = tmppath / "test.cpp"
|
||||
obj_file = tmppath / "test.o"
|
||||
|
||||
cpp_file.write_text(cpp_content)
|
||||
|
||||
result = subprocess.run(
|
||||
compile_flags + [str(cpp_file), "-o", str(obj_file)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Compilation failed: {result.stderr}")
|
||||
|
||||
symbols = get_symbols(str(obj_file))
|
||||
|
||||
# Return all symbol names, excluding those in the exclude list
|
||||
return [name for _addr, _stype, name in symbols if name not in exclude_list]
|
||||
|
||||
|
||||
def check_stable_only_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
|
||||
|
||||
This approach tests:
|
||||
1. WITHOUT macros -> many torch symbols exposed
|
||||
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
|
||||
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
|
||||
4. WITH both macros -> zero torch symbols (all hidden)
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
test_cpp_content = """
|
||||
// Main torch C++ API headers
|
||||
#include <torch/torch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// ATen tensor library
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
// Core c10 headers (commonly used)
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
int main() { return 0; }
|
||||
"""
|
||||
|
||||
base_compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c", # Compile only, don't link
|
||||
]
|
||||
|
||||
# Compile WITHOUT any macros
|
||||
symbols_without = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=base_compile_flags,
|
||||
)
|
||||
|
||||
# We expect constexpr symbols, inline functions used by other headers etc.
|
||||
# to produce symbols
|
||||
num_symbols_without = len(symbols_without)
|
||||
print(f"Found {num_symbols_without} symbols without any macros defined")
|
||||
assert num_symbols_without != 0, (
|
||||
"Expected a non-zero number of symbols without any macros"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
|
||||
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
|
||||
|
||||
symbols_with_stable_only = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_stable_only,
|
||||
)
|
||||
|
||||
num_symbols_with_stable_only = len(symbols_with_stable_only)
|
||||
assert num_symbols_with_stable_only == 0, (
|
||||
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
|
||||
compile_flags_with_target_version = base_compile_flags + [
|
||||
"-DTORCH_TARGET_VERSION=1"
|
||||
]
|
||||
|
||||
symbols_with_target_version = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_target_version,
|
||||
)
|
||||
|
||||
num_symbols_with_target_version = len(symbols_with_target_version)
|
||||
assert num_symbols_with_target_version == 0, (
|
||||
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
|
||||
)
|
||||
|
||||
# Compile WITH both macros (expect 0 symbols)
|
||||
compile_flags_with_both = base_compile_flags + [
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
"-DTORCH_TARGET_VERSION=1",
|
||||
]
|
||||
|
||||
symbols_with_both = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_both,
|
||||
)
|
||||
|
||||
num_symbols_with_both = len(symbols_with_both)
|
||||
assert num_symbols_with_both == 0, (
|
||||
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_api_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
stable_dir = include_dir / "torch" / "csrc" / "stable"
|
||||
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
|
||||
|
||||
stable_headers = list(stable_dir.rglob("*.h"))
|
||||
if not stable_headers:
|
||||
raise RuntimeError("Could not find any stable headers")
|
||||
|
||||
includes = []
|
||||
for header in stable_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_stable_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable = len(symbols_stable)
|
||||
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
|
||||
assert num_symbols_stable > 0, (
|
||||
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_headeronly_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Find all headers in torch/headeronly
|
||||
headeronly_dir = include_dir / "torch" / "headeronly"
|
||||
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
|
||||
headeronly_headers = list(headeronly_dir.rglob("*.h"))
|
||||
if not headeronly_headers:
|
||||
raise RuntimeError("Could not find any headeronly headers")
|
||||
|
||||
# Filter out platform-specific headers that may not compile everywhere
|
||||
platform_specific_keywords = [
|
||||
"cpu/vec",
|
||||
]
|
||||
|
||||
filtered_headers = []
|
||||
for header in headeronly_headers:
|
||||
rel_path = header.relative_to(include_dir).as_posix()
|
||||
if not any(
|
||||
keyword in rel_path.lower() for keyword in platform_specific_keywords
|
||||
):
|
||||
filtered_headers.append(header)
|
||||
|
||||
includes = []
|
||||
for header in filtered_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_headeronly_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_headeronly = _compile_and_extract_symbols(
|
||||
cpp_content=test_headeronly_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_headeronly = len(symbols_headeronly)
|
||||
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
|
||||
assert num_symbols_headeronly > 0, (
|
||||
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_headeronly} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_aoti_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_shim_content = """
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
int main() {
|
||||
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
|
||||
int32_t (*fp2)() = &aoti_torch_dtype_float32;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_shim = len(symbols_shim)
|
||||
assert num_symbols_shim > 0, (
|
||||
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_c_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Check if the stable C shim exists
|
||||
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
|
||||
if not stable_shim.exists():
|
||||
raise RuntimeError("Could not find stable c shim")
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_stable_shim_content = """
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
int main() {
|
||||
// Reference stable C API functions to create undefined symbols
|
||||
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
|
||||
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable_shim = len(symbols_stable_shim)
|
||||
assert num_symbols_stable_shim > 0, (
|
||||
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
|
||||
print(f"lib: {lib}")
|
||||
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
|
||||
@ -460,13 +129,6 @@ def main() -> None:
|
||||
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
|
||||
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
|
||||
|
||||
# Check symbols when TORCH_STABLE_ONLY is defined
|
||||
check_stable_only_symbols(install_root)
|
||||
check_stable_api_symbols(install_root)
|
||||
check_headeronly_symbols(install_root)
|
||||
check_aoti_shim_symbols(install_root)
|
||||
check_stable_c_shim_symbols(install_root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -389,6 +389,13 @@ test_lazy_tensor_meta_reference_disabled() {
|
||||
export -n TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE
|
||||
}
|
||||
|
||||
test_dynamo_core() {
|
||||
time python test/run_test.py \
|
||||
--include-dynamo-core-tests \
|
||||
--verbose \
|
||||
--upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_dynamo_wrapped_shard() {
|
||||
if [[ -z "$NUM_TEST_SHARDS" ]]; then
|
||||
@ -1814,6 +1821,8 @@ elif [[ "${TEST_CONFIG}" == *inductor* ]]; then
|
||||
test_inductor_shard "${SHARD_NUMBER}"
|
||||
elif [[ "${TEST_CONFIG}" == *einops* ]]; then
|
||||
test_einops
|
||||
elif [[ "${TEST_CONFIG}" == *dynamo_core* ]]; then
|
||||
test_dynamo_core
|
||||
elif [[ "${TEST_CONFIG}" == *dynamo_wrapped* ]]; then
|
||||
install_torchvision
|
||||
test_dynamo_wrapped_shard "${SHARD_NUMBER}"
|
||||
|
||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -7,6 +7,7 @@ ciflow_push_tags:
|
||||
- ciflow/binaries
|
||||
- ciflow/binaries_libtorch
|
||||
- ciflow/binaries_wheel
|
||||
- ciflow/dynamo
|
||||
- ciflow/h100
|
||||
- ciflow/h100-cutlass-backend
|
||||
- ciflow/h100-distributed
|
||||
|
||||
2
.github/scripts/generate_pytorch_version.py
vendored
2
.github/scripts/generate_pytorch_version.py
vendored
@ -50,7 +50,7 @@ def get_tag() -> str:
|
||||
|
||||
def get_base_version() -> str:
|
||||
root = get_pytorch_root()
|
||||
dirty_version = open(root / "version.txt").read().strip()
|
||||
dirty_version = Path(root / "version.txt").read_text().strip()
|
||||
# Strips trailing a0 from version.txt, not too sure why it's there in the
|
||||
# first place
|
||||
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
|
||||
|
||||
7
.github/workflows/_binary-build-linux.yml
vendored
7
.github/workflows/_binary-build-linux.yml
vendored
@ -260,11 +260,8 @@ jobs:
|
||||
"${DOCKER_IMAGE}"
|
||||
)
|
||||
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
|
||||
if [[ ${BUILD_ENVIRONMENT} == *"aarch64"* ]]; then
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/aarch64_linux/aarch64_ci_build.sh"
|
||||
else
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
|
||||
fi
|
||||
# Unified build script for all architectures (x86_64, aarch64, s390x)
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
|
||||
|
||||
- name: Chown artifacts
|
||||
if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }}
|
||||
|
||||
2
.github/workflows/_linux-test.yml
vendored
2
.github/workflows/_linux-test.yml
vendored
@ -326,7 +326,7 @@ jobs:
|
||||
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
|
||||
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
|
||||
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
|
||||
DOCKER_IMAGE: ${{ inputs.docker-image }}
|
||||
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
|
||||
XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla
|
||||
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
|
||||
|
||||
70
.github/workflows/dynamo-unittest.yml
vendored
Normal file
70
.github/workflows/dynamo-unittest.yml
vendored
Normal file
@ -0,0 +1,70 @@
|
||||
# Workflow: Dynamo Unit Test
|
||||
# runs unit tests for dynamo.
|
||||
name: dynamo-unittest
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/dynamo/*
|
||||
workflow_call:
|
||||
schedule:
|
||||
- cron: 29 8 * * * # about 1:29am PDT
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
dynamo-build:
|
||||
name: dynamo-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.11', '3.12']
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
dynamo-test:
|
||||
name: dynamo-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: [get-label-type, dynamo-build]
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.11', '3.12']
|
||||
with:
|
||||
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
330
.spin/cmds.py
Normal file
330
.spin/cmds.py
Normal file
@ -0,0 +1,330 @@
|
||||
import hashlib
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import spin
|
||||
|
||||
|
||||
def file_digest(file, algorithm: str):
|
||||
try:
|
||||
return hashlib.file_digest(file, algorithm)
|
||||
except AttributeError:
|
||||
pass # Fallback to manual implementation below
|
||||
hash = hashlib.new(algorithm)
|
||||
while chunk := file.read(8192):
|
||||
hash.update(chunk)
|
||||
return hash
|
||||
|
||||
|
||||
def _hash_file(file):
|
||||
with open(file, "rb") as f:
|
||||
hash = file_digest(f, "sha256")
|
||||
return hash.hexdigest()
|
||||
|
||||
|
||||
def _hash_files(files):
|
||||
hashes = {file: _hash_file(file) for file in files}
|
||||
return hashes
|
||||
|
||||
|
||||
def _read_hashes(hash_file: Path):
|
||||
if not hash_file.exists():
|
||||
return {}
|
||||
with hash_file.open("r") as f:
|
||||
lines = f.readlines()
|
||||
hashes = {}
|
||||
for line in lines:
|
||||
hash = line[:64]
|
||||
file = line[66:].strip()
|
||||
hashes[file] = hash
|
||||
return hashes
|
||||
|
||||
|
||||
def _updated_hashes(hash_file, files_to_hash):
|
||||
old_hashes = _read_hashes(hash_file)
|
||||
new_hashes = _hash_files(files_to_hash)
|
||||
if new_hashes != old_hashes:
|
||||
return new_hashes
|
||||
return None
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_version():
|
||||
"""Regenerate version.py."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.generate_torch_version",
|
||||
"--is-debug=false",
|
||||
]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
TYPE_STUBS = [
|
||||
(
|
||||
"Pytorch type stubs",
|
||||
Path(".lintbin/.pytorch-type-stubs.sha256"),
|
||||
[
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"tools/autograd/deprecated.yaml",
|
||||
],
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.pyi.gen_pyi",
|
||||
"--native-functions-path",
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"--tags-path",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"--deprecated-functions-path",
|
||||
"tools/autograd/deprecated.yaml",
|
||||
],
|
||||
),
|
||||
(
|
||||
"Datapipes type stubs",
|
||||
None,
|
||||
[],
|
||||
[
|
||||
sys.executable,
|
||||
"torch/utils/data/datapipes/gen_pyi.py",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_type_stubs():
|
||||
"""Regenerate type stubs."""
|
||||
for name, hash_file, files_to_hash, cmd in TYPE_STUBS:
|
||||
if hash_file:
|
||||
if hashes := _updated_hashes(hash_file, files_to_hash):
|
||||
click.echo(
|
||||
f"Changes detected in type stub files for {name}. Regenerating..."
|
||||
)
|
||||
spin.util.run(cmd)
|
||||
hash_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with hash_file.open("w") as f:
|
||||
for file, hash in hashes.items():
|
||||
f.write(f"{hash} {file}\n")
|
||||
click.echo("Type stubs and hashes updated.")
|
||||
else:
|
||||
click.echo(f"No changes detected in type stub files for {name}.")
|
||||
else:
|
||||
click.echo(f"No hash file for {name}. Regenerating...")
|
||||
spin.util.run(cmd)
|
||||
click.echo("Type stubs regenerated.")
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_clangtidy_files():
|
||||
"""Regenerate clang-tidy files."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.linter.clang_tidy.generate_build_files",
|
||||
]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
#: These linters are expected to need less than 3s cpu time total
|
||||
VERY_FAST_LINTERS = {
|
||||
"ATEN_CPU_GPU_AGNOSTIC",
|
||||
"BAZEL_LINTER",
|
||||
"C10_NODISCARD",
|
||||
"C10_UNUSED",
|
||||
"CALL_ONCE",
|
||||
"CMAKE_MINIMUM_REQUIRED",
|
||||
"CONTEXT_DECORATOR",
|
||||
"COPYRIGHT",
|
||||
"CUBINCLUDE",
|
||||
"DEPLOY_DETECTION",
|
||||
"ERROR_PRONE_ISINSTANCE",
|
||||
"EXEC",
|
||||
"HEADER_ONLY_LINTER",
|
||||
"IMPORT_LINTER",
|
||||
"INCLUDE",
|
||||
"LINTRUNNER_VERSION",
|
||||
"MERGE_CONFLICTLESS_CSV",
|
||||
"META_NO_CREATE_UNBACKED",
|
||||
"NEWLINE",
|
||||
"NOQA",
|
||||
"NO_WORKFLOWS_ON_FORK",
|
||||
"ONCE_FLAG",
|
||||
"PYBIND11_INCLUDE",
|
||||
"PYBIND11_SPECIALIZATION",
|
||||
"PYPIDEP",
|
||||
"PYPROJECT",
|
||||
"RAWCUDA",
|
||||
"RAWCUDADEVICE",
|
||||
"ROOT_LOGGING",
|
||||
"TABS",
|
||||
"TESTOWNERS",
|
||||
"TYPEIGNORE",
|
||||
"TYPENOSKIP",
|
||||
"WORKFLOWSYNC",
|
||||
}
|
||||
|
||||
|
||||
#: These linters are expected to take a few seconds, but less than 10s cpu time total
|
||||
FAST_LINTERS = {
|
||||
"CMAKE",
|
||||
"DOCSTRING_LINTER",
|
||||
"GHA",
|
||||
"NATIVEFUNCTIONS",
|
||||
"RUFF",
|
||||
"SET_LINTER",
|
||||
"SHELLCHECK",
|
||||
"SPACES",
|
||||
}
|
||||
|
||||
|
||||
#: These linters are expected to take more than 10s cpu time total;
|
||||
#: some need more than 1 hour.
|
||||
SLOW_LINTERS = {
|
||||
"ACTIONLINT",
|
||||
"CLANGFORMAT",
|
||||
"CLANGTIDY",
|
||||
"CODESPELL",
|
||||
"FLAKE8",
|
||||
"GB_REGISTRY",
|
||||
"PYFMT",
|
||||
"PYREFLY",
|
||||
"TEST_DEVICE_BIAS",
|
||||
"TEST_HAS_MAIN",
|
||||
}
|
||||
|
||||
|
||||
ALL_LINTERS = VERY_FAST_LINTERS | FAST_LINTERS | SLOW_LINTERS
|
||||
|
||||
|
||||
LINTRUNNER_CACHE_INFO = (
|
||||
Path(".lintbin/.lintrunner.sha256"),
|
||||
[
|
||||
"requirements.txt",
|
||||
"pyproject.toml",
|
||||
".lintrunner.toml",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
LINTRUNNER_BASE_CMD = [
|
||||
"uvx",
|
||||
"--python",
|
||||
"3.10",
|
||||
"lintrunner@0.12.7",
|
||||
]
|
||||
|
||||
|
||||
@click.command()
|
||||
def setup_lint():
|
||||
"""Set up lintrunner with current CI version."""
|
||||
cmd = LINTRUNNER_BASE_CMD + ["init"]
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
|
||||
|
||||
def _check_linters():
|
||||
cmd = LINTRUNNER_BASE_CMD + ["list"]
|
||||
ret = spin.util.run(cmd, output=False, stderr=subprocess.PIPE)
|
||||
linters = {l.strip() for l in ret.stdout.decode().strip().split("\n")[1:]}
|
||||
unknown_linters = linters - ALL_LINTERS
|
||||
missing_linters = ALL_LINTERS - linters
|
||||
if unknown_linters:
|
||||
click.secho(
|
||||
f"Unknown linters found; please add them to the correct category "
|
||||
f"in .spin/cmds.py: {', '.join(unknown_linters)}",
|
||||
fg="yellow",
|
||||
)
|
||||
if missing_linters:
|
||||
click.secho(
|
||||
f"Missing linters found; please update the corresponding category "
|
||||
f"in .spin/cmds.py: {', '.join(missing_linters)}",
|
||||
fg="yellow",
|
||||
)
|
||||
return unknown_linters, missing_linters
|
||||
|
||||
|
||||
@spin.util.extend_command(
|
||||
setup_lint,
|
||||
doc=f"""
|
||||
If configuration has changed, update lintrunner.
|
||||
|
||||
Compares the stored old hashes of configuration files with new ones and
|
||||
performs setup via setup-lint if the hashes have changed.
|
||||
Hashes are stored in {LINTRUNNER_CACHE_INFO[0]}; the following files are
|
||||
considered: {", ".join(LINTRUNNER_CACHE_INFO[1])}.
|
||||
""",
|
||||
)
|
||||
@click.pass_context
|
||||
def lazy_setup_lint(ctx, parent_callback, **kwargs):
|
||||
if hashes := _updated_hashes(*LINTRUNNER_CACHE_INFO):
|
||||
click.echo(
|
||||
"Changes detected in lint configuration files. Setting up linting tools..."
|
||||
)
|
||||
parent_callback(**kwargs)
|
||||
hash_file = LINTRUNNER_CACHE_INFO[0]
|
||||
hash_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with hash_file.open("w") as f:
|
||||
for file, hash in hashes.items():
|
||||
f.write(f"{hash} {file}\n")
|
||||
click.echo("Linting tools set up and hashes updated.")
|
||||
else:
|
||||
click.echo("No changes detected in lint configuration files. Skipping setup.")
|
||||
click.echo("Regenerating version...")
|
||||
ctx.invoke(regenerate_version)
|
||||
click.echo("Regenerating type stubs...")
|
||||
ctx.invoke(regenerate_type_stubs)
|
||||
click.echo("Done.")
|
||||
_check_linters()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-a", "--apply-patches", is_flag=True)
|
||||
@click.pass_context
|
||||
def lint(ctx, apply_patches, **kwargs):
|
||||
"""Lint all files."""
|
||||
ctx.invoke(lazy_setup_lint)
|
||||
all_files_linters = VERY_FAST_LINTERS | FAST_LINTERS
|
||||
changed_files_linters = SLOW_LINTERS
|
||||
cmd = LINTRUNNER_BASE_CMD
|
||||
if apply_patches:
|
||||
cmd += ["--apply-patches"]
|
||||
all_files_cmd = cmd + [
|
||||
"--take",
|
||||
",".join(all_files_linters),
|
||||
"--all-files",
|
||||
]
|
||||
spin.util.run(all_files_cmd)
|
||||
changed_files_cmd = cmd + [
|
||||
"--take",
|
||||
",".join(changed_files_linters),
|
||||
]
|
||||
spin.util.run(changed_files_cmd)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.pass_context
|
||||
def fixlint(ctx, **kwargs):
|
||||
"""Autofix all files."""
|
||||
ctx.invoke(lint, apply_patches=True)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-a", "--apply-patches", is_flag=True)
|
||||
@click.pass_context
|
||||
def quicklint(ctx, apply_patches, **kwargs):
|
||||
"""Lint changed files."""
|
||||
ctx.invoke(lazy_setup_lint)
|
||||
cmd = LINTRUNNER_BASE_CMD
|
||||
if apply_patches:
|
||||
cmd += ["--apply-patches"]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.pass_context
|
||||
def quickfix(ctx, **kwargs):
|
||||
"""Autofix changed files."""
|
||||
ctx.invoke(quicklint, apply_patches=True)
|
||||
@ -223,6 +223,62 @@ CONVERT_FROM_BF16_TEMPLATE(double)
|
||||
CONVERT_FROM_BF16_TEMPLATE(float16_t)
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
|
||||
// clang-[17, 20] crashes when autovectorizing static cast to bf16
|
||||
// Below is a workaround to have some vectorization
|
||||
// Works decently well for smaller int types
|
||||
template <typename from_type>
|
||||
inline void convertToBf16Impl(
|
||||
const from_type* __restrict src,
|
||||
c10::BFloat16* __restrict dst,
|
||||
uint64_t n) {
|
||||
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
|
||||
uint64_t loopBound = n - (n % 16);
|
||||
uint64_t i = 0;
|
||||
for (; i < loopBound; i += 16) {
|
||||
float32x4_t a, b, c, d;
|
||||
a[0] = static_cast<float>(src[i]);
|
||||
a[1] = static_cast<float>(src[i + 1]);
|
||||
a[2] = static_cast<float>(src[i + 2]);
|
||||
a[3] = static_cast<float>(src[i + 3]);
|
||||
b[0] = static_cast<float>(src[i + 4]);
|
||||
b[1] = static_cast<float>(src[i + 5]);
|
||||
b[2] = static_cast<float>(src[i + 6]);
|
||||
b[3] = static_cast<float>(src[i + 7]);
|
||||
c[0] = static_cast<float>(src[i + 8]);
|
||||
c[1] = static_cast<float>(src[i + 9]);
|
||||
c[2] = static_cast<float>(src[i + 10]);
|
||||
c[3] = static_cast<float>(src[i + 11]);
|
||||
d[0] = static_cast<float>(src[i + 12]);
|
||||
d[1] = static_cast<float>(src[i + 13]);
|
||||
d[2] = static_cast<float>(src[i + 14]);
|
||||
d[3] = static_cast<float>(src[i + 15]);
|
||||
|
||||
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
|
||||
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
|
||||
}
|
||||
|
||||
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
|
||||
for (; i < n; i++) {
|
||||
float a = static_cast<float>(src[i]);
|
||||
dstPtr[i] = vcvth_bf16_f32(a);
|
||||
}
|
||||
}
|
||||
|
||||
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
|
||||
template <> \
|
||||
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
|
||||
return convertToBf16Impl<from_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_TO_BF16_TEMPLATE(uint8_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int8_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int16_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int32_t)
|
||||
|
||||
#endif
|
||||
|
||||
inline void convertBoolToBfloat16Impl(
|
||||
const bool* __restrict src,
|
||||
c10::BFloat16* __restrict dst,
|
||||
|
||||
@ -175,17 +175,24 @@ void CUDAGraph::instantiate() {
|
||||
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
|
||||
// who prefer not to report error message through these arguments moving forward
|
||||
// (they prefer return value, or errors on api calls internal to the capture)
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, 0));
|
||||
// ROCM appears to fail with HIP error: invalid argument
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && !defined(USE_ROCM)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, cudaGraphInstantiateFlagUseNodePriority));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
|
||||
#endif
|
||||
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
|
||||
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
|
||||
} else {
|
||||
#if !defined(USE_ROCM)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch | cudaGraphInstantiateFlagUseNodePriority));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
||||
#endif
|
||||
}
|
||||
has_graph_exec_ = true;
|
||||
}
|
||||
|
||||
@ -904,19 +904,11 @@ Tensor mvlgamma(const Tensor& self, int64_t p) {
|
||||
return args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER);
|
||||
}
|
||||
|
||||
// since mvlgamma_ has different signature from its
|
||||
// out and functional variant, we explicitly
|
||||
// define it (instead of using structured kernel).
|
||||
Tensor& mvlgamma_(Tensor& self, int64_t p) {
|
||||
mvlgamma_check(self, p);
|
||||
Tensor args = native::arange(
|
||||
-p *HALF + HALF,
|
||||
HALF,
|
||||
HALF,
|
||||
optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().layout_opt(),
|
||||
self.options().device_opt(),
|
||||
self.options().pinned_memory_opt());
|
||||
args = args.add(self.unsqueeze(-1));
|
||||
const auto p2_sub_p = static_cast<double>(p * (p - 1));
|
||||
return self.copy_(args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER));
|
||||
return at::mvlgamma_out(self, self, p);
|
||||
}
|
||||
|
||||
Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {
|
||||
|
||||
@ -75,30 +75,52 @@ static inline bool can_use_int32_nhwc(
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline bool can_use_int32_nchw(
|
||||
int64_t nbatch, int64_t channels,
|
||||
int64_t height, int64_t width,
|
||||
int64_t pooled_height, int64_t pooled_width) {
|
||||
int64_t hw = height * width;
|
||||
return can_use_int32_nhwc(
|
||||
nbatch, channels, height, width,
|
||||
pooled_height, pooled_width,
|
||||
channels * hw, // in_stride_n
|
||||
hw, // in_stride_c
|
||||
width, // in_stride_h
|
||||
1 // in_stride_w
|
||||
);
|
||||
}
|
||||
|
||||
// kernels borrowed from Caffe
|
||||
template <typename scalar_t>
|
||||
__global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom_data,
|
||||
const int64_t channels, const int64_t height,
|
||||
const int64_t width, const int pooled_height, const int pooled_width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w, scalar_t* top_data,
|
||||
template <typename scalar_t, typename index_t>
|
||||
__global__ void max_pool_forward_nchw(
|
||||
const index_t nthreads,
|
||||
const scalar_t* bottom_data,
|
||||
const int64_t channels,
|
||||
const int64_t height,
|
||||
const int64_t width,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
scalar_t* top_data,
|
||||
int64_t* top_mask) {
|
||||
CUDA_KERNEL_LOOP(index, nthreads) {
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
int hstart = ph * stride_h - pad_h;
|
||||
int wstart = pw * stride_w - pad_w;
|
||||
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
|
||||
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
|
||||
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
|
||||
index_t pw = index % pooled_width;
|
||||
index_t ph = (index / pooled_width) % pooled_height;
|
||||
index_t c = (index / pooled_width / pooled_height) % channels;
|
||||
index_t n = index / pooled_width / pooled_height / channels;
|
||||
index_t hstart = ph * stride_h - pad_h;
|
||||
index_t wstart = pw * stride_w - pad_w;
|
||||
index_t hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
|
||||
index_t wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
|
||||
while(hstart < 0)
|
||||
hstart += dilation_h;
|
||||
while(wstart < 0)
|
||||
wstart += dilation_w;
|
||||
scalar_t maxval = at::numeric_limits<scalar_t>::lower_bound(); // -Infinity
|
||||
int maxidx = hstart * width + wstart;
|
||||
index_t maxidx = hstart * width + wstart;
|
||||
const scalar_t* btm_data = bottom_data + (n * channels + c) * height * width;
|
||||
for (int h = hstart; h < hend; h += dilation_h) {
|
||||
for (int w = wstart; w < wend; w += dilation_w) {
|
||||
@ -251,32 +273,39 @@ __global__ void max_pool_forward_nhwc(
|
||||
|
||||
static constexpr int BLOCK_THREADS = 256;
|
||||
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
template <typename scalar_t, typename accscalar_t, typename index_t>
|
||||
#if defined (USE_ROCM)
|
||||
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4)
|
||||
#else
|
||||
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 8)
|
||||
#endif
|
||||
__global__ void max_pool_backward_nchw(const scalar_t* top_diff,
|
||||
const int64_t* top_mask, const int num, const int64_t channels,
|
||||
const int64_t height, const int64_t width, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
|
||||
__global__ void max_pool_backward_nchw(
|
||||
const scalar_t* top_diff,
|
||||
const int64_t* top_mask,
|
||||
const index_t num,
|
||||
const index_t channels,
|
||||
const index_t height,
|
||||
const index_t width,
|
||||
const index_t pooled_height,
|
||||
const index_t pooled_width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
scalar_t* bottom_diff) {
|
||||
CUDA_KERNEL_LOOP(index, height*width) {
|
||||
int h = index / width;
|
||||
int w = index - h * width;
|
||||
int phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
|
||||
int phend = p_end(h, pad_h, pooled_height, stride_h);
|
||||
int pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
|
||||
int pwend = p_end(w, pad_w, pooled_width, stride_w);
|
||||
for (int n = blockIdx.y; n < num; n += gridDim.y) {
|
||||
for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
|
||||
CUDA_KERNEL_LOOP_TYPE(index, height*width, index_t) {
|
||||
index_t h = index / width;
|
||||
index_t w = index - h * width;
|
||||
index_t phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
|
||||
index_t phend = p_end(h, pad_h, pooled_height, stride_h);
|
||||
index_t pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
|
||||
index_t pwend = p_end(w, pad_w, pooled_width, stride_w);
|
||||
for (index_t n = blockIdx.y; n < num; n += gridDim.y) {
|
||||
for (index_t c = blockIdx.z; c < channels; c += gridDim.z) {
|
||||
accscalar_t gradient = accscalar_t(0);
|
||||
int offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
for (int ph = phstart; ph < phend; ++ph) {
|
||||
for (int pw = pwstart; pw < pwend; ++pw) {
|
||||
index_t offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
for (index_t ph = phstart; ph < phend; ++ph) {
|
||||
for (index_t pw = pwstart; pw < pwend; ++pw) {
|
||||
if (top_mask[ph * pooled_width + pw + offset] == h * width + w) {
|
||||
gradient += static_cast<accscalar_t>(top_diff[ph * pooled_width + pw + offset]);
|
||||
}
|
||||
@ -469,8 +498,6 @@ const Tensor& indices) {
|
||||
const int64_t in_stride_h = input.stride(-2);
|
||||
const int64_t in_stride_w = input.stride(-1);
|
||||
|
||||
const int count = safe_downcast<int, int64_t>(output.numel());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"max_pool2d_with_indices_out_cuda_frame",
|
||||
[&] {
|
||||
@ -553,14 +580,42 @@ const Tensor& indices) {
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::Contiguous: {
|
||||
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
|
||||
BLOCK_THREADS);
|
||||
max_pool_forward_nchw<scalar_t>
|
||||
<<<ceil_div(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count, input_data,
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
output_data, indices_data);
|
||||
const int threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
|
||||
BLOCK_THREADS);
|
||||
const int64_t nthreads = output.numel();
|
||||
bool use_int32 = can_use_int32_nchw(
|
||||
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
|
||||
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
|
||||
const int blocks = static_cast<int>(std::min<int64_t>(
|
||||
ceil_div(nthreads, static_cast<int64_t>(threads)),
|
||||
static_cast<int64_t>(maxGridX)));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (use_int32) {
|
||||
max_pool_forward_nchw<scalar_t, int32_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
static_cast<int32_t>(nthreads),
|
||||
input_data,
|
||||
static_cast<int32_t>(nInputPlane),
|
||||
static_cast<int32_t>(inputHeight),
|
||||
static_cast<int32_t>(inputWidth),
|
||||
static_cast<int32_t>(outputHeight),
|
||||
static_cast<int32_t>(outputWidth),
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
output_data, indices_data);
|
||||
} else {
|
||||
max_pool_forward_nchw<scalar_t, int64_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
nthreads,
|
||||
input_data,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
output_data, indices_data);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
break;
|
||||
}
|
||||
@ -633,8 +688,6 @@ const Tensor& gradInput) {
|
||||
|
||||
gradInput.zero_();
|
||||
|
||||
int64_t count = input.numel();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"max_pool2d_with_indices_out_cuda_frame",
|
||||
[&] {
|
||||
@ -692,25 +745,45 @@ const Tensor& gradInput) {
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::Contiguous: {
|
||||
int imgcount = inputWidth * inputHeight;
|
||||
dim3 grid;
|
||||
const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS;
|
||||
grid.x = blocks;
|
||||
grid.y = nbatch;
|
||||
uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
if (maxGridY < grid.y) grid.y = maxGridY;
|
||||
grid.z = nInputPlane;
|
||||
uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
|
||||
if (maxGridZ < grid.z) grid.z = maxGridZ;
|
||||
|
||||
max_pool_backward_nchw<scalar_t, accscalar_t>
|
||||
<<<grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
gradOutput_data,
|
||||
indices_data,
|
||||
nbatch,
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
gradInput_data);
|
||||
const int threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
|
||||
BLOCK_THREADS);
|
||||
const int imgcount = inputWidth * inputHeight;
|
||||
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
|
||||
const int maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
const int maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
|
||||
const int blocks_x = std::min(ceil_div(imgcount, threads), maxGridX);
|
||||
dim3 grid(blocks_x, static_cast<unsigned>(std::min<int64_t>(nbatch, maxGridY)), static_cast<unsigned>(std::min<int64_t>(nInputPlane, maxGridZ)));
|
||||
bool use_int32 = can_use_int32_nchw(
|
||||
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (use_int32) {
|
||||
max_pool_backward_nchw<scalar_t, accscalar_t, int32_t>
|
||||
<<<grid, threads, 0, stream>>>(
|
||||
gradOutput_data,
|
||||
indices_data,
|
||||
static_cast<int32_t>(nbatch),
|
||||
static_cast<int32_t>(nInputPlane),
|
||||
static_cast<int32_t>(inputHeight),
|
||||
static_cast<int32_t>(inputWidth),
|
||||
static_cast<int32_t>(outputHeight),
|
||||
static_cast<int32_t>(outputWidth),
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
gradInput_data);
|
||||
} else {
|
||||
max_pool_backward_nchw<scalar_t, accscalar_t, int64_t>
|
||||
<<<grid, threads, 0, stream>>>(
|
||||
gradOutput_data,
|
||||
indices_data,
|
||||
nbatch,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
gradInput_data);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
break;
|
||||
}
|
||||
|
||||
@ -78,9 +78,18 @@ __global__ void EmbeddingBag_updateOutputKernel_max(
|
||||
scalar_t weightFeatMax = 0;
|
||||
int64_t bag_size_ = 0;
|
||||
int64_t maxWord = -1;
|
||||
|
||||
// Separate validation loop reduces register pressure in the main loop below.
|
||||
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
|
||||
bool has_invalid_index = false;
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
|
||||
}
|
||||
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
|
||||
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
bool pad = (input[emb] == padding_idx);
|
||||
CUDA_KERNEL_ASSERT(input[emb] < numRows);
|
||||
const int64_t weightRow = input[emb] * weight_stride0;
|
||||
scalar_t weightValue = weightFeat[weightRow];
|
||||
if (bag_size_ == 0 || weightValue > weightFeatMax) {
|
||||
@ -129,10 +138,19 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
|
||||
CUDA_KERNEL_ASSERT(end >= begin);
|
||||
accscalar_t weightFeatSum = 0;
|
||||
int64_t bag_size_ = 0;
|
||||
|
||||
// Separate validation loop reduces register pressure in the main loop below.
|
||||
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
|
||||
bool has_invalid_index = false;
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
|
||||
}
|
||||
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
|
||||
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
bool pad = (input_idx == padding_idx);
|
||||
CUDA_KERNEL_ASSERT(0 <= input_idx && input_idx < numRows);
|
||||
const int64_t weightRow = input_idx * weight_stride0;
|
||||
scalar_t weightValue = weightFeat[weightRow];
|
||||
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
|
||||
|
||||
@ -78,9 +78,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const SwizzleType swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const SwizzleType swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
|
||||
@ -740,7 +740,12 @@ _scaled_rowwise_rowwise(
|
||||
TORCH_CHECK_VALUE(scale_a.numel() == mat_a.size(0) && scale_a.scalar_type() == kFloat, "scale_a must have ", mat_a.size(0), " Float elements, got ", scale_a.numel())
|
||||
TORCH_CHECK_VALUE(scale_b.numel() == mat_b.size(1) && scale_b.scalar_type() == kFloat, "scale_b must have ", mat_b.size(1), " Float elements, got ", scale_b.numel())
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
|
||||
// if we have a scale of shape [256, 1] (say), then stride can be [1, 0] - handle this case
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(1) == 1 ||
|
||||
scale_a.size(1) == 1,
|
||||
"expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1)
|
||||
);
|
||||
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
|
||||
|
||||
auto scaling_choice_a = ScalingType::RowWise;
|
||||
|
||||
@ -337,10 +337,6 @@ Tensor _convolution_out(
|
||||
TORCH_CHECK(
|
||||
3 == ndim || 4 == ndim || 5 == ndim,
|
||||
"convolution only supports 3D, 4D, 5D tensor");
|
||||
// get computation format for Conv/TransposedConv
|
||||
bool is_channels_last_suggested =
|
||||
use_channels_last_for_conv(input_r, weight_r);
|
||||
|
||||
Tensor input = input_r, weight = weight_r;
|
||||
// PyTorch does not support ChannelsLast1D case,
|
||||
// thus we need the transformation here
|
||||
@ -348,13 +344,8 @@ Tensor _convolution_out(
|
||||
input = view4d(input_r);
|
||||
weight = view4d(weight_r);
|
||||
}
|
||||
// ensure the input/weight/bias/output are congituous in desired format
|
||||
at::MemoryFormat mfmt = is_channels_last_suggested
|
||||
? get_cl_tag_by_ndim(input.ndimension())
|
||||
: at::MemoryFormat::Contiguous;
|
||||
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
|
||||
input = input.contiguous(mfmt);
|
||||
weight = weight.contiguous(mfmt);
|
||||
// get computation format for Conv/TransposedConv
|
||||
bool is_channels_last_suggested = use_channels_last_for_conv(input, weight);
|
||||
|
||||
auto k = weight.ndimension();
|
||||
if (k == input.ndimension() + 1) {
|
||||
@ -388,6 +379,14 @@ Tensor _convolution_out(
|
||||
expand_param_if_needed(output_padding_, "output_padding", dim);
|
||||
params.groups = groups_;
|
||||
}
|
||||
|
||||
// ensure the input/weight/bias/output are congituous in desired format
|
||||
at::MemoryFormat mfmt = is_channels_last_suggested
|
||||
? get_cl_tag_by_ndim(input.ndimension())
|
||||
: at::MemoryFormat::Contiguous;
|
||||
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
|
||||
input = input.contiguous(mfmt);
|
||||
weight = weight.contiguous(mfmt);
|
||||
check_shape_forward(input, weight, bias, params, true);
|
||||
|
||||
Tensor output;
|
||||
@ -514,18 +513,9 @@ Tensor convolution_overrideable(
|
||||
at::borrow_from_optional_tensor(bias_r_opt);
|
||||
const Tensor& bias_r = *bias_r_maybe_owned;
|
||||
|
||||
auto k = weight_r.ndimension();
|
||||
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
|
||||
if (xpu_conv_use_channels_last(input_r, weight_r)) {
|
||||
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d
|
||||
: at::MemoryFormat::ChannelsLast;
|
||||
}
|
||||
Tensor input_c = input_r.contiguous(backend_memory_format);
|
||||
Tensor weight_c = weight_r.contiguous(backend_memory_format);
|
||||
|
||||
return _convolution(
|
||||
input_c,
|
||||
weight_c,
|
||||
input_r,
|
||||
weight_r,
|
||||
bias_r,
|
||||
stride_,
|
||||
padding_,
|
||||
|
||||
342
aten/src/ATen/native/mkldnn/xpu/ScaledBlas.cpp
Normal file
342
aten/src/ATen/native/mkldnn/xpu/ScaledBlas.cpp
Normal file
@ -0,0 +1,342 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/xpu/Blas.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace {
|
||||
/*
|
||||
* Scaling Type Determination:
|
||||
* ---------------------------
|
||||
* Conditions and corresponding Scaling Types:
|
||||
*
|
||||
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
|
||||
* - Returns BlockWise (with additional size checks).
|
||||
*
|
||||
* - Else if scale.numel() == 1:
|
||||
* - Returns TensorWise.
|
||||
*
|
||||
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) ==
|
||||
* 1:
|
||||
* - Returns RowWise.
|
||||
*
|
||||
* - Otherwise:
|
||||
* - Returns Error.
|
||||
*/
|
||||
|
||||
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return at::isFloat8Type(t.scalar_type()) &&
|
||||
scale.scalar_type() == at::kFloat && scale.numel() == 1;
|
||||
}
|
||||
|
||||
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return (
|
||||
at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat &&
|
||||
scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 &&
|
||||
scale.is_contiguous());
|
||||
}
|
||||
|
||||
bool is_desired_scaling(
|
||||
const at::Tensor& t,
|
||||
const at::Tensor& scale,
|
||||
ScalingType desired_scaling) {
|
||||
auto result = desired_scaling == ScalingType::TensorWise
|
||||
? is_tensorwise_scaling(t, scale)
|
||||
: is_rowwise_scaling(t, scale);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<ScalingType, ScalingType> get_joint_scaling(
|
||||
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
|
||||
const at::Tensor& a,
|
||||
const at::Tensor& b,
|
||||
const at::Tensor& scale_a,
|
||||
const at::Tensor& scale_b) {
|
||||
for (auto [lhs, rhs] : options) {
|
||||
if (is_desired_scaling(a, scale_a, lhs) &&
|
||||
is_desired_scaling(b.t(), scale_b.t(), rhs)) {
|
||||
return {lhs, rhs};
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Invalid scaling configuration.\n"
|
||||
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
|
||||
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
|
||||
a.size(0),
|
||||
", 1) and scale_b should be (1, ",
|
||||
b.size(1),
|
||||
"), and both should be contiguous.\n"
|
||||
"Got a.dtype()=",
|
||||
a.scalar_type(),
|
||||
", scale_a.dtype()=",
|
||||
scale_a.scalar_type(),
|
||||
", scale_a.size()=",
|
||||
scale_a.sizes(),
|
||||
", scale_a.stride()=",
|
||||
scale_a.strides(),
|
||||
", ",
|
||||
"b.dtype()=",
|
||||
b.scalar_type(),
|
||||
", scale_b.dtype()=",
|
||||
scale_b.scalar_type(),
|
||||
", scale_b.size()=",
|
||||
scale_b.sizes(),
|
||||
" and scale_b.stride()=",
|
||||
scale_b.strides());
|
||||
}
|
||||
|
||||
Tensor& _scaled_gemm(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const ScalingType scaling_choice_a,
|
||||
const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out,
|
||||
const std::optional<Tensor>& alpha = std::nullopt) {
|
||||
// TODO: scale_result and alpha is not defined or used!
|
||||
std::optional<Tensor> scaled_result = std::nullopt;
|
||||
at::native::onednn::scaled_matmul(
|
||||
mat1,
|
||||
mat2,
|
||||
out,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
scaled_result,
|
||||
use_fast_accum);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output
|
||||
// matrices Scales are only applicable when matrices are of Float8 type and
|
||||
// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit
|
||||
// type, scale_result is not applied. Known limitations:
|
||||
// - Only works if mat1 is row-major and mat2 is column-major
|
||||
// - Only works if matrices sizes are divisible by 32
|
||||
// - If 1-dimensional tensors are used then scale_a should be size =
|
||||
// mat1.size(0)
|
||||
// and scale_b should have size = to mat2.size(1)
|
||||
// Arguments:
|
||||
// - `mat1`: the first operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `mat2`: the second operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
|
||||
// - `out_dtype`: the output dtype, can either be a float8 or a higher
|
||||
// precision floating point type
|
||||
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_result`: a scalar tensor with the scale of the output, only
|
||||
// utilized if the output is a float8 type
|
||||
// - `use_fast_accum`: Not applicable for XPU. For now, it should always be
|
||||
// false.
|
||||
// - `out`: a reference to the output tensor
|
||||
|
||||
Tensor& _scaled_mm_out_xpu(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// Note: fast_accum is not supported in XPU for now.
|
||||
TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now.");
|
||||
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0],
|
||||
"mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0],
|
||||
"x",
|
||||
mat1.sizes()[1],
|
||||
" and ",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
")");
|
||||
|
||||
// Check what type of scaling we are doing based on inputs. This list is
|
||||
// sorted by decreasing priority.
|
||||
|
||||
// List of supported datatypes for XPU with oneDNN:
|
||||
// https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types
|
||||
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
|
||||
{
|
||||
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
|
||||
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
|
||||
},
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b);
|
||||
TORCH_CHECK(
|
||||
!scale_result ||
|
||||
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
|
||||
"scale_result must be a float scalar");
|
||||
TORCH_CHECK(
|
||||
!bias || bias->numel() == mat2.sizes()[1],
|
||||
"Bias must be size ",
|
||||
mat2.sizes()[1],
|
||||
" but got ",
|
||||
bias->numel());
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat1.sizes()[0],
|
||||
"x",
|
||||
mat1.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK(
|
||||
mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0,
|
||||
"mat2 shape (",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
") must be divisible by 16");
|
||||
// Check types
|
||||
TORCH_CHECK(
|
||||
!out_dtype || *out_dtype == out.scalar_type(),
|
||||
"out_dtype must match output matrix type");
|
||||
TORCH_CHECK(
|
||||
at::isFloat8Type(mat1.scalar_type()),
|
||||
"Expected mat1 to be Float8 matrix got ",
|
||||
mat1.scalar_type());
|
||||
TORCH_CHECK(
|
||||
at::isFloat8Type(mat2.scalar_type()),
|
||||
"Expected mat2 to be Float8 matrix got ",
|
||||
mat2.scalar_type());
|
||||
// TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not
|
||||
// support 2D scales, only 1D. Needs to add more checks there.
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(
|
||||
bias->scalar_type() == kFloat ||
|
||||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
|
||||
bias->scalar_type() == c10::ScalarType::Half,
|
||||
"Bias must be Float32 or BFloat16 or Half, but got ",
|
||||
bias->scalar_type());
|
||||
}
|
||||
|
||||
{
|
||||
auto bias_ = bias.value_or(Tensor());
|
||||
auto scale_result_ = scale_result.value_or(Tensor());
|
||||
|
||||
// NOLINTNEXTLINE(*c-array*)
|
||||
TensorArg targs[]{
|
||||
{out, "out", 0},
|
||||
{mat1, "mat1", 1},
|
||||
{mat2, "mat2", 2},
|
||||
{bias_, "bias", 3},
|
||||
{scale_a, "scale_a", 4},
|
||||
{scale_b, "scale_b", 5},
|
||||
{scale_result_, "scale_result", 6}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
}
|
||||
|
||||
// Validation checks have passed lets resize the output to actual size
|
||||
IntArrayRef mat1_sizes = mat1.sizes();
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
|
||||
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
|
||||
// kernels do not support this case).
|
||||
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
|
||||
// `out` was created with `at::empty`. In the case where we are multiplying
|
||||
// MxK by KxN and K is the zero dim, we need to initialize here to properly
|
||||
// return a tensor of zeros.
|
||||
if (mat1_sizes[1] == 0) {
|
||||
out.zero_();
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// TODO: Scale_result is not supported by now!!
|
||||
return _scaled_gemm(
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
Tensor _scaled_mm_xpu(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
|
||||
return _scaled_mm_out_xpu(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias,
|
||||
scale_result,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
@ -1,3 +1,4 @@
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
@ -8,7 +9,6 @@
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
|
||||
namespace at::native::onednn {
|
||||
|
||||
at::Tensor broadcast_bias2D(
|
||||
at::Tensor& dst,
|
||||
at::Tensor& bias,
|
||||
@ -328,4 +328,236 @@ void quantized_matmul(
|
||||
result.copy_(dst);
|
||||
}
|
||||
|
||||
// Describes how to configure oneDNN scales for a given role/ScalingType
|
||||
struct ScaleSpec {
|
||||
// specifies the way scale values will be applied to an ARG tensor.
|
||||
int mask;
|
||||
// specifies how scales are grouped along dimensions where
|
||||
// multiple scale factors are used.
|
||||
dnnl::memory::dims groups;
|
||||
// specifies data type for scale factors.
|
||||
dnnl::memory::data_type dtype;
|
||||
|
||||
// Helper to compute expected number of elements for scale tensors
|
||||
// arg_type: "src" for SRC (groups pattern {1, X}),
|
||||
// "wei" for WEIGHTS (groups pattern {X, 1})
|
||||
int64_t expected_numel(
|
||||
int64_t outer_dim,
|
||||
int64_t inner_dim,
|
||||
const std::string& arg_type) const {
|
||||
if (groups == dnnl::memory::dims{1, 1})
|
||||
return 1; // tensorwise scaling
|
||||
|
||||
TORCH_CHECK(
|
||||
arg_type == "src" || arg_type == "wei",
|
||||
"Expected arg_type to be 'src' or 'wei', but got '",
|
||||
arg_type,
|
||||
"'");
|
||||
|
||||
// For rowwise: SRC groups={1, K}, WEI groups={K, 1}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(groups == dnnl::memory::dims{1, inner_dim} ||
|
||||
groups == dnnl::memory::dims{inner_dim, 1}),
|
||||
"The groups must be either {1, inner_dim} or {inner_dim, 1}. But got ",
|
||||
groups,
|
||||
".");
|
||||
return outer_dim;
|
||||
}
|
||||
|
||||
// Normalize an incoming scale tensor to contiguous storage and appropriate
|
||||
// dtype/view
|
||||
at::Tensor normalize(const at::Tensor& scale) const {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dtype == dnnl::memory::data_type::f32,
|
||||
"tensor scale currently must be f32, but got scale dtype: ",
|
||||
scale.scalar_type());
|
||||
return scale.to(at::kFloat).contiguous();
|
||||
}
|
||||
};
|
||||
|
||||
// This function defines how to set scales mask and groups according to:
|
||||
// https://github.com/uxlfoundation/oneDNN/blob/main/tests/benchdnn/doc/knobs_attr.md#--attr-scales
|
||||
// The returned value will be used in
|
||||
// `set_scales(arg, mask, groups, data_type)`.
|
||||
inline ScaleSpec make_scale_spec(
|
||||
at::blas::ScalingType scaling_type,
|
||||
int64_t M,
|
||||
int64_t K,
|
||||
int64_t N,
|
||||
const std::string& arg_type) {
|
||||
TORCH_CHECK(
|
||||
arg_type == "src" || arg_type == "wei",
|
||||
"Expected arg_type to be 'src' or 'wei', but got '",
|
||||
arg_type,
|
||||
"'");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(scaling_type == at::blas::ScalingType::TensorWise ||
|
||||
scaling_type == at::blas::ScalingType::RowWise),
|
||||
"Currently only support scaling_type for TensorWise or RowWise");
|
||||
int64_t dim = K; // Currently only K is used for grouping
|
||||
bool is_src = (arg_type == "src");
|
||||
if (scaling_type == at::blas::ScalingType::TensorWise) {
|
||||
// Scale tensorwise. The same as `--attr-scales=common`.
|
||||
// mask=0 : scale whole tensor
|
||||
// groups={1, 1}: indicates that there is only one group for scaling
|
||||
return {0, {1, 1}, dnnl::memory::data_type::f32};
|
||||
} else {
|
||||
// (scaling_type == at::blas::ScalingType::RowWise)
|
||||
// Scale RowWise. The same as `--attr-scales=per_dim_01`.
|
||||
// mask={(1 << 0) | (1 << 1)}: Scale on both dim0 and dim1
|
||||
// SRC: groups={1, K}, WEIGHTS: groups={K, 1}
|
||||
return {
|
||||
(1 << 0) | (1 << 1),
|
||||
is_src ? dnnl::memory::dims{1, dim} : dnnl::memory::dims{dim, 1},
|
||||
dnnl::memory::data_type::f32};
|
||||
}
|
||||
}
|
||||
|
||||
sycl::event scaled_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
Tensor& result,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
at::blas::ScalingType scaling_choice_a,
|
||||
at::blas::ScalingType scaling_choice_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
bool use_fast_accum) {
|
||||
auto& engine = GpuEngineManager::Instance().get_engine();
|
||||
auto& stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
// This function will do steps with following steps
|
||||
// 1. create memory descriptor
|
||||
// 2. call write_to_dnnl_memory() to actually write memory
|
||||
// 3. execute
|
||||
|
||||
const int64_t M = mat1.size(0);
|
||||
const int64_t K = mat1.size(1);
|
||||
const int64_t N = mat2.size(1);
|
||||
|
||||
// 1.1 Create memory descriptor
|
||||
dnnl::memory::desc src_md = get_onednn_md(mat1);
|
||||
dnnl::memory::desc weights_md = get_onednn_md(mat2);
|
||||
dnnl::memory::desc dst_md = get_onednn_md(result);
|
||||
|
||||
// scale_a and scale_b has already be checked in `is_desired_scaling()` call.
|
||||
// So we could directly get their memory desc and set later.
|
||||
dnnl::memory::desc scale_a_md = get_onednn_md(scale_a);
|
||||
dnnl::memory::desc scale_b_md = get_onednn_md(scale_b);
|
||||
|
||||
dnnl::memory::desc bias_md;
|
||||
bool with_bias = bias.has_value();
|
||||
at::Tensor possible_reshaped_bias = bias.value_or(at::Tensor());
|
||||
if (with_bias) {
|
||||
if (possible_reshaped_bias.dim() == 1) {
|
||||
possible_reshaped_bias =
|
||||
possible_reshaped_bias.reshape({1, possible_reshaped_bias.size(0)});
|
||||
bias_md = get_onednn_md(possible_reshaped_bias);
|
||||
} else {
|
||||
bias_md = get_onednn_md(possible_reshaped_bias);
|
||||
}
|
||||
}
|
||||
|
||||
// 1.2 Create primitive descriptor and set scales mask
|
||||
const ScaleSpec src_spec = make_scale_spec(scaling_choice_a, M, K, N, "src");
|
||||
const ScaleSpec wei_spec = make_scale_spec(scaling_choice_b, M, K, N, "wei");
|
||||
|
||||
dnnl::primitive_attr op_attr = dnnl::primitive_attr();
|
||||
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn())
|
||||
op_attr.set_deterministic(true);
|
||||
#endif
|
||||
|
||||
std::vector<int64_t> default_groups;
|
||||
op_attr.set_scales(
|
||||
DNNL_ARG_SRC, src_spec.mask, src_spec.groups, src_spec.dtype);
|
||||
op_attr.set_scales(
|
||||
DNNL_ARG_WEIGHTS, wei_spec.mask, wei_spec.groups, wei_spec.dtype);
|
||||
// scale_result tensor currently only supports scalar(TensorWise Scaling).
|
||||
bool with_dst_scale = scale_result && scale_result->defined();
|
||||
if (with_dst_scale) {
|
||||
op_attr.set_scales(DNNL_ARG_DST, 0, {1}, dnnl::memory::data_type::f32);
|
||||
}
|
||||
|
||||
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
// 1.3 Create the matmul primitive descriptor
|
||||
dnnl::matmul::primitive_desc matmul_pd = with_bias
|
||||
? dnnl::matmul::primitive_desc(
|
||||
engine, src_md, weights_md, bias_md, dst_md, op_attr)
|
||||
: dnnl::matmul::primitive_desc(
|
||||
engine, src_md, weights_md, dst_md, op_attr);
|
||||
|
||||
// 1.4 (Possible) Additional Checks
|
||||
// TODO: In case there are memory desc does not align with the actual tensor,
|
||||
// we might need to reorder weights similar to CPU's reorder_if_differ_in()
|
||||
// call. For example, weights not the same as matmul_pd.weights_desc(),
|
||||
|
||||
// 2. Prepare memory
|
||||
|
||||
// Create memory
|
||||
auto src_usr_m = make_onednn_memory(src_md, engine, mat1.data_ptr());
|
||||
auto weights_usr_m = make_onednn_memory(weights_md, engine, mat2.data_ptr());
|
||||
auto dst_usr_m = make_onednn_memory(dst_md, engine, result.data_ptr());
|
||||
dnnl::memory b_usr_m;
|
||||
if (with_bias) {
|
||||
b_usr_m =
|
||||
make_onednn_memory(bias_md, engine, possible_reshaped_bias.data_ptr());
|
||||
}
|
||||
|
||||
// Prepare runtime scale memories (flat 1-D views) using the specs
|
||||
auto make_scale_mem_from_spec = [&](const ScaleSpec& spec,
|
||||
int64_t expected_numel,
|
||||
const at::Tensor& scale_tensor) {
|
||||
at::Tensor prepared = spec.normalize(scale_tensor);
|
||||
TORCH_CHECK(
|
||||
prepared.numel() == expected_numel,
|
||||
"Scale buffer length mismatch. Expected ",
|
||||
expected_numel,
|
||||
", got ",
|
||||
prepared.numel());
|
||||
dnnl::memory::desc scale_md(
|
||||
{prepared.numel()}, spec.dtype, dnnl::memory::format_tag::x);
|
||||
return make_onednn_memory(scale_md, engine, prepared.data_ptr());
|
||||
};
|
||||
|
||||
auto scratchpad =
|
||||
make_onednn_memory(matmul_pd.scratchpad_desc(), engine, nullptr);
|
||||
|
||||
// 3. Setup Args for exec
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
args.insert({DNNL_ARG_SRC, src_usr_m});
|
||||
args.insert({DNNL_ARG_WEIGHTS, weights_usr_m});
|
||||
args.insert({DNNL_ARG_DST, dst_usr_m});
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
|
||||
if (with_bias) {
|
||||
args.insert({DNNL_ARG_BIAS, b_usr_m});
|
||||
}
|
||||
|
||||
// Attach runtime scales using specs
|
||||
auto src_sc_mem = make_scale_mem_from_spec(
|
||||
src_spec, src_spec.expected_numel(M, K, "src"), scale_a);
|
||||
auto wei_sc_mem = make_scale_mem_from_spec(
|
||||
wei_spec, wei_spec.expected_numel(N, K, "wei"), scale_b);
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_mem});
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_sc_mem});
|
||||
if (with_dst_scale) {
|
||||
// Bind single f32 scalar as DST scale
|
||||
at::Tensor dst_scale_f32 = scale_result->to(at::kFloat).contiguous();
|
||||
dnnl::memory::desc dst_sc_md(
|
||||
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
|
||||
auto dst_sc_mem =
|
||||
make_onednn_memory(dst_sc_md, engine, dst_scale_f32.data_ptr());
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_mem});
|
||||
}
|
||||
|
||||
dnnl::matmul matmul_p = dnnl::matmul(matmul_pd);
|
||||
sycl::event matmul_fwd_event =
|
||||
dnnl::sycl_interop::execute(matmul_p, stream, args);
|
||||
return matmul_fwd_event;
|
||||
}
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
||||
@ -78,6 +78,10 @@ dnnl::memory::data_type get_onednn_dtype(
|
||||
return dnnl::memory::data_type::f32;
|
||||
case at::ScalarType::BFloat16:
|
||||
return dnnl::memory::data_type::bf16;
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
return dnnl::memory::data_type::f8_e4m3;
|
||||
case at::ScalarType::Float8_e5m2:
|
||||
return dnnl::memory::data_type::f8_e5m2;
|
||||
default:
|
||||
if (!allow_undef) {
|
||||
TORCH_CHECK(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
@ -202,4 +203,16 @@ void sdpa_backward(
|
||||
Tensor& grad_query,
|
||||
Tensor& grad_key,
|
||||
Tensor& grad_value);
|
||||
|
||||
sycl::event scaled_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
Tensor& result,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
at::blas::ScalingType scaling_choice_a,
|
||||
at::blas::ScalingType scaling_choice_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
bool use_fast_accum);
|
||||
} // namespace at::native::onednn
|
||||
|
||||
@ -82,6 +82,7 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
|
||||
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
|
||||
std::string getMPSShapeString(MPSShape* shape);
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
|
||||
std::string to_hex_key(float);
|
||||
std::string getArrayRefString(const IntArrayRef s);
|
||||
// use has_storage() on the returned tensor to determine if src actually is a view
|
||||
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
|
||||
|
||||
@ -301,6 +301,10 @@ std::string getArrayRefString(const IntArrayRef s) {
|
||||
return fmt::to_string(fmt::join(s, ","));
|
||||
}
|
||||
|
||||
std::string to_hex_key(float f) {
|
||||
return fmt::format("{:a}", f);
|
||||
}
|
||||
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
|
||||
fmt::basic_memory_buffer<char, 100> buffer;
|
||||
auto buf_iterator = std::back_inserter(buffer);
|
||||
|
||||
@ -40,7 +40,7 @@ inline c10::metal::opmath_t<T> matmul_inner(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint k = 0; k < TILE_DIM; k++) {
|
||||
sum += A_tile[tid.y][k] * B_tile[k][tid.x];
|
||||
sum += c10::metal::mul(A_tile[tid.y][k], B_tile[k][tid.x]);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -96,7 +96,9 @@ kernel void addmm(
|
||||
auto bias =
|
||||
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
|
||||
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
|
||||
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
|
||||
static_cast<T>(
|
||||
c10::metal::mul(alpha_beta[0], sum) +
|
||||
c10::metal::mul(alpha_beta[1], bias));
|
||||
}
|
||||
}
|
||||
|
||||
@ -832,6 +834,10 @@ INSTANTIATE_MM_OPS(float);
|
||||
INSTANTIATE_MM_OPS(half);
|
||||
INSTANTIATE_MM_OPS(bfloat);
|
||||
|
||||
// Complex MM
|
||||
INSTANTIATE_MM_OPS(float2);
|
||||
INSTANTIATE_MM_OPS(half2);
|
||||
|
||||
// Integral MM
|
||||
INSTANTIATE_MM_OPS(long);
|
||||
INSTANTIATE_MM_OPS(int);
|
||||
|
||||
@ -121,7 +121,7 @@ Tensor& do_metal_addmm(const Tensor& self,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta,
|
||||
const Tensor& bias) {
|
||||
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
|
||||
if (beta.isFloatingPoint() && alpha.isFloatingPoint() && beta.toDouble() == 0 && alpha.toDouble() == 1) {
|
||||
return do_metal_mm(self, other, output);
|
||||
}
|
||||
auto stream = getCurrentMPSStream();
|
||||
@ -147,13 +147,15 @@ Tensor& do_metal_addmm(const Tensor& self,
|
||||
std::array<int64_t, 2> i64;
|
||||
std::array<int32_t, 2> i32;
|
||||
std::array<float, 2> f32;
|
||||
} alpha_beta;
|
||||
std::array<c10::complex<float>, 2> c64;
|
||||
} alpha_beta{};
|
||||
if (output.scalar_type() == kLong) {
|
||||
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
|
||||
} else if (c10::isIntegralType(output.scalar_type(), true)) {
|
||||
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
|
||||
} else if (c10::isComplexType(output.scalar_type())) {
|
||||
alpha_beta.c64 = {alpha.toComplexFloat(), beta.toComplexFloat()};
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
|
||||
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
|
||||
}
|
||||
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
|
||||
@ -190,10 +192,16 @@ std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* gr
|
||||
bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
static bool always_use_metal = c10::utils::has_env("PYTORCH_MPS_PREFER_METAL");
|
||||
constexpr auto max_stride_size = 32768;
|
||||
constexpr auto max_complex_inner_size = 2048;
|
||||
static bool is_macos_14_4_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
|
||||
if (always_use_metal || c10::isIntegralType(self.scalar_type(), true)) {
|
||||
return true;
|
||||
}
|
||||
// multiplicationWithPrimaryTensor: returns incorrect results if inner size exceeds 2048
|
||||
// See https://github.com/pytorch/pytorch/issues/167727#issuecomment-3529308548
|
||||
if (c10::isComplexType(self.scalar_type()) && self.size(1) > max_complex_inner_size) {
|
||||
return true;
|
||||
}
|
||||
return !is_macos_14_4_or_newer &&
|
||||
(self.stride(0) > max_stride_size || self.stride(1) > max_stride_size || self.size(0) > max_stride_size ||
|
||||
self.size(1) > max_stride_size || other.stride(0) > max_stride_size || other.stride(1) > max_stride_size ||
|
||||
|
||||
@ -244,8 +244,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar
|
||||
|
||||
@ -189,6 +189,10 @@ skip:
|
||||
- hf_Whisper
|
||||
- hf_distil_whisper
|
||||
- timm_vision_transformer_large
|
||||
# https://github.com/pytorch/pytorch/issues/167895
|
||||
- stable_diffusion
|
||||
- stable_diffusion_text_encoder
|
||||
- stable_diffusion_unet
|
||||
|
||||
device:
|
||||
cpu:
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# These load paths point to different files in internal and OSS environment
|
||||
|
||||
load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
load("//tools/build_defs:cell_defs.bzl", "get_fbsource_cell")
|
||||
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
|
||||
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
|
||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||
@ -590,6 +591,9 @@ def pt_operator_query_codegen(
|
||||
pt_allow_forced_schema_registration = True,
|
||||
compatible_with = [],
|
||||
apple_sdks = None):
|
||||
if get_fbsource_cell() == "fbcode":
|
||||
return
|
||||
|
||||
oplist_dir_name = name + "_pt_oplist"
|
||||
|
||||
# @lint-ignore BUCKLINT
|
||||
@ -865,6 +869,9 @@ def define_buck_targets(
|
||||
pt_xplat_cxx_library = fb_xplat_cxx_library,
|
||||
c2_fbandroid_xplat_compiler_flags = [],
|
||||
labels = []):
|
||||
if get_fbsource_cell() == "fbcode":
|
||||
return
|
||||
|
||||
# @lint-ignore BUCKLINT
|
||||
fb_native.filegroup(
|
||||
name = "metal_build_srcs",
|
||||
|
||||
@ -34,20 +34,6 @@ namespace c10 {
|
||||
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
|
||||
// regarding macros.
|
||||
|
||||
template <typename T>
|
||||
struct CppTypeToScalarType;
|
||||
|
||||
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct CppTypeToScalarType<cpp_type> \
|
||||
: std:: \
|
||||
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
|
||||
#undef SPECIALIZE_CppTypeToScalarType
|
||||
|
||||
#define DEFINE_CONSTANT(_, name) \
|
||||
constexpr ScalarType k##name = ScalarType::name;
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/FileSystem.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <c10/util/Type.h>
|
||||
|
||||
@ -27,7 +28,7 @@ Error::Error(
|
||||
const void* caller)
|
||||
: Error(
|
||||
str("[enforce fail at ",
|
||||
detail::StripBasename(file),
|
||||
c10::filesystem::path(file).filename(),
|
||||
":",
|
||||
line,
|
||||
"] ",
|
||||
|
||||
@ -379,7 +379,11 @@ C10_API std::string GetExceptionString(const std::exception& e);
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#ifdef STRIP_ERROR_MESSAGES
|
||||
#define TORCH_RETHROW(e, ...) throw
|
||||
#define TORCH_RETHROW(e, ...) \
|
||||
do { \
|
||||
(void)e; /* Suppress unused variable warning */ \
|
||||
throw; \
|
||||
} while (false)
|
||||
#else
|
||||
#define TORCH_RETHROW(e, ...) \
|
||||
do { \
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
#include <c10/util/Backtrace.h>
|
||||
#include <c10/util/FileSystem.h>
|
||||
#include <c10/util/Flags.h>
|
||||
#include <c10/util/Lazy.h>
|
||||
#include <c10/util/Logging.h>
|
||||
@ -478,8 +479,7 @@ MessageLogger::MessageLogger(
|
||||
<< std::setfill('0') << " " << std::setw(2) << timeinfo->tm_hour
|
||||
<< ":" << std::setw(2) << timeinfo->tm_min << ":" << std::setw(2)
|
||||
<< timeinfo->tm_sec << "." << std::setw(9) << ns << " "
|
||||
<< c10::detail::StripBasename(std::string(file)) << ":" << line
|
||||
<< "] ";
|
||||
<< c10::filesystem::path(file).filename() << ":" << line << "] ";
|
||||
}
|
||||
|
||||
// Output the contents of the stream to the proper channel on destruction.
|
||||
|
||||
@ -734,7 +734,7 @@ void PyTorchStreamWriter::setup(const string& file_name) {
|
||||
file_name,
|
||||
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary
|
||||
);
|
||||
} catch (const std::ios_base::failure& e) {
|
||||
} catch (const std::ios_base::failure&) {
|
||||
#ifdef _WIN32
|
||||
// Windows have verbose error code, we prefer to use it than std errno.
|
||||
uint32_t error_code = GetLastError();
|
||||
|
||||
@ -118,6 +118,11 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
|
||||
endif()
|
||||
endif()
|
||||
if("${_arch}" STREQUAL "121a")
|
||||
if(_existing_arch_flags MATCHES ".*compute_120.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
list(JOIN _file_compile_flags " " _file_compile_flags)
|
||||
|
||||
@ -126,7 +131,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
|
||||
"89;90a;100a;103a;120a")
|
||||
"89;90a;100a;103a;120a;121a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
|
||||
"90a")
|
||||
|
||||
@ -10,7 +10,7 @@ API. This API can roughly be divided into five parts:
|
||||
- **TorchScript**: An interface to the TorchScript JIT compiler and interpreter.
|
||||
- **C++ Extensions**: A means of extending the Python API with custom C++ and CUDA routines.
|
||||
|
||||
Combining, these building blocks form a research and
|
||||
Combined, these building blocks form a research and
|
||||
production ready C++ library for tensor computation and dynamic neural
|
||||
networks with strong emphasis on GPU acceleration as well as fast CPU
|
||||
performance. It is currently in use at Facebook in research and
|
||||
@ -76,7 +76,7 @@ C++ Frontend
|
||||
------------
|
||||
|
||||
The PyTorch C++ frontend provides a high level, pure C++ modeling interface for
|
||||
neural network and general ML(Machine Learning) research and production use cases,
|
||||
neural networks and general ML (Machine Learning) research and production use cases,
|
||||
largely following the Python API in design and provided functionality. The C++
|
||||
frontend includes the following:
|
||||
|
||||
|
||||
113
docs/source/accelerator/device.md
Normal file
113
docs/source/accelerator/device.md
Normal file
@ -0,0 +1,113 @@
|
||||
# Device Management
|
||||
|
||||
## Background
|
||||
|
||||
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
|
||||
|
||||
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
|
||||
|
||||
## Design
|
||||
|
||||
Accelerator vendors need to implement these core functions:
|
||||
|
||||
| Function Name | Description | Application Scenarios |
|
||||
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
|
||||
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
|
||||
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
|
||||
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
|
||||
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
|
||||
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
|
||||
|
||||
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
|
||||
|
||||
## Implementation
|
||||
|
||||
This section shows how to implement device management using `set_device` as an example. The implementation requires:
|
||||
1. C++ wrappers around the device runtime
|
||||
2. Python bindings to expose the C++ functions
|
||||
3. User-friendly Python APIs
|
||||
|
||||
### C++ Side
|
||||
|
||||
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
|
||||
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
|
||||
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
|
||||
### Binding
|
||||
|
||||
Expose the C++ functions to Python using pybind11:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
|
||||
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
|
||||
:linenos:
|
||||
```
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
|
||||
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
|
||||
:linenos:
|
||||
:emphasize-lines: 5
|
||||
```
|
||||
|
||||
### Python Side
|
||||
|
||||
Wrap the C++ bindings with user-friendly Python functions:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
|
||||
:language: python
|
||||
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
|
||||
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
|
||||
Here's the complete mapping from C++ to Python:
|
||||
|
||||
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
|
||||
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
|
||||
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
|
||||
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
|
||||
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
|
||||
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
|
||||
|
||||
## Guard
|
||||
|
||||
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
|
||||
|
||||
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
:linenos:
|
||||
```
|
||||
|
||||
**What needs to be implemented:**
|
||||
|
||||
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
|
||||
2. **getDevice()**: Get the current device
|
||||
3. **setDevice()**: Set the active device
|
||||
4. **Type checking**: Validate that device type matches the backend
|
||||
|
||||
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
|
||||
|
||||
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"
|
||||
@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
|
||||
:glob:
|
||||
:maxdepth: 1
|
||||
|
||||
device
|
||||
hooks
|
||||
autoload
|
||||
operators
|
||||
|
||||
@ -254,7 +254,7 @@ To toggle the reduced precision reduction flags in C++, one can do
|
||||
|
||||
.. _fp16accumulation:
|
||||
|
||||
Full FP16 Accmumulation in FP16 GEMMs
|
||||
Full FP16 Accumulation in FP16 GEMMs
|
||||
-------------------------------------
|
||||
|
||||
Certain GPUs have increased performance when doing _all_ FP16 GEMM accumulation
|
||||
|
||||
@ -30,5 +30,6 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
|
||||
skip_guard_on_all_nn_modules_unsafe
|
||||
keep_tensor_guards_unsafe
|
||||
skip_guard_on_globals_unsafe
|
||||
skip_all_guards_unsafe
|
||||
nested_compile_region
|
||||
```
|
||||
|
||||
@ -24,15 +24,11 @@ def gen_data(special_op_lists, analysis_name):
|
||||
all_ops = get_ops_for_key(None)
|
||||
composite_ops = get_ops_for_key("CompositeImplicitAutograd")
|
||||
noncomposite_ops = all_ops - composite_ops
|
||||
with open("../../aten/src/ATen/native/native_functions.yaml") as f:
|
||||
ops = yaml.load(f.read(), Loader=yaml.CLoader)
|
||||
|
||||
ops = yaml.load(
|
||||
open("../../aten/src/ATen/native/native_functions.yaml").read(),
|
||||
Loader=yaml.CLoader,
|
||||
)
|
||||
|
||||
annotated_ops = {
|
||||
a.strip(): b.strip() for a, b in list(csv.reader(open("annotated_ops")))
|
||||
}
|
||||
with open("annotated_ops") as f:
|
||||
annotated_ops = {a.strip(): b.strip() for a, b in csv.reader(f)}
|
||||
|
||||
uniq_ops = []
|
||||
uniq_names = set()
|
||||
|
||||
@ -376,3 +376,19 @@ keep-runtime-typing = true
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words = "tools/linter/dictionary.txt"
|
||||
|
||||
[tool.spin]
|
||||
package = 'torch'
|
||||
|
||||
[tool.spin.commands]
|
||||
"Build" = [
|
||||
".spin/cmds.py:lint",
|
||||
".spin/cmds.py:fixlint",
|
||||
".spin/cmds.py:quicklint",
|
||||
".spin/cmds.py:quickfix",
|
||||
]
|
||||
"Regenerate" = [
|
||||
".spin/cmds.py:regenerate_version",
|
||||
".spin/cmds.py:regenerate_type_stubs",
|
||||
".spin/cmds.py:regenerate_clangtidy_files",
|
||||
]
|
||||
|
||||
@ -32,7 +32,7 @@ project-excludes = [
|
||||
"torch/utils/tensorboard/summary.py",
|
||||
# formatting issues, will turn on after adjusting where suppressions can be
|
||||
# in import statements
|
||||
"tools/flight_recorder/components/types.py",
|
||||
"torch/distributed/flight_recorder/components/types.py",
|
||||
"torch/linalg/__init__.py",
|
||||
"torch/package/importer.py",
|
||||
"torch/package/_package_pickler.py",
|
||||
|
||||
@ -14,6 +14,7 @@ lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
|
||||
networkx>=2.5.1
|
||||
optree>=0.13.0
|
||||
psutil
|
||||
spin
|
||||
sympy>=1.13.3
|
||||
typing-extensions>=4.13.2
|
||||
wheel
|
||||
|
||||
49
setup.py
49
setup.py
@ -1358,45 +1358,6 @@ class concat_license_files:
|
||||
|
||||
# Need to create the proper LICENSE.txt for the wheel
|
||||
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
|
||||
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
|
||||
|
||||
Excludes:
|
||||
- torch/include/torch/headeronly/*
|
||||
- torch/include/torch/csrc/stable/*
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/generated/
|
||||
"""
|
||||
header_extensions = (".h", ".hpp", ".cuh")
|
||||
header_files = [
|
||||
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
|
||||
]
|
||||
|
||||
# Paths to exclude from wrapping
|
||||
exclude_dir_patterns = [
|
||||
"torch/include/torch/headeronly/",
|
||||
"torch/include/torch/csrc/stable/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/c/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
|
||||
]
|
||||
|
||||
for header_file in header_files:
|
||||
rel_path = header_file.relative_to(bdist_dir).as_posix()
|
||||
|
||||
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
|
||||
report(f"Skipping header: {rel_path}")
|
||||
continue
|
||||
|
||||
original_content = header_file.read_text(encoding="utf-8")
|
||||
wrapped_content = (
|
||||
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
f"{original_content}"
|
||||
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
)
|
||||
|
||||
header_file.write_text(wrapped_content, encoding="utf-8")
|
||||
report(f"Wrapped header: {rel_path}")
|
||||
|
||||
def run(self) -> None:
|
||||
with concat_license_files(include_files=True):
|
||||
super().run()
|
||||
@ -1419,14 +1380,6 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
# need an __init__.py file otherwise we wouldn't have a package
|
||||
(bdist_dir / "torch" / "__init__.py").touch()
|
||||
|
||||
# Wrap all header files with TORCH_STABLE_ONLY macro
|
||||
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
|
||||
bdist_dir = Path(self.bdist_dir)
|
||||
report(
|
||||
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
|
||||
)
|
||||
self._wrap_headers_with_macro(bdist_dir)
|
||||
|
||||
|
||||
class clean(Command):
|
||||
user_options: ClassVar[list[tuple[str, str | None, str]]] = []
|
||||
@ -1632,7 +1585,7 @@ def configure_extension_build() -> tuple[
|
||||
if cmake_cache_vars["USE_DISTRIBUTED"]:
|
||||
# Only enable fr_trace command if distributed is enabled
|
||||
entry_points["console_scripts"].append(
|
||||
"torchfrtrace = tools.flight_recorder.fr_trace:main",
|
||||
"torchfrtrace = torch.distributed.flight_recorder.fr_trace:main",
|
||||
)
|
||||
return ext_modules, cmdclass, packages, entry_points, extra_install_requires
|
||||
|
||||
|
||||
@ -13,6 +13,17 @@ TEST(TestScalarType, ScalarTypeToCPPTypeT) {
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
TEST(TestScalarType, CppTypeToScalarType) {
|
||||
using torch::headeronly::CppTypeToScalarType;
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
EXPECT_EQ(CppTypeToScalarType<TYPE>::value, ScalarType::SCALARTYPE);
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
{ \
|
||||
EXPECT_EQ( \
|
||||
|
||||
@ -634,3 +634,38 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_parallel_for", &boxed_test_parallel_for);
|
||||
m.impl("test_get_num_threads", &boxed_test_get_num_threads);
|
||||
}
|
||||
|
||||
Tensor my_empty(
|
||||
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
|
||||
std::optional<torch::headeronly::ScalarType> dtype,
|
||||
std::optional<torch::stable::Device> device,
|
||||
std::optional<bool> pin_memory) {
|
||||
return empty(size, dtype, device, pin_memory);
|
||||
}
|
||||
|
||||
Tensor my_flatten(Tensor t, int64_t start_dim, int64_t end_dim) {
|
||||
return flatten(t, start_dim, end_dim);
|
||||
}
|
||||
|
||||
Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) {
|
||||
return reshape(t, shape);
|
||||
}
|
||||
|
||||
Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) {
|
||||
return view(t, size);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def(
|
||||
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
|
||||
m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor");
|
||||
m.def("my_reshape(Tensor t, int[] shape) -> Tensor");
|
||||
m.def("my_view(Tensor t, int[] size) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_empty", TORCH_BOX(&my_empty));
|
||||
m.impl("my_flatten", TORCH_BOX(&my_flatten));
|
||||
m.impl("my_reshape", TORCH_BOX(&my_reshape));
|
||||
m.impl("my_view", TORCH_BOX(&my_view));
|
||||
}
|
||||
|
||||
@ -487,3 +487,58 @@ def test_get_num_threads() -> int:
|
||||
Returns: int - the number of threads for the parallel backend
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_get_num_threads.default()
|
||||
|
||||
|
||||
def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor:
|
||||
"""
|
||||
Creates an empty tensor with the specified size, dtype, device, and pin_memory.
|
||||
|
||||
Args:
|
||||
size: list[int] - size of the tensor to create
|
||||
dtype: ScalarType or None - data type of the tensor
|
||||
device: Device or None - device on which to create the tensor
|
||||
pin_memory: bool or None - whether to use pinned memory
|
||||
|
||||
Returns: Tensor - an uninitialized tensor with the specified properties
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_empty.default(size, dtype, device, pin_memory)
|
||||
|
||||
|
||||
def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor:
|
||||
"""
|
||||
Flattens the input tensor from start_dim to end_dim into a single dimension.
|
||||
|
||||
Args:
|
||||
t: Tensor - tensor to flatten
|
||||
start_dim: int - first dimension to flatten (default: 0)
|
||||
end_dim: int - last dimension to flatten (default: -1)
|
||||
|
||||
Returns: Tensor - flattened tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_flatten.default(t, start_dim, end_dim)
|
||||
|
||||
|
||||
def my_reshape(t, shape) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with the same data but different shape.
|
||||
|
||||
Args:
|
||||
t: Tensor - tensor to reshape
|
||||
shape: list[int] - new shape for the tensor
|
||||
|
||||
Returns: Tensor - reshaped tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_reshape.default(t, shape)
|
||||
|
||||
|
||||
def my_view(t, size) -> Tensor:
|
||||
"""
|
||||
Returns a new tensor with the same data as the input tensor but of a different shape.
|
||||
|
||||
Args:
|
||||
t: Tensor - tensor to view
|
||||
size: list[int] - new size for the tensor
|
||||
|
||||
Returns: Tensor - tensor with new view
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_view.default(t, size)
|
||||
|
||||
@ -33,7 +33,7 @@ class clean(distutils.command.clean.clean):
|
||||
|
||||
def get_extension():
|
||||
extra_compile_args = {
|
||||
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
|
||||
"cxx": ["-fdiagnostics-color=always"],
|
||||
}
|
||||
|
||||
extension = CppExtension
|
||||
|
||||
@ -525,6 +525,97 @@ if not IS_WINDOWS:
|
||||
expected_num_threads = torch.get_num_threads()
|
||||
self.assertEqual(num_threads, expected_num_threads)
|
||||
|
||||
def test_my_empty(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
deterministic = torch.are_deterministic_algorithms_enabled()
|
||||
try:
|
||||
# set use_deterministic_algorithms to fill uninitialized memory
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
size = [2, 3]
|
||||
result = libtorch_agnostic.ops.my_empty(size, None, None, None)
|
||||
expected = torch.empty(size)
|
||||
self.assertEqual(result, expected, exact_device=True)
|
||||
|
||||
result_float = libtorch_agnostic.ops.my_empty(
|
||||
size, torch.float32, None, None
|
||||
)
|
||||
expected_float = torch.empty(size, dtype=torch.float32)
|
||||
self.assertEqual(result_float, expected_float, exact_device=True)
|
||||
|
||||
result_with_device = libtorch_agnostic.ops.my_empty(
|
||||
size, torch.float64, device, None
|
||||
)
|
||||
expected_with_device = torch.empty(
|
||||
size, dtype=torch.float64, device=device
|
||||
)
|
||||
self.assertEqual(
|
||||
result_with_device, expected_with_device, exact_device=True
|
||||
)
|
||||
|
||||
if device == "cuda":
|
||||
result_pinned = libtorch_agnostic.ops.my_empty(
|
||||
size, torch.float32, "cpu", True
|
||||
)
|
||||
expected_pinned = torch.empty(
|
||||
size, dtype=torch.float32, device="cpu", pin_memory=True
|
||||
)
|
||||
self.assertEqual(result_pinned, expected_pinned)
|
||||
self.assertTrue(result_pinned.is_pinned())
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(deterministic)
|
||||
|
||||
def test_my_flatten(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(2, 3, 4, device=device)
|
||||
result = libtorch_agnostic.ops.my_flatten(t)
|
||||
expected = torch.flatten(t)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
result_start = libtorch_agnostic.ops.my_flatten(t, 1)
|
||||
expected_start = torch.flatten(t, 1)
|
||||
self.assertEqual(result_start, expected_start)
|
||||
|
||||
result_range = libtorch_agnostic.ops.my_flatten(t, 2, -1)
|
||||
expected_range = torch.flatten(t, 2, -1)
|
||||
self.assertEqual(result_range, expected_range)
|
||||
|
||||
def test_my_reshape(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(2, 3, 4, device=device)
|
||||
|
||||
result = libtorch_agnostic.ops.my_reshape(t, [6, 4])
|
||||
expected = torch.reshape(t, [6, 4])
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
result_infer = libtorch_agnostic.ops.my_reshape(t, [-1, 4])
|
||||
expected_infer = torch.reshape(t, [-1, 4])
|
||||
self.assertEqual(result_infer, expected_infer)
|
||||
|
||||
result_flat = libtorch_agnostic.ops.my_reshape(t, [-1])
|
||||
expected_flat = torch.reshape(t, [-1])
|
||||
self.assertEqual(result_flat, expected_flat)
|
||||
|
||||
def test_my_view(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(2, 3, 4, device=device)
|
||||
|
||||
result = libtorch_agnostic.ops.my_view(t, [6, 4])
|
||||
expected = t.view([6, 4])
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
result_infer = libtorch_agnostic.ops.my_view(t, [-1, 4])
|
||||
expected_infer = t.view([-1, 4])
|
||||
self.assertEqual(result_infer, expected_infer)
|
||||
|
||||
result_flat = libtorch_agnostic.ops.my_view(t, [-1])
|
||||
expected_flat = t.view([-1])
|
||||
self.assertEqual(result_flat, expected_flat)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -4,17 +4,12 @@
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
void orCheckFail(
|
||||
const char* func,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
const char* msg = "");
|
||||
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
|
||||
|
||||
#define OPENREG_CHECK(EXPR, ...) \
|
||||
do { \
|
||||
const orError_t __err = EXPR; \
|
||||
if (__err != orSuccess) { \
|
||||
orCheckFail( \
|
||||
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
||||
} \
|
||||
#define OPENREG_CHECK(EXPR, ...) \
|
||||
do { \
|
||||
const orError_t __err = EXPR; \
|
||||
if (C10_UNLIKELY(__err != orSuccess)) { \
|
||||
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <include/openreg.h>
|
||||
|
||||
#include "OpenRegException.h"
|
||||
@ -9,21 +10,22 @@ orError_t GetDeviceCount(int* dev_count) {
|
||||
return orGetDeviceCount(dev_count);
|
||||
}
|
||||
|
||||
orError_t GetDevice(c10::DeviceIndex* device) {
|
||||
orError_t GetDevice(DeviceIndex* device) {
|
||||
int tmp_device = -1;
|
||||
auto err = orGetDevice(&tmp_device);
|
||||
*device = static_cast<c10::DeviceIndex>(tmp_device);
|
||||
*device = static_cast<DeviceIndex>(tmp_device);
|
||||
return err;
|
||||
}
|
||||
|
||||
orError_t SetDevice(c10::DeviceIndex device) {
|
||||
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
|
||||
orError_t SetDevice(DeviceIndex device) {
|
||||
int cur_device = -1;
|
||||
orGetDevice(&cur_device);
|
||||
OPENREG_CHECK(orGetDevice(&cur_device));
|
||||
if (device == cur_device) {
|
||||
return orSuccess;
|
||||
}
|
||||
return orSetDevice(device);
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
|
||||
|
||||
int device_count_impl() {
|
||||
int count = 0;
|
||||
@ -31,34 +33,37 @@ int device_count_impl() {
|
||||
return count;
|
||||
}
|
||||
|
||||
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
|
||||
OPENREG_EXPORT DeviceIndex device_count() noexcept {
|
||||
// initialize number of devices only once
|
||||
static int count = []() {
|
||||
try {
|
||||
auto result = device_count_impl();
|
||||
TORCH_CHECK(
|
||||
result <= std::numeric_limits<c10::DeviceIndex>::max(),
|
||||
result <= std::numeric_limits<DeviceIndex>::max(),
|
||||
"Too many devices, DeviceIndex overflowed");
|
||||
return result;
|
||||
} catch (const c10::Error& ex) {
|
||||
} catch (const Error& ex) {
|
||||
// We don't want to fail, but still log the warning
|
||||
// msg() returns the message without the stack trace
|
||||
TORCH_WARN("Device initialization: ", ex.msg());
|
||||
return 0;
|
||||
}
|
||||
}();
|
||||
return static_cast<c10::DeviceIndex>(count);
|
||||
return static_cast<DeviceIndex>(count);
|
||||
}
|
||||
|
||||
OPENREG_EXPORT c10::DeviceIndex current_device() {
|
||||
c10::DeviceIndex cur_device = -1;
|
||||
GetDevice(&cur_device);
|
||||
OPENREG_EXPORT DeviceIndex current_device() {
|
||||
DeviceIndex cur_device = -1;
|
||||
OPENREG_CHECK(GetDevice(&cur_device));
|
||||
return cur_device;
|
||||
}
|
||||
|
||||
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
|
||||
SetDevice(device);
|
||||
// LITERALINCLUDE START: OPENREG set_device FUNCTION
|
||||
OPENREG_EXPORT void set_device(DeviceIndex device) {
|
||||
check_device_index(device);
|
||||
OPENREG_CHECK(SetDevice(device));
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG set_device FUNCTION
|
||||
|
||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||
int current_device = -1;
|
||||
@ -71,4 +76,8 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||
return current_device;
|
||||
}
|
||||
|
||||
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
|
||||
check_device_index(to_device);
|
||||
return ExchangeDevice(to_device);
|
||||
}
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -9,10 +9,20 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
|
||||
OPENREG_EXPORT c10::DeviceIndex current_device();
|
||||
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
|
||||
OPENREG_EXPORT DeviceIndex device_count() noexcept;
|
||||
OPENREG_EXPORT DeviceIndex current_device();
|
||||
OPENREG_EXPORT void set_device(DeviceIndex device);
|
||||
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
|
||||
|
||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
|
||||
|
||||
static inline void check_device_index(int64_t device) {
|
||||
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
|
||||
"The device index is out of range. It must be in [0, ",
|
||||
static_cast<int>(c10::openreg::device_count()),
|
||||
"), but got ",
|
||||
static_cast<int>(device),
|
||||
".");
|
||||
}
|
||||
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
|
||||
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
|
||||
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
|
||||
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
// LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
|
||||
|
||||
@ -58,6 +59,7 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
|
||||
set_device(d.index());
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
|
||||
/**
|
||||
* Set the current device to c10::Device, without checking for errors
|
||||
|
||||
@ -27,6 +27,10 @@ class TestDevice(TestCase):
|
||||
self.assertEqual(torch.accelerator.current_device_index(), 1)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), device)
|
||||
|
||||
def test_invalid_device_index(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
|
||||
torch.accelerator.set_device_index(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -34,18 +34,21 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
|
||||
|
||||
// LITERALINCLUDE START: MODULE SET DEVICE HELPER
|
||||
|
||||
PyObject* _setDevice(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
|
||||
auto device = THPUtils_unpackLong(arg);
|
||||
|
||||
auto device = THPUtils_unpackDeviceIndex(arg);
|
||||
torch::utils::device_lazy_init(at::kPrivateUse1);
|
||||
c10::openreg::set_device(static_cast<c10::DeviceIndex>(device));
|
||||
c10::openreg::set_device(device);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// LITERALINCLUDE END: MODULE SET DEVICE HELPER
|
||||
|
||||
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");
|
||||
|
||||
@ -41,8 +41,13 @@ def current_device():
|
||||
return torch_openreg._C._get_device()
|
||||
|
||||
|
||||
# LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
|
||||
def set_device(device) -> None:
|
||||
return torch_openreg._C._set_device(device)
|
||||
if device >= 0:
|
||||
torch_openreg._C._set_device(device)
|
||||
|
||||
|
||||
# LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
|
||||
|
||||
|
||||
def init():
|
||||
|
||||
67
test/cpp_extensions/torch_stable_test_extension/setup.py
Normal file
67
test/cpp_extensions/torch_stable_test_extension/setup.py
Normal file
@ -0,0 +1,67 @@
|
||||
import distutils.command.clean
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent
|
||||
CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc"
|
||||
|
||||
|
||||
class clean(distutils.command.clean.clean):
|
||||
def run(self):
|
||||
# Run default behavior first
|
||||
distutils.command.clean.clean.run(self)
|
||||
|
||||
# Remove extension
|
||||
for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"):
|
||||
path.unlink()
|
||||
# Remove build and dist and egg-info directories
|
||||
dirs = [
|
||||
ROOT_DIR / "build",
|
||||
ROOT_DIR / "dist",
|
||||
ROOT_DIR / "torch_stable_test.egg-info",
|
||||
]
|
||||
for path in dirs:
|
||||
if path.exists():
|
||||
shutil.rmtree(str(path), ignore_errors=True)
|
||||
|
||||
|
||||
def get_extension():
|
||||
extra_compile_args = {
|
||||
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
|
||||
}
|
||||
|
||||
sources = list(CSRC_DIR.glob("**/*.cpp"))
|
||||
|
||||
return [
|
||||
CppExtension(
|
||||
"torch_stable_test._C",
|
||||
sources=sorted(str(s) for s in sources),
|
||||
py_limited_api=True,
|
||||
extra_compile_args=extra_compile_args,
|
||||
extra_link_args=[],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
setup(
|
||||
name="torch_stable_test",
|
||||
version="0.0",
|
||||
author="PyTorch Core Team",
|
||||
description="Test extension to verify TORCH_STABLE_ONLY flag",
|
||||
packages=find_packages(exclude=("test",)),
|
||||
package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]},
|
||||
install_requires=[
|
||||
"torch",
|
||||
],
|
||||
ext_modules=get_extension(),
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
|
||||
"clean": clean,
|
||||
},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
||||
@ -0,0 +1 @@
|
||||
#include <ATen/core/TensorBase.h> // This should trigger the TORCH_STABLE_ONLY error
|
||||
@ -0,0 +1,22 @@
|
||||
# Owner(s): ["module: cpp"]
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
install_cpp_extension,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
if not IS_WINDOWS:
|
||||
|
||||
class TestTorchStable(TestCase):
|
||||
def test_setup_fails(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"):
|
||||
install_cpp_extension(extension_root=Path(__file__).parent.parent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -65,6 +65,7 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
curr_backend = dist.get_default_backend_for_device(device_type)
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
@ -422,10 +423,10 @@ class TestFullyShard2DStateDict(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
# need to specify gloo backend for testing cpu offload
|
||||
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
|
||||
return f"cpu:gloo,{device_type}:{curr_backend}"
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_fully_shard_tp_2d_set_full_state_dict(self):
|
||||
dummy_model = SimpleModel().to(device_type)
|
||||
mesh_2d = init_device_mesh(
|
||||
@ -514,8 +515,8 @@ class Test2dFSDP1ParallelIntegration(DTensorTestBase):
|
||||
).to_local()
|
||||
self.assertEqual(param_m2, param_m1)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_2d_ddp_integration_functionality(self) -> None:
|
||||
model, twod_model, dp_pg = self.init_model(self.device_type)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=3e-5)
|
||||
@ -566,8 +567,8 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||
p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local()
|
||||
self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_2d_fsdp_state_enable_extension(self):
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
||||
@ -642,18 +643,18 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||
# Ensure all params are still the same after optimizer update.
|
||||
self._compare_params(model, model_2d)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_2d_e2e_training_default(self):
|
||||
self._test_2d_e2e_training()
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_2d_e2e_training_use_orig_params(self):
|
||||
self._test_2d_e2e_training(use_orig_params=True)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_2d_e2e_training_not_use_orig_params(self):
|
||||
# TODO: need to revisit input_reshard API about why it failed multi-gpu tests.
|
||||
# self._test_2d_e2e_training(recompute_activation=True)
|
||||
@ -666,10 +667,10 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
# need to specify gloo backend for testing cpu offload
|
||||
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
|
||||
return f"cpu:gloo,{device_type}:{curr_backend}"
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_fsdp_2d_extension(self):
|
||||
"""
|
||||
Test whether _fsdp_extension from FSDPstate has been set correctly.
|
||||
@ -700,8 +701,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
|
||||
self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
@ -756,8 +757,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_load_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
@ -811,8 +812,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
self.assertEqual(v1.device_mesh, v2.device_mesh)
|
||||
self.assertEqual(v1.placements, v2.placements)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_optim_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
@ -899,9 +900,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
else:
|
||||
self.assertEqual(new_state, state)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_fsdp1_tp_2d_set_full_state_dict(self):
|
||||
"""
|
||||
This is a workaround for loading full state dict into a FSDP1+TP 2D model.
|
||||
|
||||
@ -29,8 +29,8 @@ from torch.distributed.tensor.parallel import (
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
at_least_x_gpu,
|
||||
MultiProcessTestCase,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
@ -40,7 +40,6 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
@ -107,11 +106,9 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
def device(self):
|
||||
return self.rank
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
def test_pp_and_dcp(self):
|
||||
"""
|
||||
Test that pipeline parallelism and distributed checkpointing can be used together and
|
||||
@ -201,11 +198,9 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
|
||||
_dcp_test(self)
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
@ -355,11 +350,9 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
@ -550,11 +543,9 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
@ -18,8 +17,8 @@ from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||
)
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
requires_nccl,
|
||||
DistributedTestBase,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
@ -30,9 +29,12 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
def gpus_for_rank(world_size):
|
||||
visible_devices = list(range(torch.cuda.device_count()))
|
||||
gpus_per_process = torch.cuda.device_count() // world_size
|
||||
visible_devices = list(range(torch.accelerator.device_count()))
|
||||
gpus_per_process = torch.accelerator.device_count() // world_size
|
||||
gpus_for_rank = []
|
||||
for rank in range(world_size):
|
||||
gpus_for_rank.append(
|
||||
@ -60,27 +62,7 @@ class TestDdpCommHook(nn.Module):
|
||||
return self.t0(x ** (1 + rank))
|
||||
|
||||
|
||||
class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _get_process_group_nccl(self):
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
return dist.distributed_c10d._get_default_group()
|
||||
|
||||
class DistributedDataParallelCommHookTest(DistributedTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
@ -119,14 +101,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
param = next(model.parameters())
|
||||
return param.grad
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_allreduce_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``allreduce`` hook registered case gives same result
|
||||
with no hook registered case.
|
||||
"""
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -135,14 +117,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_fp16compress_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``fp16 compress`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -151,14 +133,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_quantize_per_tensor_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``quantize per tensor`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -167,14 +149,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_quantize_per_channel_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``quantize per channel`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -185,14 +167,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_noop_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``noop`` hook registered case and a subsequent allreduce
|
||||
gives same result with no hook registered case.
|
||||
"""
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -204,10 +186,10 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_is_last_hook(self):
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
def hook(flags, bucket):
|
||||
flags.append(bucket.is_last())
|
||||
|
||||
@ -32,7 +32,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
class TestStateDictUtils(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.accelerator.device_count())
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@ -49,7 +49,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
|
||||
)
|
||||
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
|
||||
self.assertTrue(gathered_state_dict["dtensor"].is_cuda)
|
||||
self.assertEqual(gathered_state_dict["dtensor"].device.type, self.device_type)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@ -69,14 +69,16 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
)
|
||||
if dist.get_rank() in (0, 2):
|
||||
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
|
||||
self.assertFalse(gathered_state_dict["dtensor"].is_cuda)
|
||||
self.assertNotEqual(
|
||||
gathered_state_dict["dtensor"].device.type, self.device_type
|
||||
)
|
||||
else:
|
||||
self.assertEqual(gathered_state_dict, {})
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_cpu_and_ranks_only(self):
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(self.device_type)
|
||||
state_dict = {
|
||||
"tensor1": torch.arange(10, device=device),
|
||||
"tensor2": torch.ones(10, device=device),
|
||||
@ -85,7 +87,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
cpu_state_dict = _offload_state_dict_to_cpu(state_dict, ranks_only=(0, 2))
|
||||
if dist.get_rank() in (0, 2):
|
||||
for v in cpu_state_dict.values():
|
||||
self.assertFalse(v.is_cuda)
|
||||
self.assertNotEqual(v.device.type, self.device_type)
|
||||
self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
|
||||
self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
|
||||
else:
|
||||
@ -109,27 +111,27 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
for _ in range(10):
|
||||
tensor, dtensor = create_dtensor()
|
||||
ltensor.append(tensor)
|
||||
ltensor.append(torch.ones(10, device=torch.device("cuda")))
|
||||
ltensor.append(torch.ones(10, device=torch.device(self.device_type)))
|
||||
ldtensor.append(dtensor)
|
||||
ldtensor.append(torch.ones(10, device=torch.device("cuda")))
|
||||
ldtensor.append(torch.ones(10, device=torch.device(self.device_type)))
|
||||
|
||||
tensor, dtensor = create_dtensor()
|
||||
dist_state_dict = {
|
||||
"local": dtensor,
|
||||
"list": ldtensor,
|
||||
"arange": torch.arange(10, device=torch.device("cuda")),
|
||||
"arange": torch.arange(10, device=torch.device(self.device_type)),
|
||||
}
|
||||
state_dict = {
|
||||
"local": tensor,
|
||||
"list": ltensor,
|
||||
"arange": torch.arange(10, device=torch.device("cuda")),
|
||||
"arange": torch.arange(10, device=torch.device(self.device_type)),
|
||||
}
|
||||
self.assertEqual(state_dict, _gather_state_dict(dist_state_dict))
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_create_cpu_state_dict(self):
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(self.device_type)
|
||||
rank = dist.get_rank()
|
||||
# Scale tensors based on world size
|
||||
# to fit in the tensor shards accurately.
|
||||
@ -149,7 +151,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
metadata=ShardMetadata(
|
||||
shard_offsets=[5 * rank, 0],
|
||||
shard_sizes=[5, 10],
|
||||
placement=f"rank:{rank}/cuda:{rank}",
|
||||
placement=f"rank:{rank}/{self.device_type}:{rank}",
|
||||
),
|
||||
)
|
||||
],
|
||||
@ -159,7 +161,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
torch.arange(50 * scale_factor, device=device).reshape(
|
||||
5 * scale_factor, 10
|
||||
),
|
||||
init_device_mesh("cuda", mesh_shape=(self.world_size,)),
|
||||
init_device_mesh(self.device_type, mesh_shape=(self.world_size,)),
|
||||
[Shard(0)],
|
||||
),
|
||||
"non_tensor_bytes_io": copy.deepcopy(buffer),
|
||||
@ -245,7 +247,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
even_tensor = torch.randn(self.world_size, 2)
|
||||
uneven_tensor = torch.randn(1, 2)
|
||||
|
||||
mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
|
||||
mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
even_dtensor = distribute_tensor(
|
||||
torch.randn(self.world_size, 2), mesh, [Shard(0)]
|
||||
)
|
||||
@ -273,10 +275,10 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_cpu_offload_for_dtensor(self):
|
||||
device_mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
sd = {
|
||||
"k": DTensor.from_local(
|
||||
torch.ones(8, 8, device="cuda"), device_mesh, [Shard(0)]
|
||||
torch.ones(8, 8, device=self.device_type), device_mesh, [Shard(0)]
|
||||
)
|
||||
}
|
||||
cpu_sd = _create_cpu_state_dict(sd)
|
||||
@ -290,12 +292,12 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
|
||||
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
_copy_state_dict(sd, cpu_sd, non_blocking=True)
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
sd["k"] += 1
|
||||
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
_copy_state_dict(sd, cpu_sd, non_blocking=True)
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
|
||||
|
||||
|
||||
@ -743,16 +743,19 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_binary_duplicate_log_filters(self):
|
||||
envs = {0: {"RANK": "0"}, 1: {"RANK": "1"}}
|
||||
logs_specs = DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
redirects={0: Std.ERR, 1: Std.NONE},
|
||||
tee={0: Std.OUT, 1: Std.ERR},
|
||||
)
|
||||
logs_dest = logs_specs.reify(envs)
|
||||
pc = start_processes(
|
||||
name="trainer",
|
||||
entrypoint=bin("echo1.py"),
|
||||
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
redirects={0: Std.ERR, 1: Std.NONE},
|
||||
tee={0: Std.OUT, 1: Std.ERR},
|
||||
),
|
||||
envs=envs,
|
||||
logs_specs=logs_specs,
|
||||
log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"},
|
||||
duplicate_stdout_filters=["helloA"],
|
||||
duplicate_stderr_filters=["worldA", "B"],
|
||||
@ -762,12 +765,18 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
result = pc.wait()
|
||||
|
||||
self.assertFalse(result.is_failed())
|
||||
self.assert_in_file(["[rank0]:helloA stdout from 0"], pc.filtered_stdout)
|
||||
self.assert_not_in_file(
|
||||
["[rank0]:helloB stdout from 0"], pc.filtered_stdout
|
||||
self.assert_in_file(
|
||||
["[rank0]:helloA stdout from 0"], logs_dest.filtered_stdout
|
||||
)
|
||||
self.assert_not_in_file(
|
||||
["[rank0]:helloB stdout from 0"], logs_dest.filtered_stdout
|
||||
)
|
||||
self.assert_in_file(
|
||||
["[rank1]:worldA stderr from 1"], logs_dest.filtered_stderr
|
||||
)
|
||||
self.assert_in_file(
|
||||
["[rank1]:worldB stderr from 1"], logs_dest.filtered_stderr
|
||||
)
|
||||
self.assert_in_file(["[rank1]:worldA stderr from 1"], pc.filtered_stderr)
|
||||
self.assert_in_file(["[rank1]:worldB stderr from 1"], pc.filtered_stderr)
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
@ -838,16 +847,19 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
||||
def test_function_duplicate_log_filters(self):
|
||||
for start_method in self._start_methods:
|
||||
with self.subTest(start_method=start_method):
|
||||
envs = {0: {"RANK": "0"}, 1: {"RANK": "1"}}
|
||||
logs_specs = DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
redirects={0: Std.ERR, 1: Std.NONE},
|
||||
tee={0: Std.OUT, 1: Std.ERR},
|
||||
)
|
||||
logs_dest = logs_specs.reify(envs)
|
||||
pc = start_processes(
|
||||
name="trainer",
|
||||
entrypoint=echo1,
|
||||
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
redirects={0: Std.ERR, 1: Std.NONE},
|
||||
tee={0: Std.OUT, 1: Std.ERR},
|
||||
),
|
||||
envs=envs,
|
||||
logs_specs=logs_specs,
|
||||
duplicate_stdout_filters=["helloA"],
|
||||
duplicate_stderr_filters=["worldA", "B"],
|
||||
start_method="spawn",
|
||||
@ -857,16 +869,16 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
||||
|
||||
self.assertFalse(result.is_failed())
|
||||
self.assert_in_file(
|
||||
["[trainer0]:helloA stdout from 0"], pc.filtered_stdout
|
||||
["[trainer0]:helloA stdout from 0"], logs_dest.filtered_stdout
|
||||
)
|
||||
self.assert_not_in_file(
|
||||
["[trainer0]:helloB stdout from 0"], pc.filtered_stdout
|
||||
["[trainer0]:helloB stdout from 0"], logs_dest.filtered_stdout
|
||||
)
|
||||
self.assert_in_file(
|
||||
["[trainer1]:worldA stderr from 1"], pc.filtered_stderr
|
||||
["[trainer1]:worldA stderr from 1"], logs_dest.filtered_stderr
|
||||
)
|
||||
self.assert_in_file(
|
||||
["[trainer1]:worldB stderr from 1"], pc.filtered_stderr
|
||||
["[trainer1]:worldB stderr from 1"], logs_dest.filtered_stderr
|
||||
)
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
@ -225,9 +225,11 @@ class ApiTest(unittest.TestCase):
|
||||
raise_child_failure_error_fn("trainer", trainer_error_file)
|
||||
pf = cm.exception.get_first_failure()[1]
|
||||
# compare worker error file with reply file and overridden error code
|
||||
expect = json.load(open(pf.error_file))
|
||||
with open(pf.error_file) as f:
|
||||
expect = json.load(f)
|
||||
expect["message"]["errorCode"] = pf.exitcode
|
||||
actual = json.load(open(self.test_error_file))
|
||||
with open(self.test_error_file) as f:
|
||||
actual = json.load(f)
|
||||
self.assertTrue(
|
||||
json.dumps(expect, sort_keys=True),
|
||||
json.dumps(actual, sort_keys=True),
|
||||
|
||||
@ -100,8 +100,9 @@ class TailLogTest(unittest.TestCase):
|
||||
}
|
||||
|
||||
dst = os.path.join(self.test_dir, "tailed_stdout.log")
|
||||
dst_file = open(dst, "w", buffering=1)
|
||||
tail = TailLog(
|
||||
name="writer", log_files=log_files, dst=dst, interval_sec=interval_sec
|
||||
name="writer", log_files=log_files, dst=dst_file, interval_sec=interval_sec
|
||||
).start()
|
||||
# sleep here is intentional to ensure that the log tail
|
||||
# can gracefully handle and wait for non-existent log files
|
||||
@ -117,10 +118,11 @@ class TailLogTest(unittest.TestCase):
|
||||
wait(futs, return_when=ALL_COMPLETED)
|
||||
self.assertFalse(tail.stopped())
|
||||
tail.stop()
|
||||
dst_file.close()
|
||||
|
||||
actual: dict[int, set[int]] = {}
|
||||
with open(dst) as dst_file:
|
||||
for line in dst_file:
|
||||
with open(dst) as read_dst_file:
|
||||
for line in read_dst_file:
|
||||
header, num = line.split(":")
|
||||
nums = actual.setdefault(header, set())
|
||||
nums.add(int(num))
|
||||
@ -256,4 +258,4 @@ class TailLogTest(unittest.TestCase):
|
||||
tail = TailLog("writer", log_files={0: self.test_dir}, dst=sys.stdout).start()
|
||||
tail.stop()
|
||||
|
||||
mock_logger.error.assert_called_once()
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
@ -2,23 +2,16 @@
|
||||
|
||||
import copy
|
||||
import math
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
||||
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
from tools.flight_recorder.components.builder import build_db
|
||||
from tools.flight_recorder.components.config_manager import JobConfig
|
||||
from tools.flight_recorder.components.types import COLLECTIVES, MatchInfo, MatchState
|
||||
from tools.flight_recorder.components.utils import match_one_event
|
||||
|
||||
|
||||
# Make sure to remove REPO_ROOT after import is done
|
||||
sys.path.remove(str(REPO_ROOT))
|
||||
|
||||
from torch.distributed.flight_recorder.components.builder import build_db
|
||||
from torch.distributed.flight_recorder.components.config_manager import JobConfig
|
||||
from torch.distributed.flight_recorder.components.types import (
|
||||
COLLECTIVES,
|
||||
MatchInfo,
|
||||
MatchState,
|
||||
)
|
||||
from torch.distributed.flight_recorder.components.utils import match_one_event
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
@ -40,7 +40,6 @@ from torch.testing._internal.common_distributed import (
|
||||
skip_if_rocm_multiprocess,
|
||||
skip_if_win32,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -57,7 +56,17 @@ except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
|
||||
|
||||
device_type = str(get_devtype())
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def deterministic_algorithms(enabled=True):
|
||||
prev_state = torch.are_deterministic_algorithms_enabled()
|
||||
torch.use_deterministic_algorithms(enabled)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(prev_state)
|
||||
|
||||
|
||||
class TestZeroRedundancyOptimizer(DistributedTestBase):
|
||||
@ -1241,7 +1250,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
enabled=True, deterministic=True, benchmark=False
|
||||
)
|
||||
if "cuda" in device
|
||||
else torch.use_deterministic_algorithms(True)
|
||||
else deterministic_algorithms(True)
|
||||
)
|
||||
with det_ctx:
|
||||
device_ids = [rank] if requires_ddp_rank(device) else None
|
||||
|
||||
@ -31,6 +31,8 @@ from torch.utils._debug_mode import (
|
||||
_RedistributeCall,
|
||||
_TritonKernelCall,
|
||||
DebugMode,
|
||||
hash_tensor_fn,
|
||||
norm_hash_fn,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._triton import has_triton_package
|
||||
@ -115,6 +117,28 @@ class TestDTensorDebugMode(TestCase):
|
||||
"aten::sum(t: f32[1, 32]) # {'hash': " in debug_mode.debug_string()
|
||||
)
|
||||
|
||||
# check tuple hash functions
|
||||
with (
|
||||
DebugMode() as debug_mode,
|
||||
DebugMode.log_tensor_hashes(hash_fn=["norm", "hash_tensor"]),
|
||||
):
|
||||
mm(x_dtensor, y_dtensor)
|
||||
|
||||
output_hash = debug_mode.operators[-1].log["hash"]
|
||||
norm_ = lambda x: norm_hash_fn(x, use_scalar=True) # noqa: E731
|
||||
hash_ = lambda x: hash_tensor_fn(x, use_scalar=True) # noqa: E731
|
||||
|
||||
self.assertEqual(output_hash[0], norm_(eager_out))
|
||||
self.assertEqual(output_hash[1], hash_(eager_out))
|
||||
|
||||
# some edge cases
|
||||
self.assertEqual(norm_(torch.tensor(torch.nan)), torch.nan)
|
||||
self.assertEqual(norm_(torch.tensor(torch.inf)), torch.inf)
|
||||
self.assertEqual(norm_(torch.complex(torch.ones(4), torch.zeros(4))), 4)
|
||||
self.assertEqual(hash_(torch.ones(4, dtype=torch.float8_e5m2)), 0)
|
||||
self.assertEqual(hash_(torch.ones(4, dtype=torch.int8)), 0)
|
||||
self.assertEqual(hash_(torch.ones(5, dtype=torch.int8)), 1)
|
||||
|
||||
def test_debug_string_inside_context(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
|
||||
@ -331,6 +331,25 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
|
||||
self.assertEqual(z.placements, (Replicate(),))
|
||||
self.assertEqual(z.to_local(), input)
|
||||
|
||||
def test_inplace_op_partial_to_replicate(self):
|
||||
# test that in-place operations that require redistribution raise an error
|
||||
# to preserve aliasing semantics (issue #163374)
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
input_tensor = torch.tensor(64.0, device=self.device_type)
|
||||
partial_dt = DTensor.from_local(
|
||||
input_tensor, device_mesh, placements=(Partial(),)
|
||||
)
|
||||
|
||||
self.assertTrue(partial_dt.placements[0].is_partial())
|
||||
|
||||
# Inplace ops that require placement changes (Partial -> Replicate) should error
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"in-place operations that require placement changes are not supported",
|
||||
):
|
||||
partial_dt.clamp_(max=10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -664,6 +664,101 @@ class TestViewOps(DTensorTestBase):
|
||||
)
|
||||
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_slice(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked on DTensor when slicing
|
||||
a replicated tensor.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a replicated DTensor
|
||||
tensor = torch.randn(10, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Replicate()])
|
||||
|
||||
# Perform a slice operation [1:]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[1:]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
# Verify that the DTensor's storage_offset matches the expected value
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), 1)
|
||||
|
||||
# Verify that the local tensor also has the correct storage_offset
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 1)
|
||||
|
||||
# Verify the shape is correct
|
||||
self.assertEqual(sliced_dtensor.shape, torch.Size([9]))
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[1:]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_shard_dim0_slice_dim1(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked when tensor is sharded on dim 0
|
||||
and sliced on dim 1.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a 2D tensor and shard on dim 0
|
||||
tensor = torch.randn(12, 8, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
|
||||
|
||||
# Perform a slice operation [:, 2:]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[:, 2:]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
# The storage_offset should be 2 (skipping 2 elements in each row)
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), 2)
|
||||
|
||||
# Verify that the local tensor also has the correct storage_offset
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 2)
|
||||
|
||||
# Verify the shape is correct
|
||||
expected_shape = torch.Size([12, 6])
|
||||
self.assertEqual(sliced_dtensor.shape, expected_shape)
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[:, 2:]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_shard_dim1_slice_dim0(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked when tensor is sharded on dim 1
|
||||
and sliced on dim 0.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a 2D tensor and shard on dim 1
|
||||
tensor = torch.randn(10, 12, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(1)])
|
||||
|
||||
# Perform a slice operation [2:, :]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[2:, :]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
local_dim1_size = 12 // self.world_size
|
||||
expected_offset = 2 * local_dim1_size
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), expected_offset)
|
||||
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), expected_offset)
|
||||
|
||||
# Verify the shape is correct
|
||||
expected_shape = torch.Size([8, 12])
|
||||
self.assertEqual(sliced_dtensor.shape, expected_shape)
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[2:, :]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
|
||||
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
|
||||
TestViewOps,
|
||||
|
||||
@ -24,7 +24,7 @@ from torch.distributed._functional_collectives import (
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
||||
from torch.testing._internal.common_device_type import e4m3_type
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
DistributedTestBase,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
@ -59,12 +59,8 @@ if not dist.is_available():
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
class TestWithNCCL(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
class TestWithNCCL(DistributedTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
@ -78,16 +74,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
return torch.device(self.rank)
|
||||
|
||||
def _init_process_group(self) -> None:
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
backend = dist.get_default_backend_for_device(self.device.type)
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
self.create_pg(self.device.type)
|
||||
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
||||
@ -11,13 +11,10 @@ if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfHpu,
|
||||
TEST_CUDA,
|
||||
TEST_HPU,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
@ -29,16 +26,8 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_HPU:
|
||||
DEVICE = "hpu"
|
||||
elif TEST_CUDA:
|
||||
DEVICE = "cuda"
|
||||
else:
|
||||
DEVICE = "cpu"
|
||||
|
||||
device_module = torch.get_device_module(DEVICE)
|
||||
device_count = device_module.device_count()
|
||||
BACKEND = dist.get_default_backend_for_device(DEVICE)
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
device_count = torch.accelerator.device_count()
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
@ -49,11 +38,10 @@ def with_comms(func=None):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if DEVICE != "cpu" and device_count < self.world_size:
|
||||
if device_type != "cpu" and device_count < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
|
||||
kwargs["device"] = DEVICE
|
||||
self.pg = self.create_pg(device=DEVICE)
|
||||
self.pg = self.create_pg(device=device_type)
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
finally:
|
||||
@ -64,7 +52,7 @@ def with_comms(func=None):
|
||||
|
||||
class TestObjectCollectives(DistributedTestBase):
|
||||
@with_comms()
|
||||
def test_all_gather_object(self, device):
|
||||
def test_all_gather_object(self):
|
||||
output = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(object_list=output, obj=self.rank)
|
||||
|
||||
@ -72,7 +60,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(i, v, f"rank: {self.rank}")
|
||||
|
||||
@with_comms()
|
||||
def test_gather_object(self, device):
|
||||
def test_gather_object(self):
|
||||
output = [None] * dist.get_world_size() if self.rank == 0 else None
|
||||
dist.gather_object(obj=self.rank, object_gather_list=output)
|
||||
|
||||
@ -82,7 +70,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_send_recv_object_list(self, device):
|
||||
def test_send_recv_object_list(self):
|
||||
val = 99 if self.rank == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
if self.rank == 0:
|
||||
@ -96,7 +84,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(None, object_list[0])
|
||||
|
||||
@with_comms()
|
||||
def test_broadcast_object_list(self, device):
|
||||
def test_broadcast_object_list(self):
|
||||
val = 99 if self.rank == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
# TODO test with broadcast_object_list's device argument
|
||||
@ -105,7 +93,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(99, object_list[0])
|
||||
|
||||
@with_comms()
|
||||
def test_scatter_object_list(self, device):
|
||||
def test_scatter_object_list(self):
|
||||
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
|
||||
output_list = [None]
|
||||
dist.scatter_object_list(
|
||||
@ -123,34 +111,30 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
my_pg = dist.new_group(ranks, use_local_synchronization=True)
|
||||
return rank, ranks, my_pg
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_scatter_object(self, device):
|
||||
def test_subpg_scatter_object(self):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None]
|
||||
dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
|
||||
self.assertEqual(rank, out_list[0])
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_all_gather_object(self, device):
|
||||
def test_subpg_all_gather_object(self):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None] * len(ranks)
|
||||
dist.all_gather_object(out_list, rank, group=my_pg)
|
||||
self.assertEqual(ranks, out_list)
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_gather_object(self, device):
|
||||
def test_subpg_gather_object(self):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None] * len(ranks) if rank == ranks[0] else None
|
||||
dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
|
||||
if rank == ranks[0]:
|
||||
self.assertEqual(ranks, out_list)
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_broadcast_object(self, device):
|
||||
def test_subpg_broadcast_object(self):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None]
|
||||
if rank == ranks[0]:
|
||||
@ -159,7 +143,5 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(ranks[0], out_list[0])
|
||||
|
||||
|
||||
devices = ("cpu", "cuda", "hpu")
|
||||
instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -29,7 +29,7 @@ from torch.distributed.tensor._collective_utils import (
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import _Partial, Shard
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_HPU, TEST_XPU, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
@ -58,7 +58,7 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran
|
||||
os.environ["LOCAL_RANK"] = f"{local_rank}"
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.")
|
||||
@unittest.skipIf(TEST_XPU or TEST_HPU, "XPU/HPU does not support gloo backend.")
|
||||
class DeviceMeshTestGlooBackend(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
|
||||
@ -208,6 +208,21 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
|
||||
)
|
||||
self.assertEqual(y, expected)
|
||||
|
||||
def test_get_remote_tensors(self) -> None:
|
||||
"""
|
||||
Get all remote tensors
|
||||
"""
|
||||
self._init_device()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
my_tensor = symm_mem.empty(1, device=self.device).fill_(self.rank)
|
||||
remote_tensors = torch.ops.symm_mem.get_remote_tensors(my_tensor, group_name)
|
||||
dist.barrier()
|
||||
|
||||
for peer, tensor in enumerate(remote_tensors):
|
||||
self.assertEqual(tensor, peer)
|
||||
|
||||
@skipIfRocm
|
||||
def test_nvshmem_put(self) -> None:
|
||||
self._init_device()
|
||||
|
||||
@ -29,6 +29,8 @@ from torch.testing._internal.common_utils import (
|
||||
|
||||
MY_LAMBDA = lambda x: x + 1 # noqa: E731
|
||||
|
||||
EPS = torch.tensor(1e-7)
|
||||
|
||||
|
||||
class CustomCompiledFunction(torch._dynamo.aot_compile.SerializableCallable):
|
||||
def __init__(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]):
|
||||
@ -587,6 +589,18 @@ from user code:
|
||||
actual = compiled_fn(fn, *inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_aot_compile_with_global_tensor(self):
|
||||
def fn(x, y):
|
||||
return x + y + EPS
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch.compile(fn, fullgraph=True).aot_compile((make_inputs(), {}))
|
||||
|
||||
test_inputs = make_inputs()
|
||||
self.assertEqual(compiled_fn(*test_inputs), fn(*test_inputs))
|
||||
|
||||
def test_aot_compile_with_default_args(self):
|
||||
def fn(x, y=1):
|
||||
return x + x
|
||||
|
||||
@ -330,6 +330,13 @@ y = FakeTensor(..., size=(2,))
|
||||
'obj_weakref': None
|
||||
'guarded_class': None
|
||||
}
|
||||
global '' GLOBAL_STATE
|
||||
{
|
||||
'guard_types': None,
|
||||
'code': None,
|
||||
'obj_weakref': None
|
||||
'guarded_class': None
|
||||
}
|
||||
global '' TORCH_FUNCTION_STATE
|
||||
{
|
||||
'guard_types': None,
|
||||
|
||||
@ -952,7 +952,9 @@ User code traceback:
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: skip: from user code at:
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
assert x is None
|
||||
""",
|
||||
@ -1078,6 +1080,88 @@ from user code:
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(verbose=True)
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skipped_frame_with_verbose_traceback(self, records):
|
||||
def fn(x):
|
||||
with GenericCtxMgr():
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break under GenericContextWrappingVariable
|
||||
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
|
||||
Hint: Move the offending context manager(s) to outside the compiled region.
|
||||
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
|
||||
|
||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skip_frame_in_loop_message(self, records):
|
||||
def fn(x):
|
||||
for i in range(2):
|
||||
with GenericCtxMgr():
|
||||
if x.sum() > 0:
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
if x.sum() > 0:
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(dynamo=logging.DEBUG)
|
||||
def test_skip_frame_empty_function_message(self, records):
|
||||
def empty_fn(x):
|
||||
pass
|
||||
|
||||
torch.compile(empty_fn, backend="eager")(torch.randn(3))
|
||||
skip_messages = [
|
||||
r
|
||||
for r in records
|
||||
if "intentionally decided to skip the frame" in r.getMessage()
|
||||
]
|
||||
self.assertEqual(len(skip_messages), 1)
|
||||
msg = munge_exc(skip_messages[0].getMessage(), suppress_suffix=True, skip=0)
|
||||
msg = re.sub(r" (\d+)$", r" N", msg, flags=re.MULTILINE)
|
||||
|
||||
self.assertExpectedInline(
|
||||
msg,
|
||||
"""\
|
||||
Skipping frame torch.compile intentionally decided to skip the frame empty_fn (test_error_messages.py line N) and fall back to eager.
|
||||
Reason: no content in function call empty_fn test_error_messages.py N""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_nested_compile_user_frames(self, records):
|
||||
def fn(x):
|
||||
@ -1624,6 +1708,110 @@ from user code:
|
||||
)
|
||||
|
||||
|
||||
class NestedGraphBreakLoggingTests(
|
||||
LoggingTestCase, torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
|
||||
):
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skipped_frame_with_verbose_traceback_nested(self, records):
|
||||
global f1, f2, f3
|
||||
|
||||
class GenericCtxMgr:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
def f1(x):
|
||||
with GenericCtxMgr():
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
def f2(x):
|
||||
return f1(x + 2)
|
||||
|
||||
def f3(x):
|
||||
return f2(x + 3)
|
||||
|
||||
torch.compile(f3, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break in user code at test_error_messages.py:N
|
||||
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
|
||||
Graph break under GenericContextWrappingVariable
|
||||
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
|
||||
Hint: Move the offending context manager(s) to outside the compiled region.
|
||||
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
|
||||
|
||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
|
||||
User code traceback:
|
||||
File "test_error_messages.py", line N, in test_skipped_frame_with_verbose_traceback_nested
|
||||
torch.compile(f3, backend="eager")(torch.randn(3))
|
||||
File "test_error_messages.py", line N, in f3
|
||||
return f2(x + 3)
|
||||
File "test_error_messages.py", line N, in f2
|
||||
return f1(x + 2)
|
||||
File "test_error_messages.py", line N, in f1
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skip_frame_in_loop_message_nested(self, records):
|
||||
global f1, f2, f3
|
||||
|
||||
class GenericCtxMgr:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
def f1(x):
|
||||
for i in range(2):
|
||||
with GenericCtxMgr():
|
||||
if x.sum() > 0:
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
def f2(x):
|
||||
return f1(x + 4)
|
||||
|
||||
def f3(x):
|
||||
return f2(x + 5)
|
||||
|
||||
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break in user code at test_error_messages.py:N
|
||||
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
|
||||
Data-dependent branching
|
||||
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
|
||||
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
|
||||
Hint: Use `torch.cond` to express dynamic control flow.
|
||||
|
||||
Developer debug context: attempted to jump with TensorVariable()
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
|
||||
User code traceback:
|
||||
File "test_error_messages.py", line N, in test_skip_frame_in_loop_message_nested
|
||||
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
|
||||
File "test_error_messages.py", line N, in f3
|
||||
return f2(x + 5)
|
||||
File "test_error_messages.py", line N, in f2
|
||||
return f1(x + 4)
|
||||
File "test_error_messages.py", line N, in f1
|
||||
if x.sum() > 0:
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
||||
@ -1214,7 +1214,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
with torch.enable_grad():
|
||||
ref, loaded = self._test_serialization("GRAD_MODE", fn, x)
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
with torch.no_grad():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with torch.enable_grad():
|
||||
@ -1226,7 +1226,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
with torch.enable_grad():
|
||||
ref, _ = self._test_serialization("GRAD_MODE", fn, x)
|
||||
ref, _ = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
with torch.no_grad():
|
||||
# Ensure guards state loading is not affected by the current global grad mode.
|
||||
guards_state = pickle.loads(self._cached_guards_state)
|
||||
@ -1246,7 +1246,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
try:
|
||||
x = torch.randn(3, 2)
|
||||
torch.use_deterministic_algorithms(True)
|
||||
ref, loaded = self._test_serialization("DETERMINISTIC_ALGORITHMS", fn, x)
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
torch.use_deterministic_algorithms(False)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
torch.use_deterministic_algorithms(True)
|
||||
@ -1270,6 +1270,9 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with GlobalTorchFunctionMode():
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
with GlobalTorchFunctionMode():
|
||||
with torch._C.DisableTorchFunction():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
@ -1306,7 +1309,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
x = torch.randn(3, 2)
|
||||
|
||||
with torch.enable_grad():
|
||||
ref, loaded = self._test_serialization("FSDP_TRAINING_STATE", fn, x)
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
with torch.no_grad():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with torch.enable_grad():
|
||||
@ -1690,6 +1693,38 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False
|
||||
)
|
||||
|
||||
def test_global_state_guard_filter(self):
|
||||
def foo(x):
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
|
||||
with torch.no_grad():
|
||||
compiled_fn = torch.compile(
|
||||
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
|
||||
)
|
||||
compiled_fn(x)
|
||||
|
||||
# Check global guards are gone.
|
||||
with torch.enable_grad(), torch.compiler.set_stance("fail_on_recompile"):
|
||||
self.assertEqual(compiled_fn(x), foo(x))
|
||||
|
||||
def test_torch_function_state_filter(self):
|
||||
def foo(x):
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
|
||||
with GlobalTorchFunctionMode():
|
||||
compiled_fn = torch.compile(
|
||||
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
|
||||
)
|
||||
compiled_fn(x)
|
||||
|
||||
# Check global guards are gone.
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
self.assertEqual(compiled_fn(x), foo(x))
|
||||
|
||||
|
||||
class SimpleModule(torch.nn.Module):
|
||||
def __init__(self, c):
|
||||
|
||||
@ -131,7 +131,7 @@ def default_args_generator(seed_value):
|
||||
yield new_args
|
||||
|
||||
|
||||
class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
class HigherOrderOpTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
def _assert_wrap_fallback(self, func, args, setup=lambda: None):
|
||||
counters.clear()
|
||||
backend = EagerAndRecordGraphs()
|
||||
@ -3396,7 +3396,9 @@ class GraphModule(torch.nn.Module):
|
||||
fn_with_hints(x, y)
|
||||
|
||||
|
||||
class HigherOrderOpVmapGuardTests(LoggingTestCase):
|
||||
class HigherOrderOpVmapGuardTests(
|
||||
torch._dynamo.test_case.TestCaseWithNestedGraphBreaks, LoggingTestCase
|
||||
):
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_vmap_grad_guard_ok(self, records):
|
||||
vmap = torch.vmap
|
||||
@ -3665,7 +3667,9 @@ class HigherOrderOpVmapGuardTests(LoggingTestCase):
|
||||
self.assertGreater(len(records), 0)
|
||||
|
||||
|
||||
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
class FuncTorchHigherOrderOpTests(
|
||||
torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
|
||||
):
|
||||
def tearDown(self):
|
||||
# Ensure that in the case of a test failure, the next test won't fail
|
||||
# because of a previous call to _vmap_increment_nesting that wasn't undone
|
||||
@ -6782,7 +6786,9 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
||||
class ActivationCheckpointingTests(
|
||||
torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
|
||||
):
|
||||
def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True):
|
||||
cloned_args = []
|
||||
for arg in args:
|
||||
@ -7173,7 +7179,7 @@ xfail_hops_compile = {
|
||||
}
|
||||
|
||||
|
||||
class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCase):
|
||||
class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
@requires_cuda_and_triton
|
||||
@parametrize("backend", ("aot_eager", "inductor"))
|
||||
@ops(
|
||||
|
||||
@ -244,6 +244,61 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
self.assertTrue(same(val4, correct1))
|
||||
self.assertEqual(counter.frame_count, 3)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "cuda needed")
|
||||
def test_assume_32_bit_indexing(self):
|
||||
@torch.compile(backend="inductor")
|
||||
def func(a, b):
|
||||
# Multiple concat operations
|
||||
x = torch.concat([a, b], dim=0)
|
||||
y = torch.concat([a, b], dim=1)
|
||||
|
||||
# Reshape to create indexing patterns
|
||||
x_flat = x.reshape(-1)
|
||||
y_flat = y.reshape(-1)
|
||||
|
||||
# Take the smaller one and expand
|
||||
min_size = min(x_flat.shape[0], y_flat.shape[0])
|
||||
x_trunc = x_flat[:min_size]
|
||||
y_trunc = y_flat[:min_size]
|
||||
|
||||
# Combine and compute
|
||||
result = (x_trunc + y_trunc) * 10
|
||||
|
||||
# Cumulative operations create complex indexing
|
||||
cumsum = result.cumsum(dim=0)
|
||||
|
||||
return cumsum.sum()
|
||||
|
||||
a = torch.rand(100, 30, device="cuda")
|
||||
b = torch.rand(100, 30, device="cuda")
|
||||
|
||||
torch._dynamo.decorators.mark_unbacked(a, 0)
|
||||
torch._dynamo.decorators.mark_unbacked(a, 1)
|
||||
torch._dynamo.decorators.mark_unbacked(b, 0)
|
||||
torch._dynamo.decorators.mark_unbacked(b, 1)
|
||||
|
||||
source_code = run_and_get_code(func, a, b)[1]
|
||||
|
||||
self.assertTrue(
|
||||
"xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64)\\n"
|
||||
in str(source_code)
|
||||
)
|
||||
self.assertFalse(
|
||||
"xindex = xoffset + tl.arange(0, XBLOCK)[:]\\n" in str(source_code)
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
with torch._inductor.config.patch(assume_32bit_indexing=True):
|
||||
source_code = run_and_get_code(func, a, b)[1]
|
||||
self.assertFalse(
|
||||
"xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64)\\n"
|
||||
in str(source_code)
|
||||
)
|
||||
self.assertTrue(
|
||||
"xindex = xoffset + tl.arange(0, XBLOCK)[:]\\n" in str(source_code)
|
||||
)
|
||||
|
||||
def test_dynamo_inside_custom_op(self):
|
||||
cnt = torch._dynamo.testing.InductorAndRecordGraphs()
|
||||
cnt1 = torch._dynamo.testing.InductorAndRecordGraphs()
|
||||
@ -14036,6 +14091,44 @@ class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
|
||||
except Exception as e:
|
||||
self.fail(f"torch.compile failed with error: {e}")
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_tensorify_track_item_symint(self):
|
||||
def _random_resize(image: torch.Tensor):
|
||||
image_metanet = image
|
||||
default_patch_size = 14
|
||||
rand_cnn_resolution = (224, 256)
|
||||
min_nump = rand_cnn_resolution[0] // default_patch_size
|
||||
max_nump = rand_cnn_resolution[1] // default_patch_size
|
||||
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
|
||||
torch._check(new_nump > 0)
|
||||
torch._check(new_nump * default_patch_size > 1)
|
||||
|
||||
image_metanet = F.interpolate(
|
||||
image_metanet,
|
||||
size=(new_nump * default_patch_size, new_nump * default_patch_size),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
img_h_new, img_w_new = image_metanet.shape[2:]
|
||||
|
||||
return (img_h_new, img_w_new), image_metanet
|
||||
|
||||
_random_resize_compiled = torch.compile(fullgraph=True)(_random_resize)
|
||||
|
||||
# Test the function
|
||||
input_tensor = torch.rand(1, 3, 224, 224)
|
||||
(h, w), output = _random_resize_compiled(input_tensor)
|
||||
|
||||
# Verify output properties
|
||||
self.assertEqual(output.shape[0], 1)
|
||||
self.assertEqual(output.shape[1], 3)
|
||||
self.assertEqual(output.shape[2], h)
|
||||
self.assertEqual(output.shape[3], w)
|
||||
self.assertTrue(h % 14 == 0)
|
||||
self.assertTrue(w % 14 == 0)
|
||||
self.assertTrue(224 <= h <= 256)
|
||||
self.assertTrue(224 <= w <= 256)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -874,6 +874,32 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreak
|
||||
self.assertEqual(cnts.frame_count, 8)
|
||||
self.assertEqual(cnts.op_count, 10)
|
||||
|
||||
def test_functorch_with_nested_graph_break(self):
|
||||
def f1(x):
|
||||
x = x * 2
|
||||
torch._dynamo.graph_break()
|
||||
return x * 4
|
||||
|
||||
def f2(x):
|
||||
return (f1(x * 8) * 16).sum()
|
||||
|
||||
def f3(x):
|
||||
return torch.func.grad(f2)(x * 32) * 64
|
||||
|
||||
def f4(x):
|
||||
return f3(x * 128) * 256
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
x = torch.randn(3)
|
||||
actual = f4(x)
|
||||
expected = torch.compile(f4, backend=cnts, fullgraph=False)(x)
|
||||
self.assertEqual(actual, expected)
|
||||
self.assertEqual(len(torch._dynamo.utils.counters["graph_break"]), 1)
|
||||
# f4 + f3, f3 end + f4 end
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
# multiplication by 32, 64, 128, 256
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -562,7 +562,7 @@ class TestDynamoTimed(TestCase):
|
||||
'graph_node_count': 3,
|
||||
'graph_node_shapes': None,
|
||||
'graph_op_count': 1,
|
||||
'guard_count': 9,
|
||||
'guard_count': 10,
|
||||
'has_guarded_code': True,
|
||||
'inductor_code_gen_cumulative_compile_time_us': 0,
|
||||
'inductor_compile_time_s': 0.0,
|
||||
@ -608,7 +608,7 @@ class TestDynamoTimed(TestCase):
|
||||
'tensorify_float_attempt': None,
|
||||
'tensorify_float_failure': None,
|
||||
'tensorify_float_success': None,
|
||||
'triton_compile_time_us': None,
|
||||
'triton_compile_time_us': 0,
|
||||
'triton_kernel_compile_times_us': None,
|
||||
'triton_version': None}"""
|
||||
if _IS_WINDOWS
|
||||
@ -649,7 +649,7 @@ class TestDynamoTimed(TestCase):
|
||||
'graph_node_count': 3,
|
||||
'graph_node_shapes': None,
|
||||
'graph_op_count': 1,
|
||||
'guard_count': 9,
|
||||
'guard_count': 10,
|
||||
'has_guarded_code': True,
|
||||
'inductor_code_gen_cumulative_compile_time_us': 0,
|
||||
'inductor_compile_time_s': 0.0,
|
||||
@ -920,7 +920,7 @@ class TestDynamoTimed(TestCase):
|
||||
first, second = {
|
||||
(3, 9): (10, 6),
|
||||
(3, 10): (10, 6),
|
||||
(3, 11): (10, 6),
|
||||
(3, 11): (11, 7),
|
||||
(3, 12): (11, 7),
|
||||
(3, 13): (11, 7),
|
||||
(3, 14): (11, 7),
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import copy
|
||||
import types
|
||||
import unittest
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
@ -18,6 +19,9 @@ from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import TEST_CUDA
|
||||
|
||||
|
||||
GLOBAL_LIST = []
|
||||
|
||||
|
||||
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
|
||||
class TestExperiment(TestCase):
|
||||
def test_joint_basic(self) -> None:
|
||||
@ -585,9 +589,9 @@ def forward(self, args_0):
|
||||
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
|
||||
L_args_0_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
|
||||
l_args_0_ = L_args_0_
|
||||
add = l_args_0_ + 1
|
||||
add = l_args_0_ + 1; add = None
|
||||
mul = l_args_0_ * 2; l_args_0_ = None
|
||||
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul, add), self._out_spec)""",
|
||||
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul), self._out_spec)""",
|
||||
)
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
|
||||
@ -611,6 +615,34 @@ def forward(self, args_0):
|
||||
self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers())))
|
||||
self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters())))
|
||||
|
||||
def test_dynamo_graph_capture_side_effects(self):
|
||||
GLOBAL_LIST.clear()
|
||||
|
||||
def foo(x):
|
||||
z = x + 1
|
||||
GLOBAL_LIST.append(z)
|
||||
return z
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(2, 3),)
|
||||
|
||||
trace_inputs = make_inputs()
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||
cnt = 0
|
||||
for entry in w:
|
||||
if "While compiling, we found certain side effects happened" in str(
|
||||
entry.message
|
||||
):
|
||||
cnt += 1
|
||||
self.assertEqual(cnt, 1)
|
||||
self.assertEqual(len(GLOBAL_LIST), 0)
|
||||
test_inputs = make_inputs()
|
||||
gm_results = gm(*test_inputs)
|
||||
self.assertEqual(len(GLOBAL_LIST), 0)
|
||||
self.assertEqual(gm_results, foo(*test_inputs))
|
||||
self.assertEqual(len(GLOBAL_LIST), 1)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
|
||||
class DummyOp(torch.autograd.Function):
|
||||
|
||||
@ -11,6 +11,7 @@ import math
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
@ -739,18 +740,26 @@ class TestExport(TestCase):
|
||||
dynamic_shapes={"x": {0: Dim("b")}, "y": None},
|
||||
)
|
||||
|
||||
# clean up _torchdynamo related meta data as it could vary depending on the caller
|
||||
# https://github.com/pytorch/pytorch/issues/167432
|
||||
for node in ep.graph.nodes:
|
||||
if "custom" in node.meta:
|
||||
node.meta["custom"] = {
|
||||
k: v
|
||||
for k, v in node.meta["custom"].items()
|
||||
if "_torchdynamo_disable" not in k
|
||||
}
|
||||
|
||||
custom_metadata = torch.fx.traceback._get_custom_metadata(ep.module())
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(custom_metadata),
|
||||
"""\
|
||||
('placeholder', 'x', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
|
||||
('placeholder', 'y', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
|
||||
('call_function', 'cat', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'item', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'ge_1', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', '_assert_scalar_default', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'mul', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('output', 'output', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})""",
|
||||
('call_function', 'cat', {'moo': 0})
|
||||
('call_function', 'item', {'moo': 0})
|
||||
('call_function', 'ge_1', {'moo': 0})
|
||||
('call_function', '_assert_scalar_default', {'moo': 0})
|
||||
('call_function', 'mul', {'moo': 0})""",
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@ -12257,8 +12266,15 @@ graph():
|
||||
def forward(self, x):
|
||||
return x + 2
|
||||
|
||||
def fancy_forward(x, y):
|
||||
return x + 2 + y
|
||||
if sys.version_info >= (3, 14):
|
||||
# functools.partial is now a method descriptor:
|
||||
# https://docs.python.org/3/whatsnew/3.14.html#changes-in-the-python-api
|
||||
def fancy_forward(self, x, y):
|
||||
return x + 2 + y
|
||||
else:
|
||||
|
||||
def fancy_forward(x, y):
|
||||
return x + 2 + y
|
||||
|
||||
Foo.forward = functools.partial(fancy_forward, y=torch.randn(4, 4))
|
||||
x = torch.randn(4, 4)
|
||||
@ -15295,12 +15311,12 @@ graph():
|
||||
def forward(self, block):
|
||||
return block.a + block.b
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError, "It looks like one of the inputs with type"
|
||||
):
|
||||
_dynamo_graph_capture_for_export(Foo())(
|
||||
dynamo_graph_capture_for_export(Foo())(
|
||||
Block(torch.randn(4, 4), torch.randn(4, 4))
|
||||
)
|
||||
|
||||
|
||||
@ -1,71 +0,0 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
from torch._export import config as export_config
|
||||
|
||||
|
||||
try:
|
||||
from . import test_export, testing
|
||||
except ImportError:
|
||||
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
||||
import testing # @manual=fbcode//caffe2/test:test_export-library
|
||||
|
||||
from torch.export import export
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def mocked_strict_export(*args, **kwargs):
|
||||
# If user already specified strict, don't make it strict
|
||||
if "strict" in kwargs:
|
||||
return export(*args, **kwargs)
|
||||
return export(*args, **kwargs, strict=True)
|
||||
|
||||
|
||||
def make_dynamic_cls(cls):
|
||||
# Some test check for ending in suffix; need to make
|
||||
# the `_strict` for end of string as a result
|
||||
suffix = test_export.INLINE_AND_INSTALL_STRICT_SUFFIX
|
||||
|
||||
cls_prefix = "InlineAndInstall"
|
||||
|
||||
cls_a = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
"StrictExport",
|
||||
suffix,
|
||||
mocked_strict_export,
|
||||
xfail_prop="_expected_failure_strict",
|
||||
)
|
||||
test_class = make_test_cls_with_patches(
|
||||
cls_a,
|
||||
cls_prefix,
|
||||
"",
|
||||
(export_config, "use_new_tracer_experimental", True),
|
||||
(dynamo_config, "install_free_tensors", True),
|
||||
(dynamo_config, "inline_inbuilt_nn_modules", True),
|
||||
xfail_prop="_expected_failure_inline_and_install",
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
return test_class
|
||||
|
||||
|
||||
tests = [
|
||||
test_export.TestDynamismExpression,
|
||||
test_export.TestExport,
|
||||
]
|
||||
for test in tests:
|
||||
make_dynamic_cls(test)
|
||||
del test
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
@ -90,6 +91,99 @@ def forward(self, arg0_1):
|
||||
|
||||
self.assertEqual(printed_output, f"moo 1 2\nmoo {new_inp}\nmoo 1 2\nyeehop 4")
|
||||
|
||||
def test_print_with_side_effect(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
|
||||
res = x + x
|
||||
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
|
||||
return (res,)
|
||||
|
||||
inputs = (torch.randn(3),)
|
||||
|
||||
# With functionalization, it should appear wrapped with with_effects()
|
||||
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
|
||||
self.assertExpectedInline(
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.higher_order.print, 'moo {x} {y}', x = 1, y = 2); \
|
||||
arg0_1 = None
|
||||
getitem = with_effects[0]; with_effects = None
|
||||
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
|
||||
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.print, 'moo {x} {y}', x = 1, y = 2); \
|
||||
getitem = None
|
||||
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
||||
return (getitem_2, add)""",
|
||||
)
|
||||
self.assertEqual(len(gs.input_tokens), 1)
|
||||
self.assertEqual(len(gs.output_tokens), 1)
|
||||
|
||||
def test_print_with_input_mutations(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
torch._higher_order_ops.print("moo {x} {y}", x=x, y=2)
|
||||
res = x + x
|
||||
x.add_(res)
|
||||
res = x + x
|
||||
torch._higher_order_ops.print("moo {x} {y}", x=x, y=res)
|
||||
return (res,)
|
||||
|
||||
inputs = (torch.randn(3),)
|
||||
|
||||
# With functionalization, it should appear wrapped with with_effects()
|
||||
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
|
||||
self.assertEqual(len(gs.input_tokens), 1)
|
||||
self.assertEqual(len(gs.output_tokens), 1)
|
||||
self.assertEqual(len(gs.user_inputs_to_mutate), 1)
|
||||
self.assertExpectedInline(
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.higher_order.print, 'moo {x} {y}', \
|
||||
x = arg1_1, y = 2); arg0_1 = None
|
||||
getitem = with_effects[0]; with_effects = None
|
||||
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
|
||||
add_1 = torch.ops.aten.add.Tensor(arg1_1, add); arg1_1 = add = None
|
||||
add_2 = torch.ops.aten.add.Tensor(add_1, add_1)
|
||||
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.print, 'moo {x} {y}', \
|
||||
x = add_1, y = add_2); getitem = None
|
||||
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
||||
return (getitem_2, add_1, add_2)""",
|
||||
)
|
||||
|
||||
def test_print_gen_schema(self):
|
||||
from torch._higher_order_ops.print import print as print_op
|
||||
|
||||
# Test basic schema generation with simple kwargs int
|
||||
format_str = "Hello {x} {y}"
|
||||
schema = print_op.gen_schema(format_str, x=1, y=2)
|
||||
self.assertExpectedInline(
|
||||
str(schema),
|
||||
"""print(str format_str, *, int x, int y) -> ()""",
|
||||
)
|
||||
# Test schema generation with different types of inputs
|
||||
|
||||
# Tensor input
|
||||
tensor = torch.randn(2, 2)
|
||||
schema_tensor = print_op.gen_schema("Tensor: {x}", x=tensor)
|
||||
self.assertExpectedInline(
|
||||
str(schema_tensor),
|
||||
"""print(str format_str, *, Tensor x) -> ()""",
|
||||
)
|
||||
|
||||
# TODO: Add schema support with kwargs with value of list type
|
||||
|
||||
# No kwargs
|
||||
schema_no_kwargs = print_op.gen_schema("Simple message")
|
||||
self.assertExpectedInline(
|
||||
str(schema_no_kwargs),
|
||||
"""print(str format_str) -> ()""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -235,6 +235,7 @@ class TestCKBackend(TestCase):
|
||||
Y_eager = a @ b
|
||||
torch.testing.assert_close(Y_compiled, Y_eager, equal_nan=True)
|
||||
|
||||
@unittest.skip("Autotune Mismatch being investigated")
|
||||
@unittest.skipIf(not torch.version.hip, "ROCM only")
|
||||
@unittest.mock.patch.dict(os.environ, _test_env)
|
||||
@parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user