mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 01:54:54 +08:00
Compare commits
1 Commits
ciflow/tru
...
dev/joona/
| Author | SHA1 | Date | |
|---|---|---|---|
| f203b98062 |
@ -1,19 +0,0 @@
|
||||
# Aarch64 (ARM/Graviton) Support Scripts
|
||||
Scripts for building aarch64 PyTorch PIP Wheels. These scripts build the following wheels:
|
||||
* torch
|
||||
* torchvision
|
||||
* torchaudio
|
||||
* torchtext
|
||||
* torchdata
|
||||
## Aarch64_ci_build.sh
|
||||
This script is design to support CD operations within PyPi manylinux aarch64 container, and be executed in the container. It prepares the container and then executes __aarch64_wheel_ci_build.py__ to build the wheels. The script "assumes" the PyTorch repo is located at: ```/pytorch``` and will put the wheels into ```/artifacts```.
|
||||
### Usage
|
||||
```DESIRED_PYTHON=<PythonVersion> aarch64_ci_build.sh```
|
||||
|
||||
__NOTE:__ CI build is currently __EXPERMINTAL__
|
||||
|
||||
## Build_aarch64_wheel.py
|
||||
This app allows a person to build using AWS EC3 resources and requires AWS-CLI and Boto3 with AWS credentials to support building EC2 instances for the wheel builds. Can be used in a codebuild CD or from a local system.
|
||||
|
||||
### Usage
|
||||
```build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch <RCtag>```
|
||||
@ -1,53 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -eux -o pipefail
|
||||
|
||||
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
|
||||
|
||||
# Set CUDA architecture lists to match x86 build_cuda.sh
|
||||
if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"13.0"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;11.0;12.0+PTX"
|
||||
fi
|
||||
|
||||
# Compress the fatbin with -compress-mode=size for CUDA 13
|
||||
if [[ "$DESIRED_CUDA" == *"13"* ]]; then
|
||||
export TORCH_NVCC_FLAGS="-compress-mode=size"
|
||||
# Bundle ptxas into the cu13 wheel, see https://github.com/pytorch/pytorch/issues/163801
|
||||
export BUILD_BUNDLE_PTXAS=1
|
||||
fi
|
||||
|
||||
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
|
||||
source $SCRIPTPATH/aarch64_ci_setup.sh
|
||||
|
||||
###############################################################################
|
||||
# Run aarch64 builder python
|
||||
###############################################################################
|
||||
cd /
|
||||
# adding safe directory for git as the permissions will be
|
||||
# on the mounted pytorch repo
|
||||
git config --global --add safe.directory /pytorch
|
||||
pip install -r /pytorch/requirements.txt
|
||||
pip install auditwheel==6.2.0 wheel
|
||||
if [ "$DESIRED_CUDA" = "cpu" ]; then
|
||||
echo "BASE_CUDA_VERSION is not set. Building cpu wheel."
|
||||
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
|
||||
else
|
||||
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
|
||||
export USE_SYSTEM_NCCL=1
|
||||
|
||||
# Check if we should use NVIDIA libs from PyPI (similar to x86 build_cuda.sh logic)
|
||||
if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
|
||||
echo "Bundling CUDA libraries with wheel for aarch64."
|
||||
else
|
||||
echo "Using nvidia libs from pypi for aarch64."
|
||||
echo "Updated PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64: $PYTORCH_EXTRA_INSTALL_REQUIREMENTS"
|
||||
export USE_NVIDIA_PYPI_LIBS=1
|
||||
fi
|
||||
|
||||
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
|
||||
fi
|
||||
@ -1,21 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -eux -o pipefail
|
||||
|
||||
# This script is used to prepare the Docker container for aarch64_ci_wheel_build.py python script
|
||||
# By creating symlinks from desired /opt/python to /usr/local/bin/
|
||||
|
||||
NUMPY_VERSION=2.0.2
|
||||
if [[ "$DESIRED_PYTHON" == "3.13" || "$DESIRED_PYTHON" == "3.13t" ]]; then
|
||||
NUMPY_VERSION=2.1.2
|
||||
fi
|
||||
|
||||
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
|
||||
source $SCRIPTPATH/../manywheel/set_desired_python.sh
|
||||
|
||||
pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2
|
||||
|
||||
for tool in python python3 pip pip3 ninja scons patchelf; do
|
||||
ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin;
|
||||
done
|
||||
|
||||
python --version
|
||||
@ -1,333 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# encoding: UTF-8
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from subprocess import check_call, check_output
|
||||
|
||||
|
||||
def list_dir(path: str) -> list[str]:
|
||||
"""'
|
||||
Helper for getting paths for Python
|
||||
"""
|
||||
return check_output(["ls", "-1", path]).decode().split("\n")
|
||||
|
||||
|
||||
def replace_tag(filename) -> None:
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("Tag:"):
|
||||
lines[i] = line.replace("-linux_", "-manylinux_2_28_")
|
||||
print(f"Updated tag from {line} to {lines[i]}")
|
||||
break
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
|
||||
def patch_library_rpath(
|
||||
folder: str,
|
||||
lib_name: str,
|
||||
use_nvidia_pypi_libs: bool = False,
|
||||
desired_cuda: str = "",
|
||||
) -> None:
|
||||
"""Apply patchelf to set RPATH for a library in torch/lib"""
|
||||
lib_path = f"{folder}/tmp/torch/lib/{lib_name}"
|
||||
|
||||
if use_nvidia_pypi_libs:
|
||||
# For PyPI NVIDIA libraries, construct CUDA RPATH
|
||||
cuda_rpaths = [
|
||||
"$ORIGIN/../../nvidia/cudnn/lib",
|
||||
"$ORIGIN/../../nvidia/nvshmem/lib",
|
||||
"$ORIGIN/../../nvidia/nccl/lib",
|
||||
"$ORIGIN/../../nvidia/cusparselt/lib",
|
||||
]
|
||||
|
||||
if "130" in desired_cuda:
|
||||
cuda_rpaths.append("$ORIGIN/../../nvidia/cu13/lib")
|
||||
else:
|
||||
cuda_rpaths.extend(
|
||||
[
|
||||
"$ORIGIN/../../nvidia/cublas/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_cupti/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_nvrtc/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_runtime/lib",
|
||||
"$ORIGIN/../../nvidia/cufft/lib",
|
||||
"$ORIGIN/../../nvidia/curand/lib",
|
||||
"$ORIGIN/../../nvidia/cusolver/lib",
|
||||
"$ORIGIN/../../nvidia/cusparse/lib",
|
||||
"$ORIGIN/../../nvidia/nvtx/lib",
|
||||
"$ORIGIN/../../nvidia/cufile/lib",
|
||||
]
|
||||
)
|
||||
|
||||
# Add $ORIGIN for local torch libs
|
||||
rpath = ":".join(cuda_rpaths) + ":$ORIGIN"
|
||||
else:
|
||||
# For bundled libraries, just use $ORIGIN
|
||||
rpath = "$ORIGIN"
|
||||
|
||||
if os.path.exists(lib_path):
|
||||
os.system(
|
||||
f"cd {folder}/tmp/torch/lib/; "
|
||||
f"patchelf --set-rpath '{rpath}' --force-rpath {lib_name}"
|
||||
)
|
||||
|
||||
|
||||
def copy_and_patch_library(
|
||||
src_path: str,
|
||||
folder: str,
|
||||
use_nvidia_pypi_libs: bool = False,
|
||||
desired_cuda: str = "",
|
||||
) -> None:
|
||||
"""Copy a library to torch/lib and patch its RPATH"""
|
||||
if os.path.exists(src_path):
|
||||
lib_name = os.path.basename(src_path)
|
||||
shutil.copy2(src_path, f"{folder}/tmp/torch/lib/{lib_name}")
|
||||
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
|
||||
def package_cuda_wheel(wheel_path, desired_cuda) -> None:
|
||||
"""
|
||||
Package the cuda wheel libraries
|
||||
"""
|
||||
folder = os.path.dirname(wheel_path)
|
||||
os.mkdir(f"{folder}/tmp")
|
||||
os.system(f"unzip {wheel_path} -d {folder}/tmp")
|
||||
# Delete original wheel since it will be repackaged
|
||||
os.system(f"rm {wheel_path}")
|
||||
|
||||
# Check if we should use PyPI NVIDIA libraries or bundle system libraries
|
||||
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
|
||||
|
||||
if use_nvidia_pypi_libs:
|
||||
print("Using nvidia libs from pypi - skipping CUDA library bundling")
|
||||
# For PyPI approach, we don't bundle CUDA libraries - they come from PyPI packages
|
||||
# We only need to bundle non-NVIDIA libraries
|
||||
minimal_libs_to_copy = [
|
||||
"/lib64/libgomp.so.1",
|
||||
"/usr/lib64/libgfortran.so.5",
|
||||
"/acl/build/libarm_compute.so",
|
||||
"/acl/build/libarm_compute_graph.so",
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0",
|
||||
]
|
||||
|
||||
# Copy minimal libraries to unzipped_folder/torch/lib
|
||||
for lib_path in minimal_libs_to_copy:
|
||||
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
# Patch torch libraries used for searching libraries
|
||||
torch_libs_to_patch = [
|
||||
"libtorch.so",
|
||||
"libtorch_cpu.so",
|
||||
"libtorch_cuda.so",
|
||||
"libtorch_cuda_linalg.so",
|
||||
"libtorch_global_deps.so",
|
||||
"libtorch_python.so",
|
||||
"libtorch_nvshmem.so",
|
||||
"libc10.so",
|
||||
"libc10_cuda.so",
|
||||
"libcaffe2_nvrtc.so",
|
||||
"libshm.so",
|
||||
]
|
||||
for lib_name in torch_libs_to_patch:
|
||||
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
|
||||
else:
|
||||
print("Bundling CUDA libraries with wheel")
|
||||
# Original logic for bundling system CUDA libraries
|
||||
# Common libraries for all CUDA versions
|
||||
common_libs = [
|
||||
# Non-NVIDIA system libraries
|
||||
"/lib64/libgomp.so.1",
|
||||
"/usr/lib64/libgfortran.so.5",
|
||||
"/acl/build/libarm_compute.so",
|
||||
"/acl/build/libarm_compute_graph.so",
|
||||
# Common CUDA libraries (same for all versions)
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0",
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so",
|
||||
"/usr/local/cuda/lib64/libcudnn.so.9",
|
||||
"/usr/local/cuda/lib64/libcusparseLt.so.0",
|
||||
"/usr/local/cuda/lib64/libcurand.so.10",
|
||||
"/usr/local/cuda/lib64/libnccl.so.2",
|
||||
"/usr/local/cuda/lib64/libnvshmem_host.so.3",
|
||||
"/usr/local/cuda/lib64/libcudnn_adv.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_cnn.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_graph.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_ops.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_heuristic.so.9",
|
||||
"/usr/local/cuda/lib64/libcufile.so.0",
|
||||
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
|
||||
"/usr/local/cuda/lib64/libcusparse.so.12",
|
||||
]
|
||||
|
||||
# CUDA version-specific libraries
|
||||
if "13" in desired_cuda:
|
||||
minor_version = desired_cuda[-1]
|
||||
version_specific_libs = [
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13",
|
||||
"/usr/local/cuda/lib64/libcublas.so.13",
|
||||
"/usr/local/cuda/lib64/libcublasLt.so.13",
|
||||
"/usr/local/cuda/lib64/libcudart.so.13",
|
||||
"/usr/local/cuda/lib64/libcufft.so.12",
|
||||
"/usr/local/cuda/lib64/libcusolver.so.12",
|
||||
"/usr/local/cuda/lib64/libnvJitLink.so.13",
|
||||
"/usr/local/cuda/lib64/libnvrtc.so.13",
|
||||
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.{minor_version}",
|
||||
]
|
||||
elif "12" in desired_cuda:
|
||||
# Get the last character for libnvrtc-builtins version (e.g., "129" -> "9")
|
||||
minor_version = desired_cuda[-1]
|
||||
version_specific_libs = [
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12",
|
||||
"/usr/local/cuda/lib64/libcublas.so.12",
|
||||
"/usr/local/cuda/lib64/libcublasLt.so.12",
|
||||
"/usr/local/cuda/lib64/libcudart.so.12",
|
||||
"/usr/local/cuda/lib64/libcufft.so.11",
|
||||
"/usr/local/cuda/lib64/libcusolver.so.11",
|
||||
"/usr/local/cuda/lib64/libnvJitLink.so.12",
|
||||
"/usr/local/cuda/lib64/libnvrtc.so.12",
|
||||
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.{minor_version}",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unsupported CUDA version: {desired_cuda}.")
|
||||
|
||||
# Combine all libraries
|
||||
libs_to_copy = common_libs + version_specific_libs
|
||||
|
||||
# Copy libraries to unzipped_folder/torch/lib
|
||||
for lib_path in libs_to_copy:
|
||||
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
# Make sure the wheel is tagged with manylinux_2_28
|
||||
for f in os.scandir(f"{folder}/tmp/"):
|
||||
if f.is_dir() and f.name.endswith(".dist-info"):
|
||||
replace_tag(f"{f.path}/WHEEL")
|
||||
break
|
||||
|
||||
os.system(f"wheel pack {folder}/tmp/ -d {folder}")
|
||||
os.system(f"rm -rf {folder}/tmp/")
|
||||
|
||||
|
||||
def complete_wheel(folder: str) -> str:
|
||||
"""
|
||||
Complete wheel build and put in artifact location
|
||||
"""
|
||||
wheel_name = list_dir(f"/{folder}/dist")[0]
|
||||
|
||||
# Please note for cuda we don't run auditwheel since we use custom script to package
|
||||
# the cuda dependencies to the wheel file using update_wheel() method.
|
||||
# However we need to make sure filename reflects the correct Manylinux platform.
|
||||
if "pytorch" in folder and not enable_cuda:
|
||||
print("Repairing Wheel with AuditWheel")
|
||||
check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder)
|
||||
repaired_wheel_name = list_dir(f"/{folder}/wheelhouse")[0]
|
||||
|
||||
print(f"Moving {repaired_wheel_name} wheel to /{folder}/dist")
|
||||
os.rename(
|
||||
f"/{folder}/wheelhouse/{repaired_wheel_name}",
|
||||
f"/{folder}/dist/{repaired_wheel_name}",
|
||||
)
|
||||
else:
|
||||
repaired_wheel_name = list_dir(f"/{folder}/dist")[0]
|
||||
|
||||
print(f"Copying {repaired_wheel_name} to artifacts")
|
||||
shutil.copy2(
|
||||
f"/{folder}/dist/{repaired_wheel_name}", f"/artifacts/{repaired_wheel_name}"
|
||||
)
|
||||
|
||||
return repaired_wheel_name
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""
|
||||
Parse inline arguments
|
||||
"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("AARCH64 wheels python CD")
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--build-only", action="store_true")
|
||||
parser.add_argument("--test-only", type=str)
|
||||
parser.add_argument("--enable-mkldnn", action="store_true")
|
||||
parser.add_argument("--enable-cuda", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Entry Point
|
||||
"""
|
||||
args = parse_arguments()
|
||||
enable_mkldnn = args.enable_mkldnn
|
||||
enable_cuda = args.enable_cuda
|
||||
branch = check_output(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd="/pytorch"
|
||||
).decode()
|
||||
|
||||
print("Building PyTorch wheel")
|
||||
build_vars = ""
|
||||
# MAX_JOB=5 is not required for CPU backend (see commit 465d98b)
|
||||
if enable_cuda:
|
||||
build_vars += "MAX_JOBS=5 "
|
||||
|
||||
# Handle PyPI NVIDIA libraries vs bundled libraries
|
||||
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
|
||||
if use_nvidia_pypi_libs:
|
||||
print("Configuring build for PyPI NVIDIA libraries")
|
||||
# Configure for dynamic linking (matching x86 logic)
|
||||
build_vars += "ATEN_STATIC_CUDA=0 USE_CUDA_STATIC_LINK=0 USE_CUPTI_SO=1 "
|
||||
else:
|
||||
print("Configuring build for bundled NVIDIA libraries")
|
||||
# Keep existing static linking approach - already configured above
|
||||
|
||||
override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION")
|
||||
desired_cuda = os.getenv("DESIRED_CUDA")
|
||||
if override_package_version is not None:
|
||||
version = override_package_version
|
||||
build_vars += (
|
||||
f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version} PYTORCH_BUILD_NUMBER=1 "
|
||||
)
|
||||
elif branch in ["nightly", "main"]:
|
||||
build_date = (
|
||||
check_output(["git", "log", "--pretty=format:%cs", "-1"], cwd="/pytorch")
|
||||
.decode()
|
||||
.replace("-", "")
|
||||
)
|
||||
version = (
|
||||
check_output(["cat", "version.txt"], cwd="/pytorch").decode().strip()[:-2]
|
||||
)
|
||||
if enable_cuda:
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date}+{desired_cuda} PYTORCH_BUILD_NUMBER=1 "
|
||||
else:
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 "
|
||||
elif branch.startswith(("v1.", "v2.")):
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
|
||||
|
||||
if enable_mkldnn:
|
||||
print("build pytorch with mkldnn+acl backend")
|
||||
build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
|
||||
build_vars += "ACL_ROOT_DIR=/acl "
|
||||
if enable_cuda:
|
||||
build_vars += "BLAS=NVPL "
|
||||
else:
|
||||
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS "
|
||||
else:
|
||||
print("build pytorch without mkldnn backend")
|
||||
|
||||
os.system(f"cd /pytorch; {build_vars} python3 -m build --wheel --no-isolation")
|
||||
if enable_cuda:
|
||||
print("Updating Cuda Dependency")
|
||||
filename = os.listdir("/pytorch/dist/")
|
||||
wheel_path = f"/pytorch/dist/{filename[0]}"
|
||||
package_cuda_wheel(wheel_path, desired_cuda)
|
||||
pytorch_wheel_name = complete_wheel("/pytorch/")
|
||||
print(f"Build Complete. Created {pytorch_wheel_name}..")
|
||||
@ -1,999 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# This script is for building AARCH64 wheels using AWS EC2 instances.
|
||||
# To generate binaries for the release follow these steps:
|
||||
# 1. Update mappings for each of the Domain Libraries by adding new row to a table like this:
|
||||
# "v1.11.0": ("0.11.0", "rc1"),
|
||||
# 2. Run script with following arguments for each of the supported python versions and required tag, for example:
|
||||
# build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch v1.11.0-rc3
|
||||
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
|
||||
# AMI images for us-east-1, change the following based on your ~/.aws/config
|
||||
os_amis = {
|
||||
"ubuntu20_04": "ami-052eac90edaa9d08f", # login_name: ubuntu
|
||||
"ubuntu22_04": "ami-0c6c29c5125214c77", # login_name: ubuntu
|
||||
"redhat8": "ami-0698b90665a2ddcf1", # login_name: ec2-user
|
||||
}
|
||||
|
||||
ubuntu20_04_ami = os_amis["ubuntu20_04"]
|
||||
|
||||
|
||||
def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]:
|
||||
if key_name is None:
|
||||
key_name = os.getenv("AWS_KEY_NAME")
|
||||
if key_name is None:
|
||||
return os.getenv("SSH_KEY_PATH", ""), ""
|
||||
|
||||
homedir_path = os.path.expanduser("~")
|
||||
default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem")
|
||||
return os.getenv("SSH_KEY_PATH", default_path), key_name
|
||||
|
||||
|
||||
ec2 = boto3.resource("ec2")
|
||||
|
||||
|
||||
def ec2_get_instances(filter_name, filter_value):
|
||||
return ec2.instances.filter(
|
||||
Filters=[{"Name": filter_name, "Values": [filter_value]}]
|
||||
)
|
||||
|
||||
|
||||
def ec2_instances_of_type(instance_type="t4g.2xlarge"):
|
||||
return ec2_get_instances("instance-type", instance_type)
|
||||
|
||||
|
||||
def ec2_instances_by_id(instance_id):
|
||||
rc = list(ec2_get_instances("instance-id", instance_id))
|
||||
return rc[0] if len(rc) > 0 else None
|
||||
|
||||
|
||||
def start_instance(
|
||||
key_name, ami=ubuntu20_04_ami, instance_type="t4g.2xlarge", ebs_size: int = 50
|
||||
):
|
||||
inst = ec2.create_instances(
|
||||
ImageId=ami,
|
||||
InstanceType=instance_type,
|
||||
SecurityGroups=["ssh-allworld"],
|
||||
KeyName=key_name,
|
||||
MinCount=1,
|
||||
MaxCount=1,
|
||||
BlockDeviceMappings=[
|
||||
{
|
||||
"DeviceName": "/dev/sda1",
|
||||
"Ebs": {
|
||||
"DeleteOnTermination": True,
|
||||
"VolumeSize": ebs_size,
|
||||
"VolumeType": "standard",
|
||||
},
|
||||
}
|
||||
],
|
||||
)[0]
|
||||
print(f"Create instance {inst.id}")
|
||||
inst.wait_until_running()
|
||||
running_inst = ec2_instances_by_id(inst.id)
|
||||
print(f"Instance started at {running_inst.public_dns_name}")
|
||||
return running_inst
|
||||
|
||||
|
||||
class RemoteHost:
|
||||
addr: str
|
||||
keyfile_path: str
|
||||
login_name: str
|
||||
container_id: Optional[str] = None
|
||||
ami: Optional[str] = None
|
||||
|
||||
def __init__(self, addr: str, keyfile_path: str, login_name: str = "ubuntu"):
|
||||
self.addr = addr
|
||||
self.keyfile_path = keyfile_path
|
||||
self.login_name = login_name
|
||||
|
||||
def _gen_ssh_prefix(self) -> list[str]:
|
||||
return [
|
||||
"ssh",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
f"{self.login_name}@{self.addr}",
|
||||
"--",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _split_cmd(args: Union[str, list[str]]) -> list[str]:
|
||||
return args.split() if isinstance(args, str) else args
|
||||
|
||||
def run_ssh_cmd(self, args: Union[str, list[str]]) -> None:
|
||||
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
|
||||
|
||||
def check_ssh_output(self, args: Union[str, list[str]]) -> str:
|
||||
return subprocess.check_output(
|
||||
self._gen_ssh_prefix() + self._split_cmd(args)
|
||||
).decode("utf-8")
|
||||
|
||||
def scp_upload_file(self, local_file: str, remote_file: str) -> None:
|
||||
subprocess.check_call(
|
||||
[
|
||||
"scp",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
local_file,
|
||||
f"{self.login_name}@{self.addr}:{remote_file}",
|
||||
]
|
||||
)
|
||||
|
||||
def scp_download_file(
|
||||
self, remote_file: str, local_file: Optional[str] = None
|
||||
) -> None:
|
||||
if local_file is None:
|
||||
local_file = "."
|
||||
subprocess.check_call(
|
||||
[
|
||||
"scp",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
f"{self.login_name}@{self.addr}:{remote_file}",
|
||||
local_file,
|
||||
]
|
||||
)
|
||||
|
||||
def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None:
|
||||
self.run_ssh_cmd("sudo apt-get install -y docker.io")
|
||||
self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}")
|
||||
self.run_ssh_cmd("sudo service docker start")
|
||||
self.run_ssh_cmd(f"docker pull {image}")
|
||||
self.container_id = self.check_ssh_output(
|
||||
f"docker run -t -d -w /root {image}"
|
||||
).strip()
|
||||
|
||||
def using_docker(self) -> bool:
|
||||
return self.container_id is not None
|
||||
|
||||
def run_cmd(self, args: Union[str, list[str]]) -> None:
|
||||
if not self.using_docker():
|
||||
return self.run_ssh_cmd(args)
|
||||
assert self.container_id is not None
|
||||
docker_cmd = self._gen_ssh_prefix() + [
|
||||
"docker",
|
||||
"exec",
|
||||
"-i",
|
||||
self.container_id,
|
||||
"bash",
|
||||
]
|
||||
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE)
|
||||
p.communicate(
|
||||
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
rc = p.wait()
|
||||
if rc != 0:
|
||||
raise subprocess.CalledProcessError(rc, docker_cmd)
|
||||
|
||||
def check_output(self, args: Union[str, list[str]]) -> str:
|
||||
if not self.using_docker():
|
||||
return self.check_ssh_output(args)
|
||||
assert self.container_id is not None
|
||||
docker_cmd = self._gen_ssh_prefix() + [
|
||||
"docker",
|
||||
"exec",
|
||||
"-i",
|
||||
self.container_id,
|
||||
"bash",
|
||||
]
|
||||
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||
(out, err) = p.communicate(
|
||||
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
rc = p.wait()
|
||||
if rc != 0:
|
||||
raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err)
|
||||
return out.decode("utf-8")
|
||||
|
||||
def upload_file(self, local_file: str, remote_file: str) -> None:
|
||||
if not self.using_docker():
|
||||
return self.scp_upload_file(local_file, remote_file)
|
||||
tmp_file = os.path.join("/tmp", os.path.basename(local_file))
|
||||
self.scp_upload_file(local_file, tmp_file)
|
||||
self.run_ssh_cmd(
|
||||
["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"]
|
||||
)
|
||||
self.run_ssh_cmd(["rm", tmp_file])
|
||||
|
||||
def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None:
|
||||
if not self.using_docker():
|
||||
return self.scp_download_file(remote_file, local_file)
|
||||
tmp_file = os.path.join("/tmp", os.path.basename(remote_file))
|
||||
self.run_ssh_cmd(
|
||||
["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file]
|
||||
)
|
||||
self.scp_download_file(tmp_file, local_file)
|
||||
self.run_ssh_cmd(["rm", tmp_file])
|
||||
|
||||
def download_wheel(
|
||||
self, remote_file: str, local_file: Optional[str] = None
|
||||
) -> None:
|
||||
if self.using_docker() and local_file is None:
|
||||
basename = os.path.basename(remote_file)
|
||||
local_file = basename.replace(
|
||||
"-linux_aarch64.whl", "-manylinux2014_aarch64.whl"
|
||||
)
|
||||
self.download_file(remote_file, local_file)
|
||||
|
||||
def list_dir(self, path: str) -> list[str]:
|
||||
return self.check_output(["ls", "-1", path]).split("\n")
|
||||
|
||||
|
||||
def wait_for_connection(addr, port, timeout=15, attempt_cnt=5):
|
||||
import socket
|
||||
|
||||
for i in range(attempt_cnt):
|
||||
try:
|
||||
with socket.create_connection((addr, port), timeout=timeout):
|
||||
return
|
||||
except (ConnectionRefusedError, TimeoutError): # noqa: PERF203
|
||||
if i == attempt_cnt - 1:
|
||||
raise
|
||||
time.sleep(timeout)
|
||||
|
||||
|
||||
def update_apt_repo(host: RemoteHost) -> None:
|
||||
time.sleep(5)
|
||||
host.run_cmd("sudo systemctl stop apt-daily.service || true")
|
||||
host.run_cmd("sudo systemctl stop unattended-upgrades.service || true")
|
||||
host.run_cmd(
|
||||
"while systemctl is-active --quiet apt-daily.service; do sleep 1; done"
|
||||
)
|
||||
host.run_cmd(
|
||||
"while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done"
|
||||
)
|
||||
host.run_cmd("sudo apt-get update")
|
||||
time.sleep(3)
|
||||
host.run_cmd("sudo apt-get update")
|
||||
|
||||
|
||||
def install_condaforge(
|
||||
host: RemoteHost, suffix: str = "latest/download/Miniforge3-Linux-aarch64.sh"
|
||||
) -> None:
|
||||
print("Install conda-forge")
|
||||
host.run_cmd(f"curl -OL https://github.com/conda-forge/miniforge/releases/{suffix}")
|
||||
host.run_cmd(f"sh -f {os.path.basename(suffix)} -b")
|
||||
host.run_cmd(f"rm -f {os.path.basename(suffix)}")
|
||||
if host.using_docker():
|
||||
host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc")
|
||||
else:
|
||||
host.run_cmd(
|
||||
[
|
||||
"sed",
|
||||
"-i",
|
||||
"'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH'",
|
||||
".bashrc",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None:
|
||||
if python_version == "3.6":
|
||||
# Python-3.6 EOLed and not compatible with conda-4.11
|
||||
install_condaforge(
|
||||
host, suffix="download/4.10.3-10/Miniforge3-4.10.3-10-Linux-aarch64.sh"
|
||||
)
|
||||
host.run_cmd(f"conda install -y python={python_version} numpy pyyaml")
|
||||
else:
|
||||
install_condaforge(
|
||||
host, suffix="download/4.11.0-4/Miniforge3-4.11.0-4-Linux-aarch64.sh"
|
||||
)
|
||||
# Pytorch-1.10 or older are not compatible with setuptools=59.6 or newer
|
||||
host.run_cmd(
|
||||
f"conda install -y python={python_version} numpy pyyaml setuptools>=59.5.0"
|
||||
)
|
||||
|
||||
|
||||
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
|
||||
host.run_cmd("pip3 install auditwheel")
|
||||
host.run_cmd(
|
||||
"conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf"
|
||||
)
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
with NamedTemporaryFile() as tmp:
|
||||
tmp.write(embed_library_script.encode("utf-8"))
|
||||
tmp.flush()
|
||||
host.upload_file(tmp.name, "embed_library.py")
|
||||
|
||||
print("Embedding libgomp into wheel")
|
||||
if host.using_docker():
|
||||
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
|
||||
else:
|
||||
host.run_cmd(f"python3 embed_library.py {wheel_name}")
|
||||
|
||||
|
||||
def checkout_repo(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
url: str,
|
||||
git_clone_flags: str,
|
||||
mapping: dict[str, tuple[str, str]],
|
||||
) -> Optional[str]:
|
||||
for prefix in mapping:
|
||||
if not branch.startswith(prefix):
|
||||
continue
|
||||
tag = f"v{mapping[prefix][0]}-{mapping[prefix][1]}"
|
||||
host.run_cmd(f"git clone {url} -b {tag} {git_clone_flags}")
|
||||
return mapping[prefix][0]
|
||||
|
||||
host.run_cmd(f"git clone {url} -b {branch} {git_clone_flags}")
|
||||
return None
|
||||
|
||||
|
||||
def build_torchvision(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str,
|
||||
run_smoke_tests: bool = True,
|
||||
) -> str:
|
||||
print("Checking out TorchVision repo")
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/vision",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.7.1": ("0.8.2", "rc2"),
|
||||
"v1.8.0": ("0.9.0", "rc3"),
|
||||
"v1.8.1": ("0.9.1", "rc1"),
|
||||
"v1.9.0": ("0.10.0", "rc1"),
|
||||
"v1.10.0": ("0.11.1", "rc1"),
|
||||
"v1.10.1": ("0.11.2", "rc1"),
|
||||
"v1.10.2": ("0.11.3", "rc1"),
|
||||
"v1.11.0": ("0.12.0", "rc1"),
|
||||
"v1.12.0": ("0.13.0", "rc4"),
|
||||
"v1.12.1": ("0.13.1", "rc6"),
|
||||
"v1.13.0": ("0.14.0", "rc4"),
|
||||
"v1.13.1": ("0.14.1", "rc2"),
|
||||
"v2.0.0": ("0.15.1", "rc2"),
|
||||
"v2.0.1": ("0.15.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchVision wheel")
|
||||
|
||||
# Please note libnpg and jpeg are required to build image.so extension
|
||||
if use_conda:
|
||||
host.run_cmd("conda install -y libpng jpeg")
|
||||
# Remove .so files to force static linking
|
||||
host.run_cmd(
|
||||
"rm miniforge3/lib/libpng.so miniforge3/lib/libpng16.so miniforge3/lib/libjpeg.so"
|
||||
)
|
||||
# And patch setup.py to include libz dependency for libpng
|
||||
host.run_cmd(
|
||||
[
|
||||
'sed -i -e \'s/image_link_flags\\.append("png")/image_link_flags += ["png", "z"]/\' vision/setup.py'
|
||||
]
|
||||
)
|
||||
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"]
|
||||
).strip()
|
||||
if len(version) == 0:
|
||||
# In older revisions, version was embedded in setup.py
|
||||
version = (
|
||||
host.check_output(["grep", '"version = \'"', "vision/setup.py"])
|
||||
.strip()
|
||||
.split("'")[1][:-2]
|
||||
)
|
||||
build_date = (
|
||||
host.check_output("cd vision && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd vision && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
vision_wheel_name = host.list_dir("vision/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("vision", "dist", vision_wheel_name))
|
||||
|
||||
print("Copying TorchVision wheel")
|
||||
host.download_wheel(os.path.join("vision", "dist", vision_wheel_name))
|
||||
if run_smoke_tests:
|
||||
host.run_cmd(
|
||||
f"pip3 install {os.path.join('vision', 'dist', vision_wheel_name)}"
|
||||
)
|
||||
host.run_cmd("python3 vision/test/smoke_test.py")
|
||||
print("Delete vision checkout")
|
||||
host.run_cmd("rm -rf vision")
|
||||
|
||||
return vision_wheel_name
|
||||
|
||||
|
||||
def build_torchdata(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchData repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/data",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.13.1": ("0.5.1", ""),
|
||||
"v2.0.0": ("0.6.0", "rc5"),
|
||||
"v2.0.1": ("0.6.1", "rc1"),
|
||||
},
|
||||
)
|
||||
print("Building TorchData wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f data/version.txt ]; then cat data/version.txt; fi"]
|
||||
).strip()
|
||||
build_date = (
|
||||
host.check_output("cd data && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd data && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
wheel_name = host.list_dir("data/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("data", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchData wheel")
|
||||
host.download_wheel(os.path.join("data", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def build_torchtext(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchText repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/text",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.9.0": ("0.10.0", "rc1"),
|
||||
"v1.10.0": ("0.11.0", "rc2"),
|
||||
"v1.10.1": ("0.11.1", "rc1"),
|
||||
"v1.10.2": ("0.11.2", "rc1"),
|
||||
"v1.11.0": ("0.12.0", "rc1"),
|
||||
"v1.12.0": ("0.13.0", "rc2"),
|
||||
"v1.12.1": ("0.13.1", "rc5"),
|
||||
"v1.13.0": ("0.14.0", "rc3"),
|
||||
"v1.13.1": ("0.14.1", "rc1"),
|
||||
"v2.0.0": ("0.15.1", "rc2"),
|
||||
"v2.0.1": ("0.15.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchText wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f text/version.txt ]; then cat text/version.txt; fi"]
|
||||
).strip()
|
||||
build_date = (
|
||||
host.check_output("cd text && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd text && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
wheel_name = host.list_dir("text/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("text", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchText wheel")
|
||||
host.download_wheel(os.path.join("text", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def build_torchaudio(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchAudio repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/audio",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.9.0": ("0.9.0", "rc2"),
|
||||
"v1.10.0": ("0.10.0", "rc5"),
|
||||
"v1.10.1": ("0.10.1", "rc1"),
|
||||
"v1.10.2": ("0.10.2", "rc1"),
|
||||
"v1.11.0": ("0.11.0", "rc1"),
|
||||
"v1.12.0": ("0.12.0", "rc3"),
|
||||
"v1.12.1": ("0.12.1", "rc5"),
|
||||
"v1.13.0": ("0.13.0", "rc4"),
|
||||
"v1.13.1": ("0.13.1", "rc2"),
|
||||
"v2.0.0": ("2.0.1", "rc3"),
|
||||
"v2.0.1": ("2.0.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchAudio wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = (
|
||||
host.check_output(["grep", '"version = \'"', "audio/setup.py"])
|
||||
.strip()
|
||||
.split("'")[1][:-2]
|
||||
)
|
||||
build_date = (
|
||||
host.check_output("cd audio && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(
|
||||
f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \
|
||||
&& ./packaging/ffmpeg/build.sh \
|
||||
&& {build_vars} python3 -m build --wheel --no-isolation"
|
||||
)
|
||||
|
||||
wheel_name = host.list_dir("audio/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("audio", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchAudio wheel")
|
||||
host.download_wheel(os.path.join("audio", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def configure_system(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
compiler: str = "gcc-8",
|
||||
use_conda: bool = True,
|
||||
python_version: str = "3.8",
|
||||
) -> None:
|
||||
if use_conda:
|
||||
install_condaforge_python(host, python_version)
|
||||
|
||||
print("Configuring the system")
|
||||
if not host.using_docker():
|
||||
update_apt_repo(host)
|
||||
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip")
|
||||
else:
|
||||
host.run_cmd("yum install -y sudo")
|
||||
host.run_cmd("conda install -y ninja scons")
|
||||
|
||||
if not use_conda:
|
||||
host.run_cmd(
|
||||
"sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip"
|
||||
)
|
||||
host.run_cmd("pip3 install dataclasses typing-extensions")
|
||||
if not use_conda:
|
||||
print("Installing Cython + numpy from PyPy")
|
||||
host.run_cmd("sudo pip3 install Cython")
|
||||
host.run_cmd("sudo pip3 install numpy")
|
||||
|
||||
|
||||
def build_domains(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> tuple[str, str, str, str]:
|
||||
vision_wheel_name = build_torchvision(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
audio_wheel_name = build_torchaudio(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
data_wheel_name = build_torchdata(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
text_wheel_name = build_torchtext(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
return (vision_wheel_name, audio_wheel_name, data_wheel_name, text_wheel_name)
|
||||
|
||||
|
||||
def start_build(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
compiler: str = "gcc-8",
|
||||
use_conda: bool = True,
|
||||
python_version: str = "3.8",
|
||||
pytorch_only: bool = False,
|
||||
pytorch_build_number: Optional[str] = None,
|
||||
shallow_clone: bool = True,
|
||||
enable_mkldnn: bool = False,
|
||||
) -> tuple[str, str, str, str, str]:
|
||||
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
|
||||
if host.using_docker() and not use_conda:
|
||||
print("Auto-selecting conda option for docker images")
|
||||
use_conda = True
|
||||
if not host.using_docker():
|
||||
print("Disable mkldnn for host builds")
|
||||
enable_mkldnn = False
|
||||
|
||||
configure_system(
|
||||
host, compiler=compiler, use_conda=use_conda, python_version=python_version
|
||||
)
|
||||
|
||||
if host.using_docker():
|
||||
print("Move libgfortant.a into a standard location")
|
||||
# HACK: pypa gforntran.a is compiled without PIC, which leads to the following error
|
||||
# libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17' # noqa: E501, B950
|
||||
# Workaround by copying gfortran library from the host
|
||||
host.run_ssh_cmd("sudo apt-get install -y gfortran-8")
|
||||
host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8")
|
||||
host.run_ssh_cmd(
|
||||
[
|
||||
"docker",
|
||||
"cp",
|
||||
"/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a",
|
||||
f"{host.container_id}:/opt/rh/devtoolset-10/root/usr/lib/gcc/aarch64-redhat-linux/10/",
|
||||
]
|
||||
)
|
||||
|
||||
print("Checking out PyTorch repo")
|
||||
host.run_cmd(
|
||||
f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}"
|
||||
)
|
||||
|
||||
host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh")
|
||||
|
||||
print("Building PyTorch wheel")
|
||||
build_opts = ""
|
||||
if pytorch_build_number is not None:
|
||||
build_opts += f" -C--build-option=--build-number={pytorch_build_number}"
|
||||
# Breakpad build fails on aarch64
|
||||
build_vars = "USE_BREAKPAD=0 "
|
||||
if branch == "nightly":
|
||||
build_date = (
|
||||
host.check_output("cd pytorch && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
version = host.check_output("cat pytorch/version.txt").strip()[:-2]
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1"
|
||||
if branch.startswith(("v1.", "v2.")):
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
if enable_mkldnn:
|
||||
host.run_cmd("pytorch/.ci/docker/common/install_acl.sh")
|
||||
print("build pytorch with mkldnn+acl backend")
|
||||
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
|
||||
build_vars += " BLAS=OpenBLAS"
|
||||
build_vars += " OpenBLAS_HOME=/opt/OpenBLAS"
|
||||
build_vars += " ACL_ROOT_DIR=/acl"
|
||||
host.run_cmd(
|
||||
f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
|
||||
)
|
||||
print("Repair the wheel")
|
||||
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
|
||||
ld_library_path = "/acl/build:$HOME/pytorch/build/lib"
|
||||
host.run_cmd(
|
||||
f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}"
|
||||
)
|
||||
print("replace the original wheel with the repaired one")
|
||||
pytorch_repaired_wheel_name = host.list_dir("wheelhouse")[0]
|
||||
host.run_cmd(
|
||||
f"cp $HOME/wheelhouse/{pytorch_repaired_wheel_name} $HOME/pytorch/dist/{pytorch_wheel_name}"
|
||||
)
|
||||
else:
|
||||
print("build pytorch without mkldnn backend")
|
||||
host.run_cmd(
|
||||
f"cd pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
|
||||
)
|
||||
|
||||
print("Deleting build folder")
|
||||
host.run_cmd("cd pytorch && rm -rf build")
|
||||
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("pytorch", "dist", pytorch_wheel_name))
|
||||
print("Copying the wheel")
|
||||
host.download_wheel(os.path.join("pytorch", "dist", pytorch_wheel_name))
|
||||
|
||||
print("Installing PyTorch wheel")
|
||||
host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}")
|
||||
|
||||
if pytorch_only:
|
||||
return (pytorch_wheel_name, None, None, None, None)
|
||||
domain_wheels = build_domains(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
|
||||
return (pytorch_wheel_name, *domain_wheels)
|
||||
|
||||
|
||||
embed_library_script = """
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from auditwheel.patcher import Patchelf
|
||||
from auditwheel.wheeltools import InWheelCtx
|
||||
from auditwheel.elfutils import elf_file_filter
|
||||
from auditwheel.repair import copylib
|
||||
from auditwheel.lddtree import lddtree
|
||||
from subprocess import check_call
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
|
||||
def replace_tag(filename):
|
||||
with open(filename, 'r') as f:
|
||||
lines = f.read().split("\\n")
|
||||
for i,line in enumerate(lines):
|
||||
if not line.startswith("Tag: "):
|
||||
continue
|
||||
lines[i] = line.replace("-linux_", "-manylinux2014_")
|
||||
print(f'Updated tag from {line} to {lines[i]}')
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
f.write("\\n".join(lines))
|
||||
|
||||
|
||||
class AlignedPatchelf(Patchelf):
|
||||
def set_soname(self, file_name: str, new_soname: str) -> None:
|
||||
check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name])
|
||||
|
||||
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
|
||||
check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name])
|
||||
|
||||
|
||||
def embed_library(whl_path, lib_soname, update_tag=False):
|
||||
patcher = AlignedPatchelf()
|
||||
out_dir = TemporaryDirectory()
|
||||
whl_name = os.path.basename(whl_path)
|
||||
tmp_whl_name = os.path.join(out_dir.name, whl_name)
|
||||
with InWheelCtx(whl_path) as ctx:
|
||||
torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib')
|
||||
ctx.out_wheel=tmp_whl_name
|
||||
new_lib_path, new_lib_soname = None, None
|
||||
for filename, elf in elf_file_filter(ctx.iter_files()):
|
||||
if not filename.startswith('torch/lib'):
|
||||
continue
|
||||
libtree = lddtree(filename)
|
||||
if lib_soname not in libtree['needed']:
|
||||
continue
|
||||
lib_path = libtree['libs'][lib_soname]['path']
|
||||
if lib_path is None:
|
||||
print(f"Can't embed {lib_soname} as it could not be found")
|
||||
break
|
||||
if lib_path.startswith(torchlib_path):
|
||||
continue
|
||||
|
||||
if new_lib_path is None:
|
||||
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
|
||||
patcher.replace_needed(filename, lib_soname, new_lib_soname)
|
||||
print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}')
|
||||
if update_tag:
|
||||
# Add manylinux2014 tag
|
||||
for filename in ctx.iter_files():
|
||||
if os.path.basename(filename) != 'WHEEL':
|
||||
continue
|
||||
replace_tag(filename)
|
||||
shutil.move(tmp_whl_name, whl_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
|
||||
"""
|
||||
|
||||
|
||||
def run_tests(host: RemoteHost, whl: str, branch="main") -> None:
|
||||
print("Configuring the system")
|
||||
update_apt_repo(host)
|
||||
host.run_cmd("sudo apt-get install -y python3-pip git")
|
||||
host.run_cmd("sudo pip3 install Cython")
|
||||
host.run_cmd("sudo pip3 install numpy")
|
||||
host.upload_file(whl, ".")
|
||||
host.run_cmd(f"sudo pip3 install {whl}")
|
||||
host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'")
|
||||
host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch")
|
||||
host.run_cmd("cd pytorch/test; python3 test_torch.py -v")
|
||||
|
||||
|
||||
def get_instance_name(instance) -> Optional[str]:
|
||||
if instance.tags is None:
|
||||
return None
|
||||
for tag in instance.tags:
|
||||
if tag["Key"] == "Name":
|
||||
return tag["Value"]
|
||||
return None
|
||||
|
||||
|
||||
def list_instances(instance_type: str) -> None:
|
||||
print(f"All instances of type {instance_type}")
|
||||
for instance in ec2_instances_of_type(instance_type):
|
||||
ifaces = instance.network_interfaces
|
||||
az = ifaces[0].subnet.availability_zone if len(ifaces) > 0 else None
|
||||
print(
|
||||
f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']} {az}"
|
||||
)
|
||||
|
||||
|
||||
def terminate_instances(instance_type: str) -> None:
|
||||
print(f"Terminating all instances of type {instance_type}")
|
||||
instances = list(ec2_instances_of_type(instance_type))
|
||||
for instance in instances:
|
||||
print(f"Terminating {instance.id}")
|
||||
instance.terminate()
|
||||
print("Waiting for termination to complete")
|
||||
for instance in instances:
|
||||
instance.wait_until_terminated()
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("Build and test AARCH64 wheels using EC2")
|
||||
parser.add_argument("--key-name", type=str)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--build-only", action="store_true")
|
||||
parser.add_argument("--test-only", type=str)
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--os", type=str, choices=list(os_amis.keys()))
|
||||
group.add_argument("--ami", type=str)
|
||||
parser.add_argument(
|
||||
"--python-version",
|
||||
type=str,
|
||||
choices=[f"3.{d}" for d in range(6, 12)],
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--alloc-instance", action="store_true")
|
||||
parser.add_argument("--list-instances", action="store_true")
|
||||
parser.add_argument("--pytorch-only", action="store_true")
|
||||
parser.add_argument("--keep-running", action="store_true")
|
||||
parser.add_argument("--terminate-instances", action="store_true")
|
||||
parser.add_argument("--instance-type", type=str, default="t4g.2xlarge")
|
||||
parser.add_argument("--ebs-size", type=int, default=50)
|
||||
parser.add_argument("--branch", type=str, default="main")
|
||||
parser.add_argument("--use-docker", action="store_true")
|
||||
parser.add_argument(
|
||||
"--compiler",
|
||||
type=str,
|
||||
choices=["gcc-7", "gcc-8", "gcc-9", "clang"],
|
||||
default="gcc-8",
|
||||
)
|
||||
parser.add_argument("--use-torch-from-pypi", action="store_true")
|
||||
parser.add_argument("--pytorch-build-number", type=str, default=None)
|
||||
parser.add_argument("--disable-mkldnn", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
ami = (
|
||||
args.ami
|
||||
if args.ami is not None
|
||||
else os_amis[args.os]
|
||||
if args.os is not None
|
||||
else ubuntu20_04_ami
|
||||
)
|
||||
keyfile_path, key_name = compute_keyfile_path(args.key_name)
|
||||
|
||||
if args.list_instances:
|
||||
list_instances(args.instance_type)
|
||||
sys.exit(0)
|
||||
|
||||
if args.terminate_instances:
|
||||
terminate_instances(args.instance_type)
|
||||
sys.exit(0)
|
||||
|
||||
if len(key_name) == 0:
|
||||
raise RuntimeError("""
|
||||
Cannot start build without key_name, please specify
|
||||
--key-name argument or AWS_KEY_NAME environment variable.""")
|
||||
if len(keyfile_path) == 0 or not os.path.exists(keyfile_path):
|
||||
raise RuntimeError(f"""
|
||||
Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please
|
||||
check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""")
|
||||
|
||||
# Starting the instance
|
||||
inst = start_instance(
|
||||
key_name, ami=ami, instance_type=args.instance_type, ebs_size=args.ebs_size
|
||||
)
|
||||
instance_name = f"{args.key_name}-{args.os}"
|
||||
if args.python_version is not None:
|
||||
instance_name += f"-py{args.python_version}"
|
||||
inst.create_tags(
|
||||
DryRun=False,
|
||||
Tags=[
|
||||
{
|
||||
"Key": "Name",
|
||||
"Value": instance_name,
|
||||
}
|
||||
],
|
||||
)
|
||||
addr = inst.public_dns_name
|
||||
wait_for_connection(addr, 22)
|
||||
host = RemoteHost(addr, keyfile_path)
|
||||
host.ami = ami
|
||||
if args.use_docker:
|
||||
update_apt_repo(host)
|
||||
host.start_docker()
|
||||
|
||||
if args.test_only:
|
||||
run_tests(host, args.test_only)
|
||||
sys.exit(0)
|
||||
|
||||
if args.alloc_instance:
|
||||
if args.python_version is None:
|
||||
sys.exit(0)
|
||||
install_condaforge_python(host, args.python_version)
|
||||
sys.exit(0)
|
||||
|
||||
python_version = args.python_version if args.python_version is not None else "3.10"
|
||||
|
||||
if args.use_torch_from_pypi:
|
||||
configure_system(host, compiler=args.compiler, python_version=python_version)
|
||||
print("Installing PyTorch wheel")
|
||||
host.run_cmd("pip3 install torch")
|
||||
build_domains(
|
||||
host, branch=args.branch, git_clone_flags=" --depth 1 --shallow-submodules"
|
||||
)
|
||||
else:
|
||||
start_build(
|
||||
host,
|
||||
branch=args.branch,
|
||||
compiler=args.compiler,
|
||||
python_version=python_version,
|
||||
pytorch_only=args.pytorch_only,
|
||||
pytorch_build_number=args.pytorch_build_number,
|
||||
enable_mkldnn=not args.disable_mkldnn,
|
||||
)
|
||||
if not args.keep_running:
|
||||
print(f"Waiting for instance {inst.id} to terminate")
|
||||
inst.terminate()
|
||||
inst.wait_until_terminated()
|
||||
@ -1,87 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from subprocess import check_call
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from auditwheel.elfutils import elf_file_filter
|
||||
from auditwheel.lddtree import lddtree
|
||||
from auditwheel.patcher import Patchelf
|
||||
from auditwheel.repair import copylib
|
||||
from auditwheel.wheeltools import InWheelCtx
|
||||
|
||||
|
||||
def replace_tag(filename):
|
||||
with open(filename) as f:
|
||||
lines = f.read().split("\\n")
|
||||
for i, line in enumerate(lines):
|
||||
if not line.startswith("Tag: "):
|
||||
continue
|
||||
lines[i] = line.replace("-linux_", "-manylinux2014_")
|
||||
print(f"Updated tag from {line} to {lines[i]}")
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("\\n".join(lines))
|
||||
|
||||
|
||||
class AlignedPatchelf(Patchelf):
|
||||
def set_soname(self, file_name: str, new_soname: str) -> None:
|
||||
check_call(
|
||||
["patchelf", "--page-size", "65536", "--set-soname", new_soname, file_name]
|
||||
)
|
||||
|
||||
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
|
||||
check_call(
|
||||
[
|
||||
"patchelf",
|
||||
"--page-size",
|
||||
"65536",
|
||||
"--replace-needed",
|
||||
soname,
|
||||
new_soname,
|
||||
file_name,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def embed_library(whl_path, lib_soname, update_tag=False):
|
||||
patcher = AlignedPatchelf()
|
||||
out_dir = TemporaryDirectory()
|
||||
whl_name = os.path.basename(whl_path)
|
||||
tmp_whl_name = os.path.join(out_dir.name, whl_name)
|
||||
with InWheelCtx(whl_path) as ctx:
|
||||
torchlib_path = os.path.join(ctx._tmpdir.name, "torch", "lib")
|
||||
ctx.out_wheel = tmp_whl_name
|
||||
new_lib_path, new_lib_soname = None, None
|
||||
for filename, _ in elf_file_filter(ctx.iter_files()):
|
||||
if not filename.startswith("torch/lib"):
|
||||
continue
|
||||
libtree = lddtree(filename)
|
||||
if lib_soname not in libtree["needed"]:
|
||||
continue
|
||||
lib_path = libtree["libs"][lib_soname]["path"]
|
||||
if lib_path is None:
|
||||
print(f"Can't embed {lib_soname} as it could not be found")
|
||||
break
|
||||
if lib_path.startswith(torchlib_path):
|
||||
continue
|
||||
|
||||
if new_lib_path is None:
|
||||
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
|
||||
patcher.replace_needed(filename, lib_soname, new_lib_soname)
|
||||
print(f"Replacing {lib_soname} with {new_lib_soname} for {filename}")
|
||||
if update_tag:
|
||||
# Add manylinux2014 tag
|
||||
for filename in ctx.iter_files():
|
||||
if os.path.basename(filename) != "WHEEL":
|
||||
continue
|
||||
replace_tag(filename)
|
||||
shutil.move(tmp_whl_name, whl_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
embed_library(
|
||||
sys.argv[1], "libgomp.so.1", len(sys.argv) > 2 and sys.argv[2] == "--update-tag"
|
||||
)
|
||||
@ -4,14 +4,17 @@ set -ex
|
||||
|
||||
SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||
|
||||
# Source the common build script for architecture-specific configurations (MKLDNN, ACL, etc.)
|
||||
source "${SCRIPTPATH}/../pytorch/build.sh" || true
|
||||
|
||||
case "${GPU_ARCH_TYPE:-BLANK}" in
|
||||
cuda)
|
||||
cuda | cuda-aarch64)
|
||||
bash "${SCRIPTPATH}/build_cuda.sh"
|
||||
;;
|
||||
rocm)
|
||||
bash "${SCRIPTPATH}/build_rocm.sh"
|
||||
;;
|
||||
cpu | cpu-cxx11-abi | cpu-s390x)
|
||||
cpu | cpu-cxx11-abi | cpu-aarch64 | cpu-s390x)
|
||||
bash "${SCRIPTPATH}/build_cpu.sh"
|
||||
;;
|
||||
xpu)
|
||||
|
||||
@ -18,12 +18,31 @@ retry () {
|
||||
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
|
||||
}
|
||||
|
||||
# Detect architecture first
|
||||
ARCH=$(uname -m)
|
||||
echo "Detected architecture: $ARCH"
|
||||
|
||||
PLATFORM=""
|
||||
# TODO move this into the Docker images
|
||||
OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release)
|
||||
if [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
|
||||
retry yum install -q -y zip openssl
|
||||
PLATFORM="manylinux_2_28_x86_64"
|
||||
# Set platform based on architecture
|
||||
case $ARCH in
|
||||
x86_64)
|
||||
PLATFORM="manylinux_2_28_x86_64"
|
||||
;;
|
||||
aarch64)
|
||||
PLATFORM="manylinux_2_28_aarch64"
|
||||
;;
|
||||
s390x)
|
||||
PLATFORM="manylinux_2_28_s390x"
|
||||
;;
|
||||
*)
|
||||
echo "Unsupported architecture: $ARCH"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
|
||||
retry dnf install -q -y zip openssl
|
||||
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
|
||||
@ -38,6 +57,8 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Platform set to: $PLATFORM"
|
||||
|
||||
# We use the package name to test the package by passing this to 'pip install'
|
||||
# This is the env variable that setup.py uses to name the package. Note that
|
||||
# pip 'normalizes' the name first by changing all - to _
|
||||
@ -299,8 +320,8 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
|
||||
# ROCm workaround for roctracer dlopens
|
||||
if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then
|
||||
patchedpath=$(fname_without_so_number $destpath)
|
||||
# Keep the so number for XPU dependencies and libgomp.so.1 to avoid twice load
|
||||
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" ]]; then
|
||||
# Keep the so number for XPU dependencies, libgomp.so.1, ACL libraries, and NVPL libraries to avoid twice load
|
||||
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" || "$filename" == libarm_compute* || "$filename" == libnvpl* || "$filename" == "libgfortran.so.5" ]]; then
|
||||
patchedpath=$destpath
|
||||
else
|
||||
patchedpath=$(fname_with_sha256 $destpath)
|
||||
@ -346,9 +367,22 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
|
||||
done
|
||||
|
||||
# create Manylinux 2_28 tag this needs to happen before regenerate the RECORD
|
||||
if [[ $PLATFORM == "manylinux_2_28_x86_64" && $GPU_ARCH_TYPE != "cpu-s390x" && $GPU_ARCH_TYPE != "xpu" ]]; then
|
||||
# Support all architectures (x86_64, aarch64, s390x)
|
||||
if [[ "$IS_MANYLINUX2_28" == "1" && $GPU_ARCH_TYPE != "xpu" ]]; then
|
||||
wheel_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/WHEEL/g')
|
||||
sed -i -e s#linux_x86_64#"${PLATFORM}"# $wheel_file;
|
||||
echo "Updating wheel tag for $ARCH architecture"
|
||||
# Replace linux_* with manylinux_2_28_* based on architecture
|
||||
case $ARCH in
|
||||
x86_64)
|
||||
sed -i -e 's#linux_x86_64#manylinux_2_28_x86_64#g' $wheel_file
|
||||
;;
|
||||
aarch64)
|
||||
sed -i -e 's#linux_aarch64#manylinux_2_28_aarch64#g' $wheel_file
|
||||
;;
|
||||
s390x)
|
||||
sed -i -e 's#linux_s390x#manylinux_2_28_s390x#g' $wheel_file
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# regenerate the RECORD file with new hashes
|
||||
|
||||
@ -15,6 +15,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS=()
|
||||
fi
|
||||
|
||||
# Detect architecture
|
||||
ARCH=$(uname -m)
|
||||
echo "Building CPU wheel for architecture: $ARCH"
|
||||
|
||||
WHEELHOUSE_DIR="wheelhousecpu"
|
||||
LIBTORCH_HOUSE_DIR="libtorch_housecpu"
|
||||
if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then
|
||||
@ -34,8 +38,10 @@ elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
|
||||
elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
|
||||
LIBGOMP_PATH="/usr/lib64/libgomp.so.1"
|
||||
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
|
||||
if [[ "$(uname -m)" == "s390x" ]]; then
|
||||
if [[ "$ARCH" == "s390x" ]]; then
|
||||
LIBGOMP_PATH="/usr/lib/s390x-linux-gnu/libgomp.so.1"
|
||||
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||
LIBGOMP_PATH="/usr/lib/aarch64-linux-gnu/libgomp.so.1"
|
||||
else
|
||||
LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1"
|
||||
fi
|
||||
@ -49,6 +55,34 @@ DEPS_SONAME=(
|
||||
"libgomp.so.1"
|
||||
)
|
||||
|
||||
# Add ARM-specific library dependencies for CPU builds
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Adding ARM-specific CPU library dependencies"
|
||||
|
||||
# ARM Compute Library (if available)
|
||||
if [[ -d "/acl/build" ]]; then
|
||||
echo "Adding ARM Compute Library for CPU"
|
||||
DEPS_LIST+=(
|
||||
"/acl/build/libarm_compute.so"
|
||||
"/acl/build/libarm_compute_graph.so"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libarm_compute.so"
|
||||
"libarm_compute_graph.so"
|
||||
)
|
||||
fi
|
||||
|
||||
# ARM system libraries
|
||||
DEPS_LIST+=(
|
||||
"/usr/lib64/libgfortran.so.5"
|
||||
"/opt/OpenBLAS/lib/libopenblas.so.0"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libgfortran.so.5"
|
||||
"libopenblas.so.0"
|
||||
)
|
||||
fi
|
||||
|
||||
rm -rf /usr/local/cuda*
|
||||
|
||||
SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )"
|
||||
|
||||
@ -29,6 +29,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS=()
|
||||
fi
|
||||
|
||||
# Detect architecture
|
||||
ARCH=$(uname -m)
|
||||
echo "Building for architecture: $ARCH"
|
||||
|
||||
# Determine CUDA version and architectures to build for
|
||||
#
|
||||
# NOTE: We should first check `DESIRED_CUDA` when determining `CUDA_VERSION`,
|
||||
@ -53,34 +57,60 @@ fi
|
||||
cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
|
||||
|
||||
# Function to remove architectures from a list
|
||||
remove_archs() {
|
||||
local result="$1"
|
||||
shift
|
||||
for arch in "$@"; do
|
||||
result="${result//${arch};/}"
|
||||
done
|
||||
echo "$result"
|
||||
}
|
||||
|
||||
# Function to filter CUDA architectures for aarch64
|
||||
# aarch64 ARM GPUs only support certain compute capabilities
|
||||
# Keep: 8.0 (A100), 9.0+ (Hopper, Grace Hopper, newer)
|
||||
# Remove: < 8.0 (no ARM GPUs), 8.6 (x86_64 RTX 3090/A6000 only)
|
||||
filter_aarch64_archs() {
|
||||
local arch_list="$1"
|
||||
# Explicitly remove architectures not needed on aarch64
|
||||
arch_list=$(remove_archs "$arch_list" "5.0" "6.0" "7.0" "7.5" "8.6")
|
||||
echo "$arch_list"
|
||||
}
|
||||
|
||||
# Base: Common architectures across all modern CUDA versions
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0"
|
||||
|
||||
case ${CUDA_VERSION} in
|
||||
#removing sm_50-sm_60 as these architectures are deprecated in CUDA 12.8/9 and will be removed in future releases
|
||||
#however we would like to keep sm_70 architecture see: https://github.com/pytorch/pytorch/issues/157517
|
||||
12.8)
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0"
|
||||
;;
|
||||
12.9)
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0+PTX"
|
||||
# WAR to resolve the ld error in libtorch build with CUDA 12.9
|
||||
12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;${TORCH_CUDA_ARCH_LIST}" ;; # Only 12.6 includes Legacy Maxwell/Pascal that will be removed in future releases
|
||||
12.8) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0" ;; # +Hopper/Blackwell support
|
||||
12.9) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0+PTX" # +Hopper/Blackwell support + PTX for forward compatibility
|
||||
if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX"
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//7.0;/}" # Remove 7.0 to resolve the ld error
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//8.6;/}" # Remove 8.6 for libtorch
|
||||
fi
|
||||
;;
|
||||
13.0)
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX"
|
||||
;;
|
||||
12.6)
|
||||
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0"
|
||||
;;
|
||||
*)
|
||||
echo "unknown cuda version $CUDA_VERSION"
|
||||
exit 1
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;$([[ "$ARCH" == "aarch64" ]] && echo "11.0;" || echo "")12.0+PTX"
|
||||
export TORCH_NVCC_FLAGS="-compress-mode=size"
|
||||
export BUILD_BUNDLE_PTXAS=1
|
||||
;;
|
||||
*) echo "unknown cuda version $CUDA_VERSION"; exit 1 ;;
|
||||
esac
|
||||
|
||||
# Filter for aarch64: Remove < 8.0 and 8.6
|
||||
[[ "$ARCH" == "aarch64" ]] && TORCH_CUDA_ARCH_LIST=$(filter_aarch64_archs "$TORCH_CUDA_ARCH_LIST")
|
||||
|
||||
echo "TORCH_CUDA_ARCH_LIST set to: $TORCH_CUDA_ARCH_LIST"
|
||||
export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}
|
||||
echo "${TORCH_CUDA_ARCH_LIST}"
|
||||
|
||||
# Disable MAGMA for aarch64 as pre-built libraries are x86-64 only
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Disabling MAGMA for aarch64 architecture"
|
||||
export USE_MAGMA=0
|
||||
fi
|
||||
|
||||
# Package directories
|
||||
WHEELHOUSE_DIR="wheelhouse$cuda_version_nodot"
|
||||
LIBTORCH_HOUSE_DIR="libtorch_house$cuda_version_nodot"
|
||||
@ -244,6 +274,51 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Add ARM-specific library dependencies
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Adding ARM-specific library dependencies"
|
||||
|
||||
# ARM Compute Library (if available)
|
||||
if [[ -d "/acl/build" ]]; then
|
||||
echo "Adding ARM Compute Library"
|
||||
DEPS_LIST+=(
|
||||
"/acl/build/libarm_compute.so"
|
||||
"/acl/build/libarm_compute_graph.so"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libarm_compute.so"
|
||||
"libarm_compute_graph.so"
|
||||
)
|
||||
fi
|
||||
|
||||
# ARM system libraries
|
||||
DEPS_LIST+=(
|
||||
"/lib64/libgomp.so.1"
|
||||
"/usr/lib64/libgfortran.so.5"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libgomp.so.1"
|
||||
"libgfortran.so.5"
|
||||
)
|
||||
|
||||
# NVPL libraries (ARM optimized BLAS/LAPACK)
|
||||
if [[ -d "/usr/local/lib" && -f "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ]]; then
|
||||
echo "Adding NVPL libraries for ARM"
|
||||
DEPS_LIST+=(
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0"
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0"
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0"
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libnvpl_lapack_lp64_gomp.so.0"
|
||||
"libnvpl_blas_lp64_gomp.so.0"
|
||||
"libnvpl_lapack_core.so.0"
|
||||
"libnvpl_blas_core.so.0"
|
||||
)
|
||||
fi
|
||||
fi
|
||||
|
||||
# run_tests.sh requires DESIRED_CUDA to know what tests to exclude
|
||||
export DESIRED_CUDA="$cuda_version_nodot"
|
||||
|
||||
@ -251,9 +326,11 @@ export DESIRED_CUDA="$cuda_version_nodot"
|
||||
rm -rf /usr/local/cuda || true
|
||||
ln -s "/usr/local/cuda-${CUDA_VERSION}" /usr/local/cuda
|
||||
|
||||
# Switch `/usr/local/magma` to the desired CUDA version
|
||||
rm -rf /usr/local/magma || true
|
||||
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
|
||||
# Switch `/usr/local/magma` to the desired CUDA version (skip for aarch64)
|
||||
if [[ "$ARCH" != "aarch64" ]]; then
|
||||
rm -rf /usr/local/magma || true
|
||||
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
|
||||
fi
|
||||
|
||||
export CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) # 10.0.130
|
||||
export CUDA_VERSION_SHORT=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev | cut -f1,2 -d".") # 10.0
|
||||
|
||||
@ -21,87 +21,3 @@ if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]]; then
|
||||
fi
|
||||
|
||||
mkdir -p "$pytest_reports_dir" || true
|
||||
|
||||
##########################################
|
||||
# copied from .ci/pytorch/common_utils.sh
|
||||
##########################################
|
||||
|
||||
function get_pinned_commit() {
|
||||
cat .github/ci_commit_pins/"${1}".txt
|
||||
}
|
||||
|
||||
function pip_install_whl() {
|
||||
# This is used to install PyTorch and other build artifacts wheel locally
|
||||
# without using any network connection
|
||||
|
||||
# Convert the input arguments into an array
|
||||
local args=("$@")
|
||||
|
||||
# Check if the first argument contains multiple paths separated by spaces
|
||||
if [[ "${args[0]}" == *" "* ]]; then
|
||||
# Split the string by spaces into an array
|
||||
IFS=' ' read -r -a paths <<< "${args[0]}"
|
||||
# Loop through each path and install individually
|
||||
for path in "${paths[@]}"; do
|
||||
echo "Installing $path"
|
||||
python3 -mpip install --no-index --no-deps "$path"
|
||||
done
|
||||
else
|
||||
# Loop through each argument and install individually
|
||||
for path in "${args[@]}"; do
|
||||
echo "Installing $path"
|
||||
python3 -mpip install --no-index --no-deps "$path"
|
||||
done
|
||||
fi
|
||||
}
|
||||
|
||||
function pip_build_and_install() {
|
||||
local build_target=$1
|
||||
local wheel_dir=$2
|
||||
|
||||
local found_whl=0
|
||||
for file in "${wheel_dir}"/*.whl
|
||||
do
|
||||
if [[ -f "${file}" ]]; then
|
||||
found_whl=1
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# Build the wheel if it doesn't exist
|
||||
if [ "${found_whl}" == "0" ]; then
|
||||
python3 -m pip wheel \
|
||||
--no-build-isolation \
|
||||
--no-deps \
|
||||
-w "${wheel_dir}" \
|
||||
"${build_target}"
|
||||
fi
|
||||
|
||||
for file in "${wheel_dir}"/*.whl
|
||||
do
|
||||
pip_install_whl "${file}"
|
||||
done
|
||||
}
|
||||
|
||||
function install_torchvision() {
|
||||
local orig_preload
|
||||
local commit
|
||||
commit=$(get_pinned_commit vision)
|
||||
orig_preload=${LD_PRELOAD}
|
||||
if [ -n "${LD_PRELOAD}" ]; then
|
||||
# Silence dlerror to work-around glibc ASAN bug, see https://sourceware.org/bugzilla/show_bug.cgi?id=27653#c9
|
||||
echo 'char* dlerror(void) { return "";}'|gcc -fpic -shared -o "${HOME}/dlerror.so" -x c -
|
||||
LD_PRELOAD=${orig_preload}:${HOME}/dlerror.so
|
||||
fi
|
||||
|
||||
if [[ "${BUILD_ENVIRONMENT}" == *cuda* ]]; then
|
||||
# Not sure if both are needed, but why not
|
||||
export FORCE_CUDA=1
|
||||
export WITH_CUDA=1
|
||||
fi
|
||||
pip_build_and_install "git+https://github.com/pytorch/vision.git@${commit}" dist/vision
|
||||
|
||||
if [ -n "${LD_PRELOAD}" ]; then
|
||||
LD_PRELOAD=${orig_preload}
|
||||
fi
|
||||
}
|
||||
|
||||
@ -19,7 +19,7 @@ git config --global --add safe.directory /var/lib/jenkins/workspace
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
|
||||
# TODO: This can be removed later once vision is also part of the Docker image
|
||||
install_torchvision
|
||||
pip install -q --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)"
|
||||
# JIT C++ extensions require ninja, so put it into PATH.
|
||||
export PATH="/var/lib/jenkins/.local/bin:$PATH"
|
||||
# NB: ONNX test is fast (~15m) so it's ok to retry it few more times to avoid any flaky issue, we
|
||||
|
||||
@ -86,10 +86,20 @@ else
|
||||
fi
|
||||
fi
|
||||
|
||||
# Enable MKLDNN with ARM Compute Library for ARM builds
|
||||
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
|
||||
export USE_MKLDNN=1
|
||||
|
||||
# ACL is required for aarch64 builds
|
||||
if [[ ! -d "/acl" ]]; then
|
||||
echo "ERROR: ARM Compute Library not found at /acl"
|
||||
echo "ACL is required for aarch64 builds. Check Docker image setup."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export USE_MKLDNN_ACL=1
|
||||
export ACL_ROOT_DIR=/acl
|
||||
echo "ARM Compute Library enabled for MKLDNN: ACL_ROOT_DIR=/acl"
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then
|
||||
|
||||
7
.github/workflows/_binary-build-linux.yml
vendored
7
.github/workflows/_binary-build-linux.yml
vendored
@ -260,11 +260,8 @@ jobs:
|
||||
"${DOCKER_IMAGE}"
|
||||
)
|
||||
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
|
||||
if [[ ${BUILD_ENVIRONMENT} == *"aarch64"* ]]; then
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/aarch64_linux/aarch64_ci_build.sh"
|
||||
else
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
|
||||
fi
|
||||
# Unified build script for all architectures (x86_64, aarch64, s390x)
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
|
||||
|
||||
- name: Chown artifacts
|
||||
if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }}
|
||||
|
||||
@ -22,6 +22,7 @@ enum class MacOSVersion : uint32_t {
|
||||
MACOS_VER_15_0_PLUS,
|
||||
MACOS_VER_15_1_PLUS,
|
||||
MACOS_VER_15_2_PLUS,
|
||||
MACOS_VER_26_0_PLUS,
|
||||
};
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
|
||||
@ -65,6 +65,7 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
|
||||
static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
|
||||
static bool _macos_15_1_plus = is_os_version_at_least(15, 1);
|
||||
static bool _macos_15_2_plus = is_os_version_at_least(15, 2);
|
||||
static bool _macos_26_0_plus = is_os_version_at_least(26, 0);
|
||||
|
||||
switch (version) {
|
||||
case MacOSVersion::MACOS_VER_14_4_PLUS:
|
||||
@ -75,6 +76,8 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
|
||||
return _macos_15_1_plus;
|
||||
case MacOSVersion::MACOS_VER_15_2_PLUS:
|
||||
return _macos_15_2_plus;
|
||||
case MacOSVersion::MACOS_VER_26_0_PLUS:
|
||||
return _macos_26_0_plus;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/xpu/Blas.h>
|
||||
#include <ATen/xpu/XPUScaledBlas.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
@ -340,399 +339,4 @@ Tensor _scaled_mm_xpu(
|
||||
out);
|
||||
}
|
||||
|
||||
using acceptance_fn = std::function<bool(
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&)>;
|
||||
using namespace std::placeholders;
|
||||
|
||||
namespace scaled_blas = at::native::onednn::scaled;
|
||||
using scaled_blas::convert_int_to_enum;
|
||||
using scaled_blas::ScaledGemmImplementation;
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2>
|
||||
scale_kernel_dispatch = {{
|
||||
{"tensorwise_tensorwise",
|
||||
scaled_blas::check_tensorwise_recipe,
|
||||
ScaledGemmImplementation::TENSORWISE_TENSORWISE},
|
||||
{"rowwise_rowwise",
|
||||
scaled_blas::check_rowwise_recipe,
|
||||
ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
|
||||
}};
|
||||
|
||||
Tensor& _scaled_tensorwise_tensorwise(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(),
|
||||
mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.numel() == 1 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have 1 Float element")
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.numel() == 1 && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have 1 Float element")
|
||||
|
||||
auto scaling_choice_a = ScalingType::TensorWise;
|
||||
auto scaling_choice_b = ScalingType::TensorWise;
|
||||
|
||||
_scaled_gemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor& _scaled_rowwise_rowwise(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape M/N for A/B
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(),
|
||||
mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == mat_a.size(0) && scale_a.size(1) == 1,
|
||||
"scale_a must have shape [",
|
||||
mat_a.size(0),
|
||||
", 1], got [",
|
||||
scale_a.sizes(),
|
||||
"]");
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.numel() == mat_a.size(0) && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have ",
|
||||
mat_a.size(0),
|
||||
" Float elements, got ",
|
||||
scale_a.numel())
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.numel() == mat_b.size(1) && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have ",
|
||||
mat_b.size(1),
|
||||
" Float elements, got ",
|
||||
scale_b.numel())
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(1) == 1,
|
||||
"expected scale_a.stride(1) to be 1, but got ",
|
||||
scale_a.stride(1));
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(1) == 1,
|
||||
"expected scale_b.stride(1) to be 1, but got ",
|
||||
scale_b.stride(1));
|
||||
|
||||
auto scaling_choice_a = ScalingType::RowWise;
|
||||
auto scaling_choice_b = ScalingType::RowWise;
|
||||
|
||||
_scaled_gemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// V2: Computes matrix multiply + bias while applying scaling to input and
|
||||
// output matrices Scales are only applicable when matrices are of Float8 type
|
||||
// and assumed to be equal to 1.0 by default. If output matrix type is 16 or
|
||||
// 32-bit type, scale_result is not applied. Known limitations:
|
||||
// - Only works if mat1 is row-major and mat2 is column-major
|
||||
// - Only works if matrices sizes are divisible by 32
|
||||
// - If 1-dimensional tensors are used then scale_a should be size =
|
||||
// mat1.size(0)
|
||||
// and scale_b should have size = to mat2.size(1)
|
||||
// Arguments:
|
||||
// - `mat_a`: the first operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `mat_b`: the second operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_recipe_a`: An integer corresponding to an enum describing the
|
||||
// scaling scheme used for `scale_a`
|
||||
// - `swizzle_a`: An integer corresponding to a `SwizzleType` enum describing
|
||||
// the swizzling scheme for `scale_a`.
|
||||
// Not supported for XPU for now.
|
||||
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_recipe_b`: An integer corresponding to an enum describing the
|
||||
// scaling scheme used for `scale_b`
|
||||
// - `swizzle_b`: An integer corresponding to a `SwizzleType` enum describing
|
||||
// the swizzling scheme for `scale_b`.
|
||||
// Not supported for XPU for now.
|
||||
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
|
||||
// - `out_dtype`: the output dtype, can either be a float8 or a higher
|
||||
// precision floating point type
|
||||
// - `contraction_dim`: describe which dimensions are `K` in the matmul.
|
||||
// Not supported for XPU. Should always be empty.
|
||||
// - `use_fast_accum`: Not supported for XPU, should always be false.
|
||||
// - `out`: a reference to the output tensor
|
||||
Tensor& _scaled_mm_xpu_v2_out(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
ArrayRef<Tensor> scale_a,
|
||||
IntArrayRef scale_recipe_a,
|
||||
IntArrayRef swizzle_a,
|
||||
ArrayRef<Tensor> scale_b,
|
||||
IntArrayRef scale_recipe_b,
|
||||
IntArrayRef swizzle_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
IntArrayRef contraction_dim,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2, "mat_a must be a matrix");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2, "mat_b must be a matrix");
|
||||
|
||||
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
|
||||
// kernels do not support this case).
|
||||
if (mat_a.size(0) == 0 || mat_a.size(1) == 0 || mat_b.size(1) == 0) {
|
||||
// `out` was created with `at::empty`. In the case where we are multiplying
|
||||
// MxK by KxN and K is the zero dim, we need to initialize here to properly
|
||||
// return a tensor of zeros.
|
||||
at::native::resize_output(out, {mat_a.size(0), mat_b.size(1)});
|
||||
if (mat_a.size(1) == 0) {
|
||||
out.zero_();
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// Note: The `contraction_dim` is not actually used for now. We will need to
|
||||
// align this code when upstreamed CUDA code is done. Currently, only keeps
|
||||
// the code here for check.
|
||||
|
||||
// Check if the input matrix sizes can be multiplied
|
||||
// - if optional contraction dims are provided, use those
|
||||
// -- mostly for < 1B formats (i.e. nvfp4x2) where cheap .t() is not
|
||||
// available.
|
||||
if (contraction_dim.size() > 0) {
|
||||
TORCH_CHECK_VALUE(
|
||||
contraction_dim.size() == 2,
|
||||
"contraction_dim must have exactly 2 elements");
|
||||
auto mat_a_dim = contraction_dim[0];
|
||||
auto mat_b_dim = contraction_dim[1];
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(mat_a_dim) == mat_b.size(mat_b_dim),
|
||||
"mat_a and mat_b shapes cannot be multiplied (",
|
||||
mat_a.size(0),
|
||||
"x",
|
||||
mat_a.size(1),
|
||||
" and ",
|
||||
mat_b.size(0),
|
||||
"x",
|
||||
mat_b.size(1),
|
||||
") ",
|
||||
"with contraction dims mat_a: ",
|
||||
mat_a_dim,
|
||||
", mat_b: ",
|
||||
mat_b_dim);
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(1) == mat_b.size(0),
|
||||
"mat_a and mat_b shapes cannot be multiplied (",
|
||||
mat_a.size(0),
|
||||
"x",
|
||||
mat_a.size(1),
|
||||
" and ",
|
||||
mat_b.size(0),
|
||||
"x",
|
||||
mat_b.size(1),
|
||||
")");
|
||||
}
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
!bias || bias->numel() == mat_b.sizes()[1],
|
||||
"Bias must be size ",
|
||||
mat_b.sizes()[1],
|
||||
" but got ",
|
||||
bias->numel());
|
||||
|
||||
TORCH_CHECK_VALUE(
|
||||
!out_dtype || *out_dtype == out.scalar_type(),
|
||||
"out_dtype must match output matrix type");
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK_VALUE(
|
||||
bias->scalar_type() == kFloat ||
|
||||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
|
||||
bias->scalar_type() == c10::ScalarType::Half,
|
||||
"Bias must be Float32 or BFloat16 or Half, but got ",
|
||||
bias->scalar_type());
|
||||
}
|
||||
{
|
||||
auto bias_ = bias.value_or(Tensor());
|
||||
// NOLINTNEXTLINE(*c-array*)
|
||||
TensorArg targs[]{
|
||||
{out, "out", 0},
|
||||
{mat_a, "mat_a", 1},
|
||||
{mat_b, "mat_b", 2},
|
||||
{bias_, "bias", 3},
|
||||
{scale_a[0], "scale_a", 4},
|
||||
{scale_b[0], "scale_b", 5}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
}
|
||||
// Align with CUDA's default out to be bf16
|
||||
auto out_dtype_ = out_dtype.value_or(c10::ScalarType::BFloat16);
|
||||
|
||||
// Conversion of implicitly-defined enums to explicit
|
||||
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
|
||||
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
|
||||
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
|
||||
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
|
||||
|
||||
// XPU does not support swizzle for now. So directly return false.
|
||||
TORCH_CHECK_VALUE(
|
||||
swizzle_a_enum[0] == at::blas::SwizzleType::NO_SWIZZLE &&
|
||||
swizzle_b_enum[0] == at::blas::SwizzleType::NO_SWIZZLE,
|
||||
"XPU does not support swizzle yet.");
|
||||
|
||||
// at this point we can start working out what we want to be doing
|
||||
// Try to do as few steps as possible.
|
||||
// NOTE: support is deliberately sparse, can explicitly enumerate all
|
||||
// combinations allowed. Do this via a list of defined (name, acceptance,
|
||||
// concrete_impl) tuples.
|
||||
bool found_impl = false;
|
||||
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
|
||||
|
||||
for (const auto& fn_entry : scale_kernel_dispatch) {
|
||||
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
|
||||
bool ok = accept_fn(
|
||||
mat_a.scalar_type(),
|
||||
scale_recipe_a_enum,
|
||||
scale_a,
|
||||
mat_b.scalar_type(),
|
||||
scale_recipe_b_enum,
|
||||
scale_b);
|
||||
if (ok) {
|
||||
gemm_impl = scaled_gemm_impl;
|
||||
found_impl = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
found_impl,
|
||||
"Invalid scaling configuration.\n"
|
||||
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
|
||||
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
|
||||
mat_a.size(0),
|
||||
", 1) and scale_b should be (1, ",
|
||||
mat_b.size(1),
|
||||
"), and both should be contiguous.\n"
|
||||
"Got mat_a.dtype()=",
|
||||
mat_a.scalar_type(),
|
||||
", scale_a[0].dtype()=",
|
||||
scale_a[0].scalar_type(),
|
||||
", scale_a[0].size()=",
|
||||
scale_a[0].sizes(),
|
||||
", scale_a[0].stride()=",
|
||||
scale_a[0].strides(),
|
||||
", ",
|
||||
"mat_b.dtype()=",
|
||||
mat_b.scalar_type(),
|
||||
", scale_b[0].dtype()=",
|
||||
scale_b[0].scalar_type(),
|
||||
", scale_b[0].size()=",
|
||||
scale_b[0].sizes(),
|
||||
" and scale_b[0].stride()=",
|
||||
scale_b[0].strides());
|
||||
|
||||
at::native::resize_output(out, {mat_a.size(0), mat_b.size(1)});
|
||||
|
||||
auto bias_ = bias.value_or(Tensor());
|
||||
|
||||
// dispatch to appropriate lower-level calls for error checking & execution
|
||||
if (gemm_impl == ScaledGemmImplementation::TENSORWISE_TENSORWISE) {
|
||||
return _scaled_tensorwise_tensorwise(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
scale_b[0],
|
||||
bias,
|
||||
out_dtype_,
|
||||
use_fast_accum,
|
||||
out);
|
||||
} else if (gemm_impl == ScaledGemmImplementation::ROWWISE_ROWWISE) {
|
||||
return _scaled_rowwise_rowwise(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
scale_b[0],
|
||||
bias,
|
||||
out_dtype_,
|
||||
use_fast_accum,
|
||||
out);
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(
|
||||
false, "Invalid state - found an implementation, but not really");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _scaled_mm_xpu_v2(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
ArrayRef<Tensor> scale_a,
|
||||
IntArrayRef scale_recipe_a,
|
||||
IntArrayRef swizzle_a,
|
||||
ArrayRef<Tensor> scale_b,
|
||||
IntArrayRef scale_recipe_b,
|
||||
IntArrayRef swizzle_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
IntArrayRef contraction_dim,
|
||||
bool use_fast_accum) {
|
||||
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
|
||||
|
||||
return _scaled_mm_xpu_v2_out(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_recipe_a,
|
||||
swizzle_a,
|
||||
scale_b,
|
||||
scale_recipe_b,
|
||||
swizzle_b,
|
||||
bias,
|
||||
out_dtype,
|
||||
contraction_dim,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -69,75 +69,139 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
||||
auto out = at::empty({batchSize, num_head, qSize, headSize}, query.options());
|
||||
auto attn = at::empty({batchSize, num_head, qSize, maxSeqLength}, query.options());
|
||||
auto scale_factor = sdp::calculate_scale(query, scale).expect_float();
|
||||
static const bool is_macOS_26_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_26_0_PLUS);
|
||||
@autoreleasepool {
|
||||
auto mkey = __func__ + getTensorsStringKey({query, key, value}) + ":" + std::to_string(is_causal) + ":" +
|
||||
std::to_string(attn_mask.has_value());
|
||||
auto cachedGraph =
|
||||
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
|
||||
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
|
||||
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
|
||||
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
|
||||
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
|
||||
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor
|
||||
shape:getMPSShape({1})
|
||||
dataType:MPSDataTypeFloat32];
|
||||
|
||||
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
|
||||
CachedGraph* cachedGraph;
|
||||
//if(is_macOS_26_0_or_newer) {
|
||||
if(true) {
|
||||
cachedGraph =
|
||||
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
|
||||
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
|
||||
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
|
||||
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
|
||||
|
||||
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
|
||||
// bug in MacOS15, without this trick SDPA leaks memory, adding 0.0f gets ignored(still takes SDPA sequence
|
||||
// path which leaks)
|
||||
auto oneTensor = [mpsGraph constantWithScalar:1e-20f shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:oneTensor name:nil];
|
||||
}
|
||||
if (is_causal) {
|
||||
MPSShape* maskShape = @[@(qSize), @(maxSeqLength)];
|
||||
auto x = [mpsGraph coordinateAlongAxis:-1 withShape:@[@(qSize), @1] name:nil];
|
||||
auto y = [mpsGraph coordinateAlongAxis:-2 withShape:@[@1, @(maxSeqLength)] name:nil];
|
||||
auto isLess = [mpsGraph lessThanOrEqualToWithPrimaryTensor:x secondaryTensor:y name:nil];
|
||||
auto causalMask = [mpsGraph selectWithPredicateTensor:isLess
|
||||
truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:qTensor.dataType]
|
||||
falsePredicateTensor:[mpsGraph constantWithScalar:-INFINITY dataType:qTensor.dataType]
|
||||
name:nil];
|
||||
graph->maskTensor = causalMask;
|
||||
} else if (attn_mask) {
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
|
||||
}
|
||||
|
||||
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
|
||||
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
|
||||
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
|
||||
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
|
||||
// Overwrites expected NANs in sm with zeros.
|
||||
// auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
// auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
|
||||
// auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
|
||||
// auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
|
||||
// auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
//
|
||||
// auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
|
||||
// MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
|
||||
// truePredicateTensor:zeroTensor
|
||||
// falsePredicateTensor:sm
|
||||
// name:nil];
|
||||
//
|
||||
// auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
|
||||
|
||||
if (is_causal) {
|
||||
auto causalMask = [mpsGraph constantWithScalar:1.0f
|
||||
shape:getMPSShape({qSize, maxSeqLength})
|
||||
dataType:MPSDataTypeBool];
|
||||
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
|
||||
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
|
||||
truePredicateTensor:maskedMM
|
||||
falsePredicateTensor:minusInf
|
||||
name:nil];
|
||||
} else if (attn_mask) {
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
|
||||
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
|
||||
name:nil];
|
||||
}
|
||||
MPSGraphTensor* output;
|
||||
if(graph->maskTensor != nil) {
|
||||
output = [mpsGraph scaledDotProductAttentionWithQueryTensor:qTensor
|
||||
keyTensor:kTensor
|
||||
valueTensor:vTensor
|
||||
maskTensor:graph->maskTensor
|
||||
scale:scale_factor
|
||||
name:@"MPSGraph SDPA"];
|
||||
} else {
|
||||
output = [mpsGraph scaledDotProductAttentionWithQueryTensor:qTensor
|
||||
keyTensor:kTensor
|
||||
valueTensor:vTensor
|
||||
scale:scale_factor
|
||||
name:@"MPSGraph SDPA"];
|
||||
}
|
||||
graph->qTensor = qTensor;
|
||||
graph->kTensor = kTensor;
|
||||
graph->vTensor = vTensor;
|
||||
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
|
||||
// graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
|
||||
});
|
||||
} else {
|
||||
cachedGraph =
|
||||
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
|
||||
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
|
||||
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
|
||||
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
|
||||
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
|
||||
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor
|
||||
shape:getMPSShape({1})
|
||||
dataType:MPSDataTypeFloat32];
|
||||
|
||||
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
|
||||
// Overwrites expected NANs in sm with zeros.
|
||||
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
|
||||
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
|
||||
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
|
||||
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
|
||||
|
||||
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
|
||||
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
|
||||
truePredicateTensor:zeroTensor
|
||||
falsePredicateTensor:sm
|
||||
name:nil];
|
||||
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
|
||||
// bug in MacOS15, without this trick SDPA leaks memory, adding 0.0f gets ignored(still takes SDPA sequence
|
||||
// path which leaks)
|
||||
auto oneTensor = [mpsGraph constantWithScalar:1e-20f shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:oneTensor name:nil];
|
||||
}
|
||||
|
||||
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
|
||||
graph->qTensor = qTensor;
|
||||
graph->kTensor = kTensor;
|
||||
graph->vTensor = vTensor;
|
||||
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
|
||||
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
|
||||
});
|
||||
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
|
||||
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
|
||||
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
|
||||
|
||||
if (is_causal) {
|
||||
auto causalMask = [mpsGraph constantWithScalar:1.0f
|
||||
shape:getMPSShape({qSize, maxSeqLength})
|
||||
dataType:MPSDataTypeBool];
|
||||
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
|
||||
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
|
||||
truePredicateTensor:maskedMM
|
||||
falsePredicateTensor:minusInf
|
||||
name:nil];
|
||||
} else if (attn_mask) {
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
|
||||
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
|
||||
name:nil];
|
||||
}
|
||||
|
||||
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
|
||||
// Overwrites expected NANs in sm with zeros.
|
||||
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
|
||||
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
|
||||
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
|
||||
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
|
||||
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
|
||||
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
|
||||
truePredicateTensor:zeroTensor
|
||||
falsePredicateTensor:sm
|
||||
name:nil];
|
||||
|
||||
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
|
||||
graph->qTensor = qTensor;
|
||||
graph->kTensor = kTensor;
|
||||
graph->vTensor = vTensor;
|
||||
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
|
||||
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
|
||||
});
|
||||
}
|
||||
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
|
||||
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);
|
||||
auto vPlaceholder = Placeholder(cachedGraph->vTensor, value);
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, out);
|
||||
auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
|
||||
// auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
|
||||
NSDictionary* feeds = nil;
|
||||
if (!attn_mask) {
|
||||
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder);
|
||||
@ -145,7 +209,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
||||
auto mPlaceholder = Placeholder(cachedGraph->maskTensor, *attn_mask);
|
||||
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder, mPlaceholder);
|
||||
}
|
||||
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
|
||||
// NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
|
||||
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder);
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outs);
|
||||
}
|
||||
|
||||
|
||||
@ -1,122 +0,0 @@
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <cstdint>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/GroupedMMUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/xpu/XPUScaledBlas.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
|
||||
namespace at::native::onednn::scaled {
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs a single scale, {Tensorwise (float)}
|
||||
*/
|
||||
bool check_tensorwise_recipe(
|
||||
c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 ||
|
||||
recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
// Need {Blockwise_1x32, e8m0} for A & B
|
||||
if (recipe_a[0] != ScalingType::TensorWise)
|
||||
return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float)
|
||||
return false;
|
||||
if (recipe_b[0] != ScalingType::TensorWise)
|
||||
return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Both inputs must be fp8,
|
||||
* Each needs scales, {Rowwise (float)}
|
||||
*/
|
||||
bool check_rowwise_recipe(
|
||||
c10::ScalarType type_a,
|
||||
std::vector<ScalingType>& recipe_a,
|
||||
ArrayRef<Tensor>& scales_a,
|
||||
c10::ScalarType type_b,
|
||||
std::vector<ScalingType>& recipe_b,
|
||||
ArrayRef<Tensor>& scales_b) {
|
||||
// both types must be fp8
|
||||
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1 scale each, {Tensorwise, float}
|
||||
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 ||
|
||||
recipe_b.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Need {RowWise, dp32} for A & B
|
||||
if (recipe_a[0] != ScalingType::RowWise)
|
||||
return false;
|
||||
if (scales_a[0].scalar_type() != ScalarType::Float)
|
||||
return false;
|
||||
if (recipe_b[0] != ScalingType::RowWise)
|
||||
return false;
|
||||
if (scales_b[0].scalar_type() != ScalarType::Float)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace at::native::onednn::scaled
|
||||
@ -1,95 +0,0 @@
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <cstdint>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
#include <fbgemm_gpu/torch_ops.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
|
||||
namespace at::native::onednn::scaled {
|
||||
|
||||
/**
|
||||
* Track concrete implementations available
|
||||
*/
|
||||
enum class ScaledGemmImplementation {
|
||||
NONE = 0,
|
||||
TENSORWISE_TENSORWISE = 1,
|
||||
ROWWISE_ROWWISE = 2,
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert passed int (enum) from python back into a
|
||||
* strictly-typed enum
|
||||
*/
|
||||
template <class EnumType, class ArrayType>
|
||||
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
|
||||
std::vector<EnumType> converted;
|
||||
converted.reserve(v.size());
|
||||
|
||||
for (auto vi : v) {
|
||||
converted.push_back(static_cast<EnumType>(vi));
|
||||
}
|
||||
return converted;
|
||||
}
|
||||
|
||||
bool check_tensorwise_recipe(
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
bool check_rowwise_recipe(
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&,
|
||||
c10::ScalarType,
|
||||
std::vector<ScalingType>&,
|
||||
ArrayRef<Tensor>&);
|
||||
|
||||
} // namespace at::native::onednn::scaled
|
||||
@ -1,238 +0,0 @@
|
||||
# Owner(s): ["module: complex"]
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# Support both when imported from elsewhere or directly as a file
|
||||
try:
|
||||
from .utils import (
|
||||
COMPLEX_DTYPES,
|
||||
Descriptor,
|
||||
force_test_op_db,
|
||||
get_overload_packet_from_name,
|
||||
implemented_op_db,
|
||||
TestCase,
|
||||
Variant,
|
||||
)
|
||||
except ImportError:
|
||||
from utils import (
|
||||
COMPLEX_DTYPES,
|
||||
Descriptor,
|
||||
force_test_op_db,
|
||||
get_overload_packet_from_name,
|
||||
implemented_op_db,
|
||||
TestCase,
|
||||
Variant,
|
||||
)
|
||||
|
||||
from torch._subclasses.complex_tensor._ops.common import ComplexTensorMode
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
OpDTypes,
|
||||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TestGradients,
|
||||
unMarkDynamoStrictTest,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.testing._internal.opinfo.core import OpInfo
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
SKIPS = {
|
||||
Descriptor(op=aten.empty_like, variant=None): "Non-deterministic output",
|
||||
Descriptor(op=aten.randn_like, variant=None): "Non-deterministic output",
|
||||
Descriptor(op=aten.angle, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.asinh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.atanh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(
|
||||
op=aten.reciprocal, variant=Variant.GradCheck
|
||||
): "Numerical inconsistency",
|
||||
Descriptor(op=aten.rsqrt, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.select, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.asin, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.log, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.sgn, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.cumprod, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.slice, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.sqrt, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.tan, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(
|
||||
op=aten.true_divide, variant=Variant.GradCheck
|
||||
): "Numerical inconsistency",
|
||||
Descriptor(op=aten.prod, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.div, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.expm1, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.var, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.bmm, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.diagonal, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.sinh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.abs, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.sin, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.atan, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.acos, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.acosh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.cos, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.cosh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.addmm, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.pow, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.log1p, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.tanh, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.mm, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.dot, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.mul, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.exp, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(op=aten.to, variant=Variant.GradCheck): "Numerical inconsistency",
|
||||
Descriptor(
|
||||
op=aten.any, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.all, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.allclose, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.conj_physical, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten._conj_physical, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.cumprod, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.index_add, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.diagonal_scatter, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.flip, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.masked_fill, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.masked_scatter, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.rsub, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.ne, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.squeeze, variant=Variant.Distributed
|
||||
): "does not have a sharding strategy registered",
|
||||
Descriptor(
|
||||
op=aten.index_select, variant=Variant.Distributed
|
||||
): "Sharding propagation failed",
|
||||
Descriptor(op=aten.real, variant=Variant.Distributed): "No scalar support",
|
||||
Descriptor(op=aten.imag, variant=Variant.Distributed): "No scalar support",
|
||||
Descriptor(op=aten.isfinite, variant=Variant.Distributed): "No scalar support",
|
||||
Descriptor(op=aten.transpose, variant=Variant.Distributed): "No scalar support",
|
||||
Descriptor(op=aten.view_as_real, variant=Variant.Distributed): "No scalar support",
|
||||
}
|
||||
|
||||
EXTRA_KWARGS = {
|
||||
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Op): {
|
||||
"rtol": 2e-5,
|
||||
"atol": 5e-5,
|
||||
},
|
||||
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Op): {
|
||||
"rtol": 1e-4,
|
||||
"atol": 1e-5,
|
||||
},
|
||||
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Op): {
|
||||
"rtol": 2e-2,
|
||||
"atol": 2e-6,
|
||||
},
|
||||
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Distributed): {
|
||||
"rtol": 2e-5,
|
||||
"atol": 5e-5,
|
||||
},
|
||||
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Distributed): {
|
||||
"rtol": 1e-4,
|
||||
"atol": 1e-5,
|
||||
},
|
||||
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Distributed): {
|
||||
"rtol": 2e-2,
|
||||
"atol": 2e-6,
|
||||
},
|
||||
Descriptor(op=aten.tan, dtype=torch.complex64, variant=Variant.Distributed): {
|
||||
"rtol": 2e-6,
|
||||
"atol": 1e-2,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestComplexTensor(TestCase):
|
||||
_default_dtype_check_enabled = True
|
||||
|
||||
@ops(
|
||||
implemented_op_db,
|
||||
dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=list(COMPLEX_DTYPES),
|
||||
)
|
||||
def test_consistency(self, device, dtype, op: OpInfo):
|
||||
self.check_consistency(device, dtype, op, Variant.Op)
|
||||
|
||||
@ops(force_test_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
|
||||
def test_maybe_error(self, device, dtype, op: OpInfo):
|
||||
self.check_consistency(device, dtype, op, Variant.Op)
|
||||
|
||||
|
||||
@unMarkDynamoStrictTest
|
||||
class TestComplexBwdGradients(TestGradients):
|
||||
_default_dtype_check_enabled = True
|
||||
|
||||
@ops(
|
||||
implemented_op_db,
|
||||
dtypes=OpDTypes.supported_backward,
|
||||
allowed_dtypes=[torch.complex128],
|
||||
)
|
||||
def test_fn_grad(self, device: str, dtype: torch.dtype, op: OpInfo) -> None:
|
||||
test_info = Descriptor(
|
||||
op=get_overload_packet_from_name(op.name),
|
||||
device_type=torch.device(device).type,
|
||||
dtype=dtype,
|
||||
variant=Variant.GradCheck,
|
||||
)
|
||||
for xfail_info, reason in SKIPS.items():
|
||||
if xfail_info.matches(test_info):
|
||||
self.skipTest(reason)
|
||||
|
||||
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
||||
self.skipTest(f"Skipped! {dtype=} is not in supported backward dtypes!")
|
||||
|
||||
with ComplexTensorMode():
|
||||
op.gradcheck_fast_mode = False
|
||||
self._grad_test_helper(device, dtype, op, op.get_op())
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestComplexTensor, globals())
|
||||
instantiate_device_type_tests(TestComplexBwdGradients, globals())
|
||||
|
||||
|
||||
if dist.is_available():
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
||||
|
||||
@unMarkDynamoStrictTest
|
||||
class TestComplexDistributed(TestCase, MultiProcessTestCase):
|
||||
@ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
|
||||
def test_distributed(self, device, dtype, op: OpInfo):
|
||||
self.check_consistency(device, dtype, op, Variant.Distributed)
|
||||
|
||||
instantiate_device_type_tests(TestComplexDistributed, globals())
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -1,214 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, fields
|
||||
from enum import auto, Enum
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._subclasses.complex_tensor._ops.common import (
|
||||
_as_complex_tensor,
|
||||
_as_interleaved,
|
||||
_get_op_name,
|
||||
COMPLEX_OPS_TABLE,
|
||||
COMPLEX_TO_REAL,
|
||||
FORCE_TEST_LIST,
|
||||
OpOverloadPacket,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.testing._internal.opinfo.core import OpInfo
|
||||
|
||||
COMPLEX_DTYPES = set(COMPLEX_TO_REAL)
|
||||
|
||||
|
||||
class Variant(Enum):
|
||||
Op = auto()
|
||||
GradCheck = auto()
|
||||
Distributed = auto()
|
||||
|
||||
|
||||
def _as_local(arg: DTensor | Any) -> torch.Tensor | Any:
|
||||
if not (dist.is_available() and isinstance(arg, dist.tensor.DTensor)):
|
||||
return arg
|
||||
|
||||
return arg.full_tensor()
|
||||
|
||||
|
||||
def _as_complex_dtensor(arg: torch.Tensor | Any) -> torch.Tensor | Any:
|
||||
if not isinstance(arg, torch.Tensor):
|
||||
return arg
|
||||
|
||||
return dist.tensor.DTensor.from_local(_as_complex_tensor(arg))
|
||||
|
||||
|
||||
TRANSFORM_FUNCS = {
|
||||
Variant.Op: _as_complex_tensor,
|
||||
Variant.Distributed: _as_complex_dtensor,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Descriptor:
|
||||
op: OpOverloadPacket
|
||||
variant: Variant | None
|
||||
device_type: str | None = field(default=None)
|
||||
dtype: torch.dtype | None = field(default=None)
|
||||
|
||||
def matches(self, other: Descriptor) -> bool:
|
||||
fields1 = fields(self)
|
||||
fields2 = fields(other)
|
||||
if fields1 != fields2:
|
||||
return False
|
||||
|
||||
for f in fields1:
|
||||
f1 = getattr(self, f.name)
|
||||
f2 = getattr(other, f.name)
|
||||
if f1 is not None and f2 is not None and f1 != f2:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class TestCase(PytorchTestCase):
|
||||
def assertSameResult(
|
||||
self,
|
||||
expected: Callable[[], Any],
|
||||
actual: Callable[[], Any],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
try:
|
||||
result_e = expected()
|
||||
exception_e = None
|
||||
except Exception as e: # noqa: BLE001
|
||||
result_e = None
|
||||
exception_e = e
|
||||
|
||||
try:
|
||||
result_a = actual()
|
||||
exception_a = None
|
||||
except Exception as e: # noqa: BLE001
|
||||
result_a = None
|
||||
exception_a = e
|
||||
|
||||
if (exception_e is None) != (exception_a is None):
|
||||
if exception_a is not None and exception_e is None:
|
||||
raise exception_a
|
||||
self.assertIs(
|
||||
type(exception_e),
|
||||
type(exception_a),
|
||||
f"\n{exception_e=}\n{exception_a=}",
|
||||
)
|
||||
|
||||
if exception_e is None:
|
||||
flattened_e, spec_e = tree_flatten(result_e)
|
||||
flattened_a, spec_a = tree_flatten(result_a)
|
||||
|
||||
self.assertEqual(
|
||||
spec_e,
|
||||
spec_a,
|
||||
"Both functions must return a result with the same tree structure.",
|
||||
)
|
||||
for value_e, value_a in zip(flattened_e, flattened_a, strict=True):
|
||||
value_e = _as_interleaved(_as_local(value_e))
|
||||
value_a = _as_interleaved(_as_local(value_a))
|
||||
|
||||
self.assertEqual(value_e, value_a, *args, **kwargs)
|
||||
|
||||
def check_consistency(
|
||||
self, device: str, dtype, op: OpInfo, variant: Variant
|
||||
) -> None:
|
||||
try:
|
||||
from .test_complex_tensor import EXTRA_KWARGS, SKIPS
|
||||
except ImportError:
|
||||
from test_complex_tensor import EXTRA_KWARGS, SKIPS
|
||||
test_info = Descriptor(
|
||||
op=get_overload_packet_from_name(op.name),
|
||||
device_type=torch.device(device).type,
|
||||
dtype=dtype,
|
||||
variant=variant,
|
||||
)
|
||||
for xfail_info, reason in SKIPS.items():
|
||||
if xfail_info.matches(test_info):
|
||||
self.skipTest(reason)
|
||||
|
||||
kwargs = {}
|
||||
for extra_info, extra_kw in EXTRA_KWARGS.items():
|
||||
if extra_info.matches(test_info):
|
||||
kwargs = extra_kw
|
||||
break
|
||||
sample_inputs = op.sample_inputs(device, dtype)
|
||||
transform_fn = TRANSFORM_FUNCS[variant]
|
||||
|
||||
for sample_input in sample_inputs:
|
||||
|
||||
def expected(sample_input=sample_input):
|
||||
return op(sample_input.input, *sample_input.args, **sample_input.kwargs)
|
||||
|
||||
subclass_sample = sample_input.transform(transform_fn)
|
||||
|
||||
def actual(subclass_sample=subclass_sample):
|
||||
return op(
|
||||
subclass_sample.input,
|
||||
*subclass_sample.args,
|
||||
**subclass_sample.kwargs,
|
||||
)
|
||||
|
||||
self.assertSameResult(expected, actual, **kwargs)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
complex_op_db = tuple(
|
||||
filter(lambda op: any(op.supports_dtype(ct, "cpu") for ct in COMPLEX_DTYPES), op_db)
|
||||
)
|
||||
|
||||
|
||||
def get_overload_packet_from_name(name: str) -> OpOverloadPacket:
|
||||
for domain_name in torch.ops:
|
||||
op_namespace = getattr(torch.ops, domain_name)
|
||||
op: OpOverloadPacket | None = getattr(op_namespace, name, None)
|
||||
if op is not None:
|
||||
return op
|
||||
|
||||
raise RuntimeError(f"No op with {name=} found.")
|
||||
|
||||
|
||||
force_test_names = set(map(_get_op_name, FORCE_TEST_LIST))
|
||||
implemented_op_names = (
|
||||
set(map(_get_op_name, COMPLEX_OPS_TABLE.keys())) - force_test_names
|
||||
)
|
||||
implemented_op_db = tuple(
|
||||
filter(lambda op: op.name in implemented_op_names, complex_op_db)
|
||||
)
|
||||
force_test_op_db = tuple(filter(lambda op: op.name in force_test_names, op_db))
|
||||
|
||||
tested_op_names = {op.name for op in implemented_op_db} | {
|
||||
op.name for op in force_test_op_db
|
||||
}
|
||||
non_tested_ops = {
|
||||
op for op in COMPLEX_OPS_TABLE if _get_op_name(op) not in tested_op_names
|
||||
}
|
||||
|
||||
|
||||
# TODO (hameerabbasi): There are a number of ops that don't have any associated
|
||||
# OpInfos. We still need to write tests for those ops.
|
||||
if len(non_tested_ops) != 0:
|
||||
import textwrap
|
||||
import warnings
|
||||
|
||||
list_missing_ops = "\n".join(sorted([str(op) for op in non_tested_ops]))
|
||||
warnings.warn(
|
||||
"Not all implemented ops are tested. List of ops missing tests:"
|
||||
f"\n{textwrap.indent(list_missing_ops, ' ')}",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@ -230,98 +230,6 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||
out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)])
|
||||
self.assertEqual(out_dt, out)
|
||||
|
||||
@with_comms
|
||||
def test_conv2d_no_bias_compile(self):
|
||||
"""Test Conv2d with bias=False in compile mode (Issue #167091)
|
||||
|
||||
Regression test: Previously this would fail during torch.compile
|
||||
tracing with AssertionError when bias_spec was None.
|
||||
"""
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
def conv_fn(x, w):
|
||||
return F.conv2d(x, w, bias=None, padding=1)
|
||||
|
||||
compiled_fn = torch.compile(conv_fn)
|
||||
|
||||
# Create tensors
|
||||
x = torch.randn(1, 4, 5, 5, device=self.device_type)
|
||||
w = torch.randn(8, 4, 3, 3, device=self.device_type)
|
||||
|
||||
# Distribute tensors
|
||||
x_dt = distribute_tensor(x, device_mesh, [Replicate()])
|
||||
w_dt = distribute_tensor(w, device_mesh, [Replicate()])
|
||||
|
||||
# Test eager mode for comparison
|
||||
result_eager = conv_fn(x_dt, w_dt)
|
||||
|
||||
# Test compiled mode - this should not crash
|
||||
result_compiled = compiled_fn(x_dt, w_dt)
|
||||
|
||||
# Verify shape is correct (the key regression test)
|
||||
self.assertEqual(result_compiled.shape, torch.Size([1, 8, 5, 5]))
|
||||
|
||||
# Verify numerical correctness
|
||||
torch.testing.assert_close(result_compiled.to_local(), result_eager.to_local())
|
||||
|
||||
@with_comms
|
||||
def test_conv2d_no_bias_backward(self):
|
||||
"""Test Conv2d backward pass with bias=False (Issue #167091)
|
||||
|
||||
Regression test: Previously backward pass would fail when
|
||||
grad_bias_spec was None.
|
||||
"""
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
# Create tensors with requires_grad
|
||||
x = torch.randn(1, 4, 5, 5, device=self.device_type)
|
||||
w = torch.randn(8, 4, 3, 3, device=self.device_type, requires_grad=True)
|
||||
|
||||
# Distribute tensors
|
||||
x_dt = distribute_tensor(x, device_mesh, [Replicate()])
|
||||
w_dt = torch.nn.Parameter(distribute_tensor(w, device_mesh, [Replicate()]))
|
||||
|
||||
# Forward pass
|
||||
result = F.conv2d(x_dt, w_dt, bias=None, padding=1)
|
||||
|
||||
# Backward pass - this should not crash
|
||||
grad_output = torch.randn_like(result)
|
||||
result.backward(grad_output)
|
||||
|
||||
# Check weight gradient exists (the key regression test)
|
||||
self.assertIsNotNone(w_dt.grad)
|
||||
self.assertEqual(w_dt.grad.shape, torch.Size([8, 4, 3, 3]))
|
||||
|
||||
@with_comms
|
||||
def test_conv2d_module_no_bias(self):
|
||||
"""Test nn.Conv2d module with bias=False (Issue #167091)
|
||||
|
||||
Regression test: Ensures nn.Conv2d with bias=False works with DTensor.
|
||||
"""
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
# Create model with bias=False
|
||||
model = nn.Conv2d(4, 8, kernel_size=3, padding=1, bias=False).to(
|
||||
self.device_type
|
||||
)
|
||||
nn.init.ones_(model.weight)
|
||||
|
||||
# Distribute model
|
||||
model_dt = distribute_module(model, device_mesh, _conv_fn)
|
||||
|
||||
# Create input
|
||||
x = torch.randn(1, 4, 5, 5, device=self.device_type)
|
||||
x_dt = distribute_tensor(x, device_mesh, [Replicate()])
|
||||
|
||||
# Forward pass - this should not crash
|
||||
output_dt = model_dt(x_dt)
|
||||
|
||||
# Check outputs shape is correct
|
||||
self.assertEqual(output_dt.shape, torch.Size([1, 8, 5, 5]))
|
||||
|
||||
# Check that model.bias is None
|
||||
self.assertIsNone(model.bias)
|
||||
|
||||
|
||||
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DistConvolutionOpsTest,
|
||||
@ -330,10 +238,6 @@ DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
|
||||
"test_conv_backward_none_grad_inp",
|
||||
"test_depthwise_convolution",
|
||||
"test_downsampling_convolution",
|
||||
# New tests for Issue #167091 - use send/recv via tp_convolution
|
||||
"test_conv2d_no_bias_compile",
|
||||
"test_conv2d_no_bias_backward",
|
||||
"test_conv2d_module_no_bias",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -1061,63 +1061,6 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
correct = func(a, b, c)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_multiple_hiding_nodes_bucketing(self):
|
||||
"""Test that collectives hidden by multiple compute ops can bucket together."""
|
||||
|
||||
# Use 0.5 compute multiplier so each collective needs 2 matmuls to be fully hidden
|
||||
def estimate_with_half_compute(fx_node, override_size=None):
|
||||
return estimate_aten_runtime(fx_node, compute_multiplier=0.5)
|
||||
|
||||
def func(a, b, *, ranks):
|
||||
# Two all_gathers that will be hidden by multiple compute operations
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
|
||||
# Multiple compute operations that can hide the collectives
|
||||
# With 0.5 multiplier: mm1 and mm2 together hide ag1, mm2 and mm3 together hide ag2
|
||||
mm1 = torch.matmul(a, a.T)
|
||||
mm2 = torch.matmul(b, b.T)
|
||||
mm3 = torch.matmul(a + b, (a + b).T)
|
||||
|
||||
return ag1.sum() + ag2.sum() + mm1.sum() + mm2.sum() + mm3.sum()
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
|
||||
# Patch with custom estimation that uses 0.5 multiplier
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"aten_distributed_optimizations.custom_runtime_estimation": estimate_with_half_compute
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b)
|
||||
|
||||
# Should have 1 bucketed all_gather (both ag1 and ag2 bucketed together)
|
||||
FileCheck().check_count(
|
||||
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
# Verify bucketed collective is scheduled before all matmuls
|
||||
FileCheck().check("functional.all_gather_into_tensor").check(
|
||||
"aten.mm"
|
||||
).check("aten.mm").check("aten.mm").check("wait_tensor").run(aten_graph_str)
|
||||
|
||||
# Verify correctness
|
||||
correct = func(a, b, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
|
||||
def get_toy_model(device_type: str):
|
||||
"""
|
||||
|
||||
@ -49,8 +49,7 @@ def build_collective_info(graph, hiding_annotations):
|
||||
"""
|
||||
Build CollectiveInfo dict from manual hiding annotations.
|
||||
|
||||
hiding_annotations: dict mapping collective_start -> hiding_compute_node(s)
|
||||
Can be a single node or a list/OrderedSet of nodes
|
||||
hiding_annotations: dict mapping collective_start -> hiding_compute_node
|
||||
"""
|
||||
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
|
||||
|
||||
@ -66,20 +65,12 @@ def build_collective_info(graph, hiding_annotations):
|
||||
|
||||
# Build CollectiveInfo for each collective
|
||||
for start_node, wait_node in start_to_wait.items():
|
||||
hiding_annotation = hiding_annotations.get(start_node)
|
||||
|
||||
# Convert to OrderedSet
|
||||
hiding_nodes = OrderedSet()
|
||||
if hiding_annotation is not None:
|
||||
if isinstance(hiding_annotation, list | OrderedSet):
|
||||
hiding_nodes = OrderedSet(hiding_annotation)
|
||||
else:
|
||||
hiding_nodes = OrderedSet([hiding_annotation])
|
||||
hiding_node = hiding_annotations.get(start_node)
|
||||
|
||||
# Estimate size and time
|
||||
size_bytes = 16 * 4 # 4x4 tensor of floats
|
||||
estimated_time_ms = 1.0 # Dummy time
|
||||
exposed_time_ms = 0.0 if hiding_nodes else 1.0 # Hidden if has hiding_nodes
|
||||
exposed_time_ms = 0.0 if hiding_node else 1.0 # Hidden if has hiding_node
|
||||
|
||||
collective_info[start_node] = CollectiveInfo(
|
||||
start_node=start_node,
|
||||
@ -87,7 +78,7 @@ def build_collective_info(graph, hiding_annotations):
|
||||
size_bytes=size_bytes,
|
||||
estimated_time_ms=estimated_time_ms,
|
||||
exposed_time_ms=exposed_time_ms,
|
||||
hiding_nodes=hiding_nodes,
|
||||
hiding_node=hiding_node,
|
||||
)
|
||||
|
||||
return collective_info
|
||||
@ -576,97 +567,6 @@ class TestOverlapPreservingBucketing(InductorTestCase):
|
||||
graph_str
|
||||
)
|
||||
|
||||
def test_can_bucket_with_multiple_hiding_nodes(self):
|
||||
"""
|
||||
Test that collectives with multiple hiding nodes CAN bucket.
|
||||
|
||||
Graph structure:
|
||||
ag1_start -> ag2_start -> mm1 -> mm2 -> mm3 -> ag1_wait -> ag2_wait
|
||||
|
||||
Where:
|
||||
- ag1 is hidden by mm1 and mm2
|
||||
- ag2 is hidden by mm2 and mm3
|
||||
- Both collectives share mm2 as a hiding node
|
||||
"""
|
||||
|
||||
def func(a, b):
|
||||
group_name = "0"
|
||||
group_size = 1
|
||||
|
||||
# Start both collectives
|
||||
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
a, group_size, group_name
|
||||
)
|
||||
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
b, group_size, group_name
|
||||
)
|
||||
|
||||
# Three compute operations that hide the collectives
|
||||
mm1 = torch.mm(a, a)
|
||||
mm2 = torch.mm(b, b)
|
||||
mm3 = torch.mm(a + b, a + b)
|
||||
|
||||
# Wait for both
|
||||
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
|
||||
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
|
||||
|
||||
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum() + mm3.sum()
|
||||
|
||||
# Use fake mode to trace without executing
|
||||
with FakeTensorMode():
|
||||
a = torch.ones(4, 4, device=self.device)
|
||||
b = torch.ones(4, 4, device=self.device) * 2
|
||||
|
||||
# Trace with make_fx
|
||||
traced = make_fx(func)(a, b)
|
||||
|
||||
# Find nodes using find_nodes
|
||||
ag1, ag2 = traced.graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
|
||||
)
|
||||
mm1, mm2, mm3 = traced.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
|
||||
# Manually annotate hiding relationships with multiple hiding nodes
|
||||
hiding_annotations = {
|
||||
ag1: [mm1, mm2], # ag1 is hidden by mm1 and mm2
|
||||
ag2: [mm2, mm3], # ag2 is hidden by mm2 and mm3
|
||||
}
|
||||
|
||||
# Build collective info and ancestors
|
||||
collective_info = build_collective_info(traced.graph, hiding_annotations)
|
||||
node_ancestors = compute_ancestors(traced.graph)
|
||||
scheduled = OrderedSet(traced.graph.nodes)
|
||||
|
||||
# Verify hiding_nodes are correctly set
|
||||
self.assertEqual(len(collective_info[ag1].hiding_nodes), 2)
|
||||
self.assertIn(mm1, collective_info[ag1].hiding_nodes)
|
||||
self.assertIn(mm2, collective_info[ag1].hiding_nodes)
|
||||
self.assertEqual(len(collective_info[ag2].hiding_nodes), 2)
|
||||
self.assertIn(mm2, collective_info[ag2].hiding_nodes)
|
||||
self.assertIn(mm3, collective_info[ag2].hiding_nodes)
|
||||
|
||||
# Run bucketing
|
||||
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
|
||||
OverlapPreservingBucketer,
|
||||
)
|
||||
|
||||
bucketer = OverlapPreservingBucketer(
|
||||
traced.graph,
|
||||
collective_info,
|
||||
node_ancestors,
|
||||
scheduled,
|
||||
)
|
||||
bucketer.bucket_collectives()
|
||||
|
||||
FileCheck().check_count(
|
||||
"all_gather_into_tensor_out", 1, exactly=False
|
||||
).check_count("torch.ops.aten.mm.default", 3, exactly=True).run(
|
||||
str(traced.graph)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -253,14 +253,6 @@ class StoreTestBase:
|
||||
a.set("foo", "bar")
|
||||
self.assertEqual(b.get("foo"), b"bar")
|
||||
|
||||
def test_list_keys(self):
|
||||
a = self._create_store()
|
||||
a.set("foo", "bar")
|
||||
a.set("baz", "qux")
|
||||
keys = a.list_keys()
|
||||
self.assertIn("foo", keys)
|
||||
self.assertIn("baz", keys)
|
||||
|
||||
# This is the number of keys used in test_set_get. Adding this as a class
|
||||
# property instead of hardcoding in the test since some Store
|
||||
# implementations will have differing number of keys. In the base case,
|
||||
|
||||
@ -470,7 +470,7 @@ class <lambda>(torch.nn.Module):
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_backward_simple(self) -> None:
|
||||
def test_stream_backward(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
@ -524,68 +524,7 @@ class GraphModule(torch.nn.Module):
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_backward_sync(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s0:
|
||||
y0 = 2 * x + y
|
||||
with s2:
|
||||
z = 2 * x + y
|
||||
|
||||
return y0, z
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2, device="cuda:0", requires_grad=True) + 1,
|
||||
torch.ones(2, 2, device="cuda:0", requires_grad=True),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
bw_graphs,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
|
||||
return (add, add_1)
|
||||
""",
|
||||
)
|
||||
|
||||
actual[1].sum().backward()
|
||||
self.assertExpectedInline(
|
||||
print_graph(bw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
|
||||
|
||||
#
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
|
||||
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
""",
|
||||
|
||||
@ -456,31 +456,6 @@ def forward(self, x):
|
||||
test_inputs = make_inputs()
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
|
||||
def test_dynamo_graph_capture_with_call_override(self):
|
||||
class _InterestingModule(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self._module = module
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._module(*args, **kwargs)
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
foo = _InterestingModule(MyModel())
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(2, 3),)
|
||||
|
||||
trace_inputs = make_inputs()
|
||||
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||
test_inputs = make_inputs()
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers())))
|
||||
self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters())))
|
||||
|
||||
def test_dynamo_graph_capture_custom_pytree_type(self):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
|
||||
@ -3,17 +3,12 @@ import io
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestHopPrint(TestCase):
|
||||
def test_base_print(self):
|
||||
def f(x):
|
||||
@ -23,6 +18,7 @@ class TestHopPrint(TestCase):
|
||||
torch._higher_order_ops.print("moo")
|
||||
return x
|
||||
|
||||
counters.clear()
|
||||
x = torch.randn(3, 3)
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
f(x)
|
||||
@ -37,6 +33,7 @@ class TestHopPrint(TestCase):
|
||||
x = x * x
|
||||
return x
|
||||
|
||||
counters.clear()
|
||||
x = torch.randn(3, 3)
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
f(x)
|
||||
@ -187,62 +184,6 @@ x = add_1, y = add_2); getitem = None
|
||||
"""print(str format_str) -> ()""",
|
||||
)
|
||||
|
||||
@parametrize("backend", ["eager", "aot_eager"])
|
||||
def test_reorder_print_no_graph_break(self, backend):
|
||||
def f(x):
|
||||
x1 = x + x
|
||||
torch._higher_order_ops.print("moo {x}", x=x1)
|
||||
x2 = x1 * x1
|
||||
torch._higher_order_ops.print("moo {x}", x=x2)
|
||||
x3 = x2 + x2
|
||||
return (x1, x3)
|
||||
|
||||
# Eager and aot_eager backend for dynamo tracing testing
|
||||
x = torch.randn(3, 3)
|
||||
opt_f = torch.compile(backend=backend, fullgraph=True)(f)
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
opt_out = opt_f(x)
|
||||
printed_output = mock_stdout.getvalue().strip()
|
||||
orig_out = f(x)
|
||||
|
||||
self.assertEqual(
|
||||
printed_output,
|
||||
f"moo {x * 2}\nmoo {x * 2 * x * 2}",
|
||||
)
|
||||
self.assertEqual(orig_out, opt_out)
|
||||
|
||||
x_new = torch.randn(2, 2)
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
opt_out = opt_f(x_new)
|
||||
printed_output = mock_stdout.getvalue().strip()
|
||||
|
||||
self.assertEqual(
|
||||
printed_output,
|
||||
f"moo {x_new * 2}\nmoo {x_new * 2 * x_new * 2}",
|
||||
)
|
||||
|
||||
@parametrize("backend", ["eager", "aot_eager"])
|
||||
def test_constant_mutation(self, backend):
|
||||
def f(x):
|
||||
alist = [x]
|
||||
alist.append(x + 1)
|
||||
torch._higher_order_ops.print("moo {x}", x=alist[-1])
|
||||
alist[0].sum().item() # graph break
|
||||
res = alist.pop()
|
||||
torch._higher_order_ops.print("moo {x}", x=alist[-1])
|
||||
res.sum().item() # graph break
|
||||
return res
|
||||
|
||||
inputs = (torch.tensor([1]),)
|
||||
opt_f = torch.compile(backend=backend, fullgraph=True)(f)
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
opt_out = opt_f(*inputs)
|
||||
printed_output = mock_stdout.getvalue().strip()
|
||||
orig_out = f(*inputs)
|
||||
|
||||
self.assertEqual(printed_output, "moo tensor([2])\nmoo tensor([1])")
|
||||
self.assertEqual(orig_out, opt_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -1188,22 +1188,6 @@ class TestTiling(TestCase):
|
||||
with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad():
|
||||
torch.compile(f)(x)
|
||||
|
||||
def test_find_broadcast_var(self):
|
||||
"""Test broadcast variable detection for tiling improvements."""
|
||||
from torch._inductor import tiling_utils
|
||||
|
||||
i, j, k = sympy.symbols("i j k", integer=True)
|
||||
|
||||
# Test broadcast pattern detection: FloorDiv creates broadcast
|
||||
result = tiling_utils.find_broadcast_var(
|
||||
FloorDiv(i, 10), {i: 100, j: 50, k: 20}
|
||||
)
|
||||
self.assertEqual(result, i)
|
||||
|
||||
# Test non-broadcast: linear access pattern
|
||||
result = tiling_utils.find_broadcast_var(i + j * 10, {i: 10, j: 8, k: 20})
|
||||
self.assertEqual(result, None)
|
||||
|
||||
|
||||
class TestIndexInversion(TestCase):
|
||||
@classmethod
|
||||
|
||||
@ -630,6 +630,7 @@ class TestSparse(TestSparseBase):
|
||||
i[0][0] = 0
|
||||
self.assertEqual(torch.empty((3, 0), dtype=dtype, device=device), self.safeToDense(x))
|
||||
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error")
|
||||
@ -646,8 +647,7 @@ class TestSparse(TestSparseBase):
|
||||
def fn(x):
|
||||
return x.to_dense(masked_grad=gradcheck.masked)
|
||||
x.requires_grad_(True)
|
||||
kwargs = {"eps": 1e-4} if device == "mps:0" else {}
|
||||
gradcheck(fn, (x,), **kwargs)
|
||||
gradcheck(fn, (x,))
|
||||
|
||||
i = self.index_tensor([
|
||||
[0, 1, 2, 2],
|
||||
|
||||
@ -423,156 +423,6 @@ class TestFuzzerCompileIssues(TestCase):
|
||||
out_compiled.sum().backward()
|
||||
print("Compile Success! ✅")
|
||||
|
||||
@pytest.mark.xfail(reason="Issue #167937")
|
||||
def test_fuzzer_issue_167937(self):
|
||||
torch.manual_seed(1251149731)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def fuzzed_program(
|
||||
arg_0,
|
||||
arg_1,
|
||||
arg_2,
|
||||
arg_3,
|
||||
arg_4,
|
||||
arg_5,
|
||||
arg_6,
|
||||
arg_7,
|
||||
arg_8,
|
||||
arg_9,
|
||||
sentinel,
|
||||
):
|
||||
var_node_3 = arg_0 # size=(27, 28, 7), stride=(196, 7, 1), dtype=bfloat16, device=cuda
|
||||
var_node_4 = (
|
||||
arg_1 # size=(27, 7, 6), stride=(42, 6, 1), dtype=bfloat16, device=cuda
|
||||
)
|
||||
var_node_2 = torch.matmul(
|
||||
var_node_3.to(torch.bfloat16), var_node_4.to(torch.bfloat16)
|
||||
) # size=(27, 28, 6), stride=(168, 6, 1), dtype=bfloat16, device=cuda
|
||||
var_node_6 = (
|
||||
arg_2 # size=(27, 6, 9), stride=(54, 9, 1), dtype=bfloat16, device=cuda
|
||||
)
|
||||
var_node_7 = torch.full(
|
||||
(27, 9, 1), -0.310546875, dtype=torch.bfloat16
|
||||
) # size=(27, 9, 1), stride=(9, 1, 1), dtype=bfloat16, device=cuda
|
||||
var_node_5 = torch.matmul(
|
||||
var_node_6.to(torch.bfloat16), var_node_7.to(torch.bfloat16)
|
||||
) # size=(27, 6, 1), stride=(6, 1, 1), dtype=bfloat16, device=cuda
|
||||
var_node_1 = torch.matmul(
|
||||
var_node_2.to(torch.bfloat16), var_node_5.to(torch.bfloat16)
|
||||
) # size=(27, 28, 1), stride=(28, 1, 1), dtype=bfloat16, device=cuda
|
||||
var_node_8 = arg_3 # size=(27, 28, 1), stride=(28, 1, 1), dtype=bfloat16, device=cuda
|
||||
var_node_9 = torch.full(
|
||||
(27, 28, 1), 0.76953125, dtype=torch.bfloat16
|
||||
) # size=(27, 28, 1), stride=(28, 1, 1), dtype=bfloat16, device=cuda
|
||||
var_node_12 = (
|
||||
arg_4 # size=(3, 4), stride=(4, 1), dtype=bfloat16, device=cuda
|
||||
)
|
||||
var_node_13 = (
|
||||
arg_5 # size=(4, 15), stride=(15, 1), dtype=bfloat16, device=cuda
|
||||
)
|
||||
var_node_11 = torch.matmul(
|
||||
var_node_12.to(torch.bfloat16), var_node_13.to(torch.bfloat16)
|
||||
) # size=(3, 15), stride=(15, 1), dtype=bfloat16, device=cuda
|
||||
var_node_15 = (
|
||||
arg_6 # size=(15, 12), stride=(12, 1), dtype=bfloat16, device=cuda
|
||||
)
|
||||
var_node_16 = (
|
||||
arg_7 # size=(12, 1), stride=(1, 1), dtype=bfloat16, device=cuda
|
||||
)
|
||||
var_node_14 = torch.matmul(
|
||||
var_node_15.to(torch.bfloat16), var_node_16.to(torch.bfloat16)
|
||||
) # size=(15, 1), stride=(1, 1), dtype=bfloat16, device=cuda
|
||||
var_node_10 = torch.matmul(
|
||||
var_node_11.to(torch.bfloat16), var_node_14.to(torch.bfloat16)
|
||||
) # size=(3, 1), stride=(1, 1), dtype=bfloat16, device=cuda
|
||||
var_node_19 = (
|
||||
arg_8 # size=(1, 8), stride=(8, 1), dtype=bfloat16, device=cuda
|
||||
)
|
||||
var_node_20 = (
|
||||
arg_9 # size=(8, 2), stride=(2, 1), dtype=bfloat16, device=cuda
|
||||
)
|
||||
var_node_18 = torch.matmul(
|
||||
var_node_19.to(torch.bfloat16), var_node_20.to(torch.bfloat16)
|
||||
) # size=(1, 2), stride=(2, 1), dtype=bfloat16, device=cuda
|
||||
var_node_21 = torch.full(
|
||||
(2, 1), 0.000762939453125, dtype=torch.bfloat16
|
||||
) # size=(2, 1), stride=(1, 1), dtype=bfloat16, device=cuda
|
||||
var_node_17 = torch.matmul(
|
||||
var_node_18.to(torch.bfloat16), var_node_21.to(torch.bfloat16)
|
||||
) # size=(1, 1), stride=(1, 1), dtype=bfloat16, device=cuda
|
||||
var_node_0, _ = torch.nn.functional.multi_head_attention_forward(
|
||||
var_node_1.to(torch.bfloat16),
|
||||
var_node_8.to(torch.bfloat16),
|
||||
var_node_9.to(torch.bfloat16),
|
||||
1,
|
||||
1,
|
||||
var_node_10.to(torch.bfloat16),
|
||||
None, # in_proj_bias
|
||||
None, # bias_k
|
||||
None, # bias_v
|
||||
False, # add_zero_attn
|
||||
0.0, # dropout_p (no dropout for testing)
|
||||
var_node_17.to(torch.bfloat16),
|
||||
None, # out_proj_bias
|
||||
training=False, # Use eval mode for deterministic behavior
|
||||
need_weights=False, # Don't compute attention weights for performance
|
||||
) # size=(27, 28, 1), stride=(28, 1, 1), dtype=bfloat16, device=cuda
|
||||
# Ensure gradient computation by multiplying with sentinel and taking real part
|
||||
result = var_node_0 * sentinel
|
||||
if result.is_complex():
|
||||
result = result.real
|
||||
return result
|
||||
|
||||
try:
|
||||
# Sentinel tensor to ensure gradient computation
|
||||
sentinel = torch.tensor(1.0, requires_grad=True)
|
||||
arg_0 = torch.as_strided(
|
||||
torch.randn(5292).to(torch.bfloat16), (27, 28, 7), (196, 7, 1)
|
||||
)
|
||||
arg_1 = torch.as_strided(
|
||||
torch.randn(1134).to(torch.bfloat16), (27, 7, 6), (42, 6, 1)
|
||||
)
|
||||
arg_2 = torch.as_strided(
|
||||
torch.randn(1458).to(torch.bfloat16), (27, 6, 9), (54, 9, 1)
|
||||
)
|
||||
arg_3 = torch.as_strided(
|
||||
torch.randn(756).to(torch.bfloat16), (27, 28, 1), (28, 1, 1)
|
||||
)
|
||||
arg_4 = torch.as_strided(torch.randn(12).to(torch.bfloat16), (3, 4), (4, 1))
|
||||
arg_5 = torch.as_strided(
|
||||
torch.randn(60).to(torch.bfloat16), (4, 15), (15, 1)
|
||||
)
|
||||
arg_6 = torch.as_strided(
|
||||
torch.randn(180).to(torch.bfloat16), (15, 12), (12, 1)
|
||||
)
|
||||
arg_7 = torch.as_strided(
|
||||
torch.randn(12).to(torch.bfloat16), (12, 1), (1, 1)
|
||||
)
|
||||
arg_8 = torch.as_strided(torch.randn(8).to(torch.bfloat16), (1, 8), (8, 1))
|
||||
arg_9 = torch.as_strided(torch.randn(16).to(torch.bfloat16), (8, 2), (2, 1))
|
||||
args = (
|
||||
arg_0,
|
||||
arg_1,
|
||||
arg_2,
|
||||
arg_3,
|
||||
arg_4,
|
||||
arg_5,
|
||||
arg_6,
|
||||
arg_7,
|
||||
arg_8,
|
||||
arg_9,
|
||||
) + (sentinel,)
|
||||
|
||||
out_eager = fuzzed_program(*args)
|
||||
out_eager.sum().backward()
|
||||
print("Eager Success! ✅")
|
||||
compiled_foo = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)
|
||||
out_compiled = compiled_foo(*args)
|
||||
out_compiled.sum().backward()
|
||||
print("Compile Success! ✅")
|
||||
finally:
|
||||
torch.set_default_device(None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -267,10 +267,7 @@ class DefaultFuzzTemplate(FuzzTemplate):
|
||||
]
|
||||
|
||||
def flags_codegen(self):
|
||||
return [
|
||||
"torch.set_default_device('cuda')",
|
||||
"torch._dynamo.config.capture_scalar_outputs = True",
|
||||
]
|
||||
return ["torch._dynamo.config.capture_scalar_outputs = True"]
|
||||
|
||||
def epilogue_codegen(self):
|
||||
return []
|
||||
@ -493,7 +490,6 @@ class UnbackedFuzzTemplate(FuzzTemplate):
|
||||
|
||||
def flags_codegen(self):
|
||||
return [
|
||||
"torch.set_default_device('cuda')",
|
||||
"torch._dynamo.config.capture_scalar_outputs = True",
|
||||
"torch._dynamo.config.capture_dynamic_output_shape_ops = True",
|
||||
]
|
||||
|
||||
@ -48,8 +48,54 @@ def persist_print(msg):
|
||||
# List of regex patterns for ignore bucket
|
||||
IGNORE_PATTERNS: list[re.Pattern] = [
|
||||
re.compile(
|
||||
r"torch\._inductor\.exc\.InductorError: AssertionError: -1"
|
||||
), # https://github.com/pytorch/pytorch/issues/167937
|
||||
r"Dynamo failed to run FX node with fake tensors: call_method fill_diagonal_"
|
||||
), # https://github.com/pytorch/pytorch/issues/163420
|
||||
re.compile(
|
||||
r"TypeError: unsupported operand type\(s\) for divmod\(\): 'SymInt' and 'int'"
|
||||
), # https://github.com/pytorch/pytorch/issues/163457
|
||||
re.compile(
|
||||
r"RuntimeError: self\.stride\(-1\) must be 1 to view ComplexDouble as"
|
||||
), # https://github.com/pytorch/pytorch/issues/162561
|
||||
re.compile(
|
||||
r"BooleanAtom not allowed in this context"
|
||||
), # https://github.com/pytorch/pytorch/issues/160726
|
||||
re.compile(
|
||||
r"TypeError\(\"unsupported operand type\(s\) for \*: 'SymBool' and 'FakeTensor'\"\)"
|
||||
), # https://github.com/pytorch/pytorch/issues/164684
|
||||
re.compile(r"KeyError: u\d+"), # https://github.com/pytorch/pytorch/issues/164685
|
||||
re.compile(
|
||||
r"torch\._inductor\.exc\.InductorError: CppCompileError: C\+\+ compile error"
|
||||
), # https://github.com/pytorch/pytorch/issues/164686
|
||||
re.compile(
|
||||
r"\.item\(\) # dtype="
|
||||
), # https://github.com/pytorch/pytorch/issues/164725
|
||||
re.compile(
|
||||
r"dimensionality of sizes \(0\) must match dimensionality of strides \(1\)"
|
||||
), # https://github.com/pytorch/pytorch/issues/164814
|
||||
re.compile(
|
||||
r"self and mat2 must have the same dtype"
|
||||
), # https://github.com/pytorch/pytorch/issues/165718
|
||||
re.compile(
|
||||
r"free\(\): invalid next size \(fast\)"
|
||||
), # TODO: figure out why sometimes heap metadata gets corrupted on program exit (checks actually pass successfully)
|
||||
re.compile(
|
||||
r'assert "int" in str\(indices\.get_dtype\(\)\)'
|
||||
), # https://github.com/pytorch/pytorch/issues/166042
|
||||
re.compile(
|
||||
r'self\.shape_env\.guard_or_defer_runtime_assert\(expr, "guard_equals"\)'
|
||||
), # https://github.com/pytorch/pytorch/issues/166245
|
||||
re.compile(
|
||||
r"assert len\(self\.stride\) == len\(order\)"
|
||||
), # https://github.com/pytorch/pytorch/issues/166270
|
||||
re.compile(
|
||||
r"assert len\(input_size\) == len\(new_size\)"
|
||||
), # https://github.com/pytorch/pytorch/issues/166279
|
||||
re.compile(
|
||||
r"torch\._inductor\.exc\.InductorError: IndexError: list index out of range"
|
||||
), # https://github.com/pytorch/pytorch/issues/166290
|
||||
re.compile(
|
||||
r"assert bool\(static_expr\)"
|
||||
), # https://github.com/pytorch/pytorch/issues/166319
|
||||
# Add more patterns here as needed, e.g.:
|
||||
# re.compile(r"Some other error message"),
|
||||
]
|
||||
|
||||
@ -215,7 +215,6 @@ class Store:
|
||||
def queue_pop(self, key: str, block: bool = True) -> bytes: ...
|
||||
def queue_push(self, key: str, value: Union[bytes, str]) -> None: ...
|
||||
def queue_len(self, key: str) -> int: ...
|
||||
def list_keys(self) -> list[str]: ...
|
||||
|
||||
class FileStore(Store):
|
||||
def __init__(self, path: str, numWorkers: int = ...) -> None: ...
|
||||
|
||||
@ -1043,11 +1043,6 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
|
||||
import inspect
|
||||
|
||||
if isinstance(mod, torch.nn.Module):
|
||||
resolved_forward = mod.forward
|
||||
if hasattr(resolved_forward, "__self__"):
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
resolved_forward = resolved_forward.__func__
|
||||
|
||||
# Mirrored from NNModuleVariable.call_function:
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/nn_module.py#L1035
|
||||
if (
|
||||
@ -1059,12 +1054,7 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
|
||||
and len(mod._backward_hooks) == 0
|
||||
and len(torch.nn.modules.module._global_backward_pre_hooks) == 0
|
||||
and len(torch.nn.modules.module._global_backward_hooks) == 0
|
||||
and resolved_forward != torch.nn.Module.forward
|
||||
):
|
||||
# We cannot trace __call__ by default because it will break
|
||||
# the legacy dynamo export. If we want to revisit this,
|
||||
# feel free to remove this path and try unittests in
|
||||
# test_strict_export_v2.py
|
||||
mod = mod.forward
|
||||
elif isinstance(mod, torch.fx.GraphModule):
|
||||
mod = mod._call_impl
|
||||
|
||||
@ -1528,6 +1528,37 @@ class OutputGraph(OutputGraphCommon):
|
||||
|
||||
from .decorators import disable
|
||||
|
||||
if has_user_objects():
|
||||
# NB: This is where we store possible user objects before running the graph
|
||||
# index_to_user_object_weakref is the function used in the graph to translate
|
||||
# the dynamo-generated index into the actual object passed to the compiled function.
|
||||
# We generate bytecode to store all user objects at the proper index in the below
|
||||
# call.
|
||||
codegen = PyCodegen(
|
||||
self.root_tx, root, overridden_sources=overridden_sources
|
||||
)
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
torch._dynamo.graph_bytecode_inputs.__name__,
|
||||
"store_user_object_weakrefs",
|
||||
)
|
||||
)
|
||||
tmp_vars = []
|
||||
for constructor in index_to_bytecode_constructor.values():
|
||||
constructor(codegen)
|
||||
var_name = (
|
||||
self.new_var()
|
||||
) # keep alive any temp objects for the rest of the frame
|
||||
codegen.store(var_name)
|
||||
tmp_vars.append(var_name)
|
||||
|
||||
for var_name in tmp_vars:
|
||||
codegen.append_output(codegen.create_load(var_name))
|
||||
|
||||
codegen.call_function(len(index_to_bytecode_constructor), False)
|
||||
codegen.pop_top()
|
||||
self.add_output_instructions(codegen.get_instructions())
|
||||
|
||||
# to handle random calls
|
||||
if len(self.random_calls) > 0:
|
||||
random_calls_instructions = []
|
||||
@ -2312,33 +2343,6 @@ class OutputGraph(OutputGraphCommon):
|
||||
assert self.root_tx is not None
|
||||
cg = PyCodegen(self.root_tx)
|
||||
|
||||
if has_user_objects():
|
||||
# NB: This is where we store possible user objects before running the graph
|
||||
# index_to_user_object_weakref is the function used in the graph to translate
|
||||
# the dynamo-generated index into the actual object passed to the compiled function.
|
||||
# We generate bytecode to store all user objects at the proper index in the below
|
||||
# call.
|
||||
cg.add_push_null(
|
||||
lambda: cg.load_import_from(
|
||||
torch._dynamo.graph_bytecode_inputs.__name__,
|
||||
"store_user_object_weakrefs",
|
||||
)
|
||||
)
|
||||
tmp_vars = []
|
||||
for constructor in index_to_bytecode_constructor.values():
|
||||
constructor(cg)
|
||||
var_name = (
|
||||
self.new_var()
|
||||
) # keep alive any temp objects for the rest of the frame
|
||||
cg.store(var_name)
|
||||
tmp_vars.append(var_name)
|
||||
|
||||
for var_name in tmp_vars:
|
||||
cg.append_output(cg.create_load(var_name))
|
||||
|
||||
cg.call_function(len(index_to_bytecode_constructor), False)
|
||||
cg.pop_top()
|
||||
|
||||
for idx, arg in enumerate(self.graphargs):
|
||||
self.export_metadata.graph_input_idx_to_local_source[idx] = arg.source
|
||||
|
||||
@ -3007,7 +3011,7 @@ class SubgraphTracer(fx.Tracer):
|
||||
|
||||
self.tracked_tensor_or_symint_vt: OrderedSet[VariableTracker] = OrderedSet()
|
||||
|
||||
def record_tensor_or_symint_vt(self, vt: VariableTracker):
|
||||
def record_tensor_or_symint_vt(self, vt):
|
||||
self.tracked_tensor_or_symint_vt.add(vt)
|
||||
|
||||
# preserve original meta if it is available
|
||||
|
||||
@ -903,33 +903,11 @@ def reset_graph_break_dup_checker() -> None:
|
||||
graph_break_dup_warning_checker.reset()
|
||||
|
||||
|
||||
# Matches ANSI escape sequences (CSI)
|
||||
ANSI_ESCAPE_PATTERN = re.compile(
|
||||
r"""
|
||||
\x1B # ESC
|
||||
\[ # [
|
||||
[0-?]* # Parameter bytes
|
||||
[ -/]* # Intermediate bytes
|
||||
[@-~] # Final byte
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
class StripAnsiFormatter(logging.Formatter):
|
||||
"""Logging formatter that strips ANSI escape codes."""
|
||||
|
||||
def format(self, record):
|
||||
msg = super().format(record)
|
||||
return ANSI_ESCAPE_PATTERN.sub("", msg)
|
||||
|
||||
|
||||
def add_file_handler() -> contextlib.ExitStack:
|
||||
log_path = os.path.join(get_debug_dir(), "torchdynamo")
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
|
||||
log_file_handler = logging.FileHandler(os.path.join(log_path, "debug.log"))
|
||||
log_file_handler.setFormatter(StripAnsiFormatter("%(message)s"))
|
||||
logger = logging.getLogger("torch._dynamo")
|
||||
logger.addHandler(log_file_handler)
|
||||
|
||||
|
||||
@ -2673,30 +2673,6 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
|
||||
class PrintHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
def _call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
|
||||
|
||||
args_proxy = [arg.as_proxy() for arg in args]
|
||||
kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()}
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
args=tuple(args_proxy),
|
||||
kwargs=kwargs_proxy,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
def _call_function(
|
||||
self,
|
||||
@ -4561,7 +4537,6 @@ _hop_name_to_variable_class = {
|
||||
"associative_scan": AssociativeScanHigherOrderVariable,
|
||||
"scan": ScanHigherOrderVariable,
|
||||
"call_torchbind": CallTorchbindHigherOrderVariable,
|
||||
"print": PrintHigherOrderVariable,
|
||||
"wrap_with_set_grad_enabled": WrapWithSetGradEnabledHigherOrderVariable,
|
||||
"wrap_with_autocast": WrapWithAutocastHigherOrderVariable,
|
||||
"dynamo_bypassing_wrapper": DynamoBypassingWrapperHigherOrderVariable,
|
||||
|
||||
@ -33,7 +33,6 @@ from .graph_capture_wrappers import (
|
||||
handle_effect_tokens_fn,
|
||||
)
|
||||
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
|
||||
from .streams import assign_backward_streams
|
||||
from .utils import (
|
||||
call_and_expect_output_descs,
|
||||
copy_fwd_metadata_to_bw_nodes,
|
||||
@ -474,9 +473,6 @@ def aot_dispatch_autograd_graph(
|
||||
# fw node match might be erased
|
||||
copy_fwd_metadata_to_bw_nodes(fx_g)
|
||||
|
||||
# After copying metadata, assign streams to gradient accumulation nodes
|
||||
assign_backward_streams(fx_g)
|
||||
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
if not aot_config.disable_functionalization:
|
||||
# There should be *NO* mutating ops in the graph at this point.
|
||||
|
||||
@ -1,53 +0,0 @@
|
||||
from typing import Optional, TypeAlias
|
||||
|
||||
import torch.fx
|
||||
import torch.fx.traceback
|
||||
from torch._dynamo.graph_utils import _get_flat_args
|
||||
|
||||
|
||||
Node: TypeAlias = torch.fx.Node
|
||||
|
||||
|
||||
def is_gradient_acc(node: Node) -> bool:
|
||||
return node.meta.get("is_gradient_acc", False)
|
||||
|
||||
|
||||
def get_stream(node: Node) -> Optional[int]:
|
||||
maybe_annotation = node.meta.get("custom", None)
|
||||
if maybe_annotation is not None:
|
||||
return node.meta["custom"].get("stream", None)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def set_stream(node: Node, ind: int) -> None:
|
||||
if "custom" in node.meta:
|
||||
node.meta["custom"].update({"stream": ind})
|
||||
else:
|
||||
node.meta["custom"] = {"stream": ind}
|
||||
|
||||
|
||||
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
|
||||
"""Assigns backward streams to gradient accumulation nodes"""
|
||||
|
||||
# NB: iterate in reverse order to more closely match eager
|
||||
# the user node stream will be populated first
|
||||
for node in reversed(list(gm.graph.nodes)):
|
||||
if is_gradient_acc(node):
|
||||
# Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream:
|
||||
# 1. Match first stream assignment of the first user with a stream
|
||||
# 2. Match first stream assignment encountered in the args from left to right
|
||||
# This differs from eager in some cases:
|
||||
# Specifically the eager code uses the autograd node to determine the stream,
|
||||
# crucially this does not necessarily correspond to the FX graph node. For example,
|
||||
# in the backward for an add node with a constant we will passthrough and during backward tracing,
|
||||
# no op will be added to the FX graph, so our stream assignment will differ in this case.
|
||||
gradients = _get_flat_args(node, {})
|
||||
users = list(node.users.keys())
|
||||
|
||||
# All gradients will be on same device, they will be coerced if they were not with a .to() node
|
||||
for neighbor in users + gradients:
|
||||
ind = get_stream(neighbor)
|
||||
if ind is not None:
|
||||
set_stream(node, ind)
|
||||
break
|
||||
@ -78,13 +78,11 @@ class TuningProcess:
|
||||
|
||||
def workloop():
|
||||
while True:
|
||||
job, extra_env = TuningProcess.recv(read_pipe)
|
||||
job = TuningProcess.recv(read_pipe)
|
||||
if job is None:
|
||||
# None is a sentinel for the child to shut down
|
||||
break
|
||||
try:
|
||||
if extra_env:
|
||||
os.environ.update(extra_env)
|
||||
result = job()
|
||||
except Exception as e:
|
||||
result = e
|
||||
@ -97,10 +95,8 @@ class TuningProcess:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def send(
|
||||
obj: Any, write_pipe: IO[bytes], extra_env: dict[str, str] | None = None
|
||||
) -> None:
|
||||
pickle.dump((obj, extra_env), write_pipe)
|
||||
def send(obj: Any, write_pipe: IO[bytes]) -> None:
|
||||
pickle.dump(obj, write_pipe)
|
||||
write_pipe.flush()
|
||||
|
||||
@staticmethod
|
||||
@ -162,13 +158,13 @@ class TuningProcess:
|
||||
"""
|
||||
return self.running and self.process.poll() is None
|
||||
|
||||
def put(self, req: Any, extra_env: dict[str, str] | None = None) -> None:
|
||||
def put(self, req: Any) -> None:
|
||||
"""
|
||||
Push a work item to the child process.
|
||||
"""
|
||||
if not self.alive():
|
||||
self.start()
|
||||
TuningProcess.send(req, self.write_pipe, extra_env=extra_env)
|
||||
TuningProcess.send(req, self.write_pipe)
|
||||
|
||||
def get(self, timeout: float = 120.0) -> Any:
|
||||
"""
|
||||
@ -178,7 +174,7 @@ class TuningProcess:
|
||||
try:
|
||||
if not self.selector.select(timeout):
|
||||
raise TimeoutError(f"Timeout in autotune subprocess {self.process.pid}")
|
||||
result, _ = TuningProcess.recv(self.read_pipe)
|
||||
result = TuningProcess.recv(self.read_pipe)
|
||||
except TimeoutError:
|
||||
self.kill()
|
||||
raise
|
||||
@ -309,10 +305,8 @@ class TuningProcessPool:
|
||||
"""
|
||||
assert choice.bmreq is not None
|
||||
|
||||
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
|
||||
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
|
||||
process = self.process_queue.get()
|
||||
process.put(choice.bmreq.benchmark, extra_env=extra_env)
|
||||
process.put(choice.bmreq.benchmark)
|
||||
try:
|
||||
return process.get(
|
||||
config.max_autotune_subproc_result_timeout_seconds,
|
||||
|
||||
@ -2819,8 +2819,6 @@ class SIMDScheduling(BaseScheduling):
|
||||
bad_size_additional_tiling_penalty = 1.025
|
||||
good_size_tiling_penalty = 1.005
|
||||
|
||||
total_uncoalesced = sum(coalesce_analysis.uncoalesced_addrs.values())
|
||||
|
||||
def score_mod(t):
|
||||
score_factor = 1.0
|
||||
for tile_size in t[0].tiling.values():
|
||||
@ -2829,19 +2827,12 @@ class SIMDScheduling(BaseScheduling):
|
||||
else:
|
||||
score_factor = score_factor / good_size_tiling_penalty
|
||||
|
||||
# Add uncoalesced memory score to prevent small coalesced benefits
|
||||
# from dominating large amounts of uncoalesced memory
|
||||
uncoalesced_penalty = total_uncoalesced * 0.05
|
||||
|
||||
return -(t[0].score + uncoalesced_penalty) * score_factor
|
||||
return -t[0].score * score_factor
|
||||
|
||||
# apply penalty for longer tilings that dont increase score much
|
||||
for cand, tiling_score in sorted(tilings, key=score_mod):
|
||||
if (
|
||||
cls.tiling_is_compatible(
|
||||
node_schedule, pointwise_numel, reduction_numel, cand.tiling
|
||||
)
|
||||
or cand.tiling == default_tiling
|
||||
if cls.tiling_is_compatible(
|
||||
node_schedule, pointwise_numel, reduction_numel, cand.tiling
|
||||
):
|
||||
# we always include default reduction numel == 1, dont include
|
||||
tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0)
|
||||
|
||||
@ -176,7 +176,6 @@ class OverlapPreservingBucketer:
|
||||
head = None
|
||||
prev_event = None
|
||||
position = 0
|
||||
hiding_nodes = OrderedSet()
|
||||
|
||||
for node in self.scheduled:
|
||||
node_type = None
|
||||
@ -184,12 +183,11 @@ class OverlapPreservingBucketer:
|
||||
# Determine if this node is relevant for this PG
|
||||
if node in self.collective_info and get_group_name(node) == pg:
|
||||
node_type = "starts"
|
||||
hiding_nodes |= self.collective_info[node].hiding_nodes
|
||||
elif is_wait_tensor(node):
|
||||
wait_input = node.args[0]
|
||||
if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg:
|
||||
node_type = "waits"
|
||||
elif is_compute_node(node) or node in hiding_nodes:
|
||||
elif is_compute_node(node):
|
||||
node_type = "compute"
|
||||
|
||||
if node_type is None:
|
||||
@ -207,6 +205,7 @@ class OverlapPreservingBucketer:
|
||||
|
||||
prev_event = event
|
||||
position += 1
|
||||
|
||||
return head
|
||||
|
||||
def _populate_node_to_event(self, pg: str) -> None:
|
||||
@ -223,12 +222,10 @@ class OverlapPreservingBucketer:
|
||||
Add hiding interval constraints: start -> compute -> wait.
|
||||
"""
|
||||
for start, info in self.collective_info.items():
|
||||
if info.is_exposed:
|
||||
continue
|
||||
for hn in info.hiding_nodes:
|
||||
if info.hiding_node and not info.is_exposed:
|
||||
# Enforce: start -> compute -> wait
|
||||
self.aug_graph.add_extra_dep(n=hn, dep=start)
|
||||
self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn)
|
||||
self.aug_graph.add_extra_dep(n=info.hiding_node, dep=start)
|
||||
self.aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
|
||||
|
||||
def bucket_collectives(self) -> None:
|
||||
"""Main entry point for bucketing collectives."""
|
||||
@ -361,13 +358,13 @@ class OverlapPreservingBucketer:
|
||||
|
||||
def _get_intervals(
|
||||
self, event: PGEvent
|
||||
) -> tuple[Optional[tuple[int, int]], list[tuple[int, int]]]:
|
||||
"""Get (execution_interval, hiding_intervals) for a collective event.
|
||||
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
|
||||
"""Get (execution_interval, hiding_interval) for a collective event.
|
||||
|
||||
Returns:
|
||||
(execution_interval, hiding_intervals) where:
|
||||
(execution_interval, hiding_interval) where:
|
||||
- execution_interval is (start_pos, wait_pos) or None
|
||||
- hiding_intervals is a list of (start_pos, compute_pos) tuples, one for each hiding node
|
||||
- hiding_interval is (start_pos, compute_pos) or None if no hiding node
|
||||
|
||||
Works for both start and wait events by looking up the collective info.
|
||||
"""
|
||||
@ -378,13 +375,13 @@ class OverlapPreservingBucketer:
|
||||
elif event.is_wait:
|
||||
wait_input = event.node.args[0]
|
||||
if not isinstance(wait_input, fx.Node):
|
||||
return None, []
|
||||
return None, None
|
||||
coll = wait_input
|
||||
else:
|
||||
return None, []
|
||||
return None, None
|
||||
|
||||
if coll not in self.collective_info:
|
||||
return None, []
|
||||
return None, None
|
||||
|
||||
info = self.collective_info[coll]
|
||||
start_event = self.node_to_event[coll]
|
||||
@ -392,17 +389,14 @@ class OverlapPreservingBucketer:
|
||||
|
||||
execution_interval = (start_event.position, wait_event.position)
|
||||
|
||||
hiding_intervals = []
|
||||
if info.hiding_nodes:
|
||||
for hiding_node in info.hiding_nodes:
|
||||
hiding_intervals.append(
|
||||
(
|
||||
start_event.position,
|
||||
self.node_to_event[hiding_node].position,
|
||||
)
|
||||
)
|
||||
hiding_interval = None
|
||||
if info.hiding_node:
|
||||
hiding_interval = (
|
||||
start_event.position,
|
||||
self.node_to_event[info.hiding_node].position,
|
||||
)
|
||||
|
||||
return execution_interval, hiding_intervals
|
||||
return execution_interval, hiding_interval
|
||||
|
||||
def _preserves_hiding_intervals(
|
||||
self,
|
||||
@ -430,9 +424,9 @@ class OverlapPreservingBucketer:
|
||||
# Collect hiding compute positions for the bucket
|
||||
bucket_hiding_compute_positions = []
|
||||
for coll in all_bucketed_colls:
|
||||
for coll_hiding_node in self.collective_info[coll].hiding_nodes:
|
||||
if hiding_node := self.collective_info[coll].hiding_node:
|
||||
bucket_hiding_compute_positions.append(
|
||||
self.node_to_event[coll_hiding_node].position
|
||||
self.node_to_event[hiding_node].position
|
||||
)
|
||||
|
||||
# Get new positions
|
||||
@ -484,10 +478,11 @@ class OverlapPreservingBucketer:
|
||||
curr_event.node not in all_bucketed_colls
|
||||
and curr_event.node not in all_bucketed_waits
|
||||
):
|
||||
exec_interval, hiding_interval_list = self._get_intervals(curr_event)
|
||||
exec_interval, hiding_interval = self._get_intervals(curr_event)
|
||||
if exec_interval:
|
||||
execution_intervals.append(exec_interval)
|
||||
hiding_intervals.extend(hiding_interval_list)
|
||||
if hiding_interval:
|
||||
hiding_intervals.append(hiding_interval)
|
||||
curr_event = curr_event.next
|
||||
|
||||
curr_event = new_wait_event.prev
|
||||
@ -496,10 +491,11 @@ class OverlapPreservingBucketer:
|
||||
curr_event.node not in all_bucketed_colls
|
||||
and curr_event.node not in all_bucketed_waits
|
||||
):
|
||||
exec_interval, hiding_interval_list = self._get_intervals(curr_event)
|
||||
exec_interval, hiding_interval = self._get_intervals(curr_event)
|
||||
if exec_interval:
|
||||
execution_intervals.append(exec_interval)
|
||||
hiding_intervals.extend(hiding_interval_list)
|
||||
if hiding_interval:
|
||||
hiding_intervals.append(hiding_interval)
|
||||
curr_event = curr_event.prev
|
||||
|
||||
# Check: no hiding interval should be enclosed by any execution interval
|
||||
@ -663,12 +659,12 @@ class OverlapPreservingBucketer:
|
||||
return True
|
||||
|
||||
# Check if existing hiding node conflicts with candidate wait
|
||||
for old_hiding_node in self.collective_info[coll].hiding_nodes:
|
||||
if self._ancestor_dep(old_hiding_node, candidate_wait):
|
||||
if hiding_node := self.collective_info[coll].hiding_node:
|
||||
if self._ancestor_dep(hiding_node, candidate_wait):
|
||||
return True
|
||||
|
||||
# Check if candidate hiding node conflicts with existing wait
|
||||
for new_hiding_node in candidate_info.hiding_nodes:
|
||||
if new_hiding_node := candidate_info.hiding_node:
|
||||
if self._ancestor_dep(new_hiding_node, coll_wait):
|
||||
return True
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
@ -190,7 +190,7 @@ class CollectiveInfo:
|
||||
size_bytes: int
|
||||
estimated_time_ms: float
|
||||
exposed_time_ms: float # How much of this collective is still exposed
|
||||
hiding_nodes: OrderedSet[fx.Node] = field(default_factory=OrderedSet)
|
||||
hiding_node: fx.Node | None = None # Node that hides this collective
|
||||
|
||||
@property
|
||||
def is_exposed(self) -> bool:
|
||||
@ -533,8 +533,6 @@ class OverlapScheduler:
|
||||
self._handle_collective_start(node)
|
||||
elif is_wait_tensor(node):
|
||||
self._handle_wait(node)
|
||||
elif node.op == "placeholder":
|
||||
self._schedule(node)
|
||||
else:
|
||||
self._handle_other(node)
|
||||
|
||||
@ -563,13 +561,11 @@ class OverlapScheduler:
|
||||
additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
|
||||
for start_node, info in self.collective_info.items():
|
||||
if info.is_exposed:
|
||||
continue
|
||||
for hn in info.hiding_nodes:
|
||||
if info.hiding_node and not info.is_exposed:
|
||||
# Compute depends on collective start (compute must wait for collective to start)
|
||||
additional_deps[hn].add(start_node)
|
||||
additional_deps[info.hiding_node].add(start_node)
|
||||
# Wait depends on compute (wait must wait for compute to finish)
|
||||
additional_deps[info.wait_node].add(hn)
|
||||
additional_deps[info.wait_node].add(info.hiding_node)
|
||||
|
||||
# Apply effect tokens to preserve these dependencies
|
||||
if additional_deps:
|
||||
@ -710,8 +706,9 @@ class OverlapScheduler:
|
||||
overlap_amount = min(info.exposed_time_ms, available_compute)
|
||||
info.exposed_time_ms -= overlap_amount
|
||||
available_compute -= overlap_amount
|
||||
info.hiding_nodes.add(node)
|
||||
if available_compute == 0:
|
||||
if info.exposed_time_ms == 0:
|
||||
info.hiding_node = node
|
||||
elif available_compute == 0:
|
||||
break
|
||||
|
||||
# Then, look for unscheduled collectives we can overlap
|
||||
@ -804,7 +801,8 @@ class OverlapScheduler:
|
||||
# after scheduling, which will account for latency reduction of bucketing
|
||||
overlap_amount = min(available_compute_time, info.exposed_time_ms)
|
||||
info.exposed_time_ms -= overlap_amount
|
||||
info.hiding_nodes.add(compute_node)
|
||||
if info.exposed_time_ms == 0:
|
||||
info.hiding_node = compute_node
|
||||
available_compute_time -= overlap_amount
|
||||
|
||||
def _find_schedulable_path(
|
||||
@ -831,7 +829,7 @@ class OverlapScheduler:
|
||||
# it's fine to schedule it
|
||||
if is_wait_tensor(node):
|
||||
info = self.collective_info[self.wait_to_start[node]]
|
||||
if info.hiding_nodes and curr_compute_node not in info.hiding_nodes:
|
||||
if info.hiding_node and info.hiding_node != curr_compute_node:
|
||||
continue
|
||||
elif node not in self.potentially_hidden_waits:
|
||||
continue
|
||||
@ -867,7 +865,7 @@ class OverlapScheduler:
|
||||
) -> bool:
|
||||
assert is_wait_tensor(wait_node)
|
||||
info = self.collective_info[self.wait_to_start[wait_node]]
|
||||
return not info.is_exposed and compute_node not in info.hiding_nodes
|
||||
return not info.is_exposed and info.hiding_node != compute_node
|
||||
|
||||
def _schedule_path_to_collective(
|
||||
self, path: OrderedSet[fx.Node], curr_compute_node: fx.Node
|
||||
@ -886,7 +884,7 @@ class OverlapScheduler:
|
||||
continue
|
||||
|
||||
info = self.collective_info[self.wait_to_start[node]]
|
||||
assert curr_compute_node not in info.hiding_nodes
|
||||
assert info.hiding_node != curr_compute_node
|
||||
self._handle_wait(node)
|
||||
continue
|
||||
|
||||
|
||||
@ -145,41 +145,6 @@ def solve_for_tiling(expr: sympy.Expr) -> Optional[sympy.Expr]:
|
||||
return None
|
||||
|
||||
|
||||
def find_broadcast_var(
|
||||
index: sympy.Expr, var_ranges: dict[sympy.Expr, int]
|
||||
) -> Optional[sympy.Expr]:
|
||||
"""
|
||||
Try to find the variable that this index is broadcast over.
|
||||
A broadcast pattern is one where consecutive values of a variable
|
||||
access the same memory location (e.g., x // 10).
|
||||
"""
|
||||
# Approximate analysis by evaluating at 1 and 0
|
||||
variables: dict[sympy.Symbol, int] = {}
|
||||
for v in index.free_symbols:
|
||||
if v in var_ranges:
|
||||
variables[v] = 0
|
||||
else:
|
||||
variables[v] = get_hint(v)
|
||||
|
||||
zero_index = sympy_subs(index, variables)
|
||||
for v in var_ranges.keys():
|
||||
if v not in index.free_symbols:
|
||||
continue
|
||||
|
||||
variables[v] = 1
|
||||
try:
|
||||
new_val = sympy_subs(index, variables)
|
||||
except ZeroDivisionError:
|
||||
loop_tiling_log.info("zero division error %s %s", index, variables)
|
||||
continue
|
||||
# Broadcast means the value doesn't change when the variable increments
|
||||
if new_val == zero_index:
|
||||
return v
|
||||
variables[v] = 0
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_coalesced_var(
|
||||
index: sympy.Expr, var_ranges: dict[sympy.Expr, int]
|
||||
) -> Optional[sympy.Expr]:
|
||||
@ -603,12 +568,11 @@ def extract_normalized_read_writes(
|
||||
return fused_out
|
||||
|
||||
|
||||
def get_score(
|
||||
addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int], buf_names: OrderedSet[str]
|
||||
) -> int:
|
||||
def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int:
|
||||
"""
|
||||
Score addr according to its approximate size.
|
||||
Score addr according to its approximate size
|
||||
"""
|
||||
|
||||
# TODO - deduplicate with candidate_tilings
|
||||
var_sizes = []
|
||||
for v in addr.free_symbols:
|
||||
@ -623,15 +587,6 @@ def get_score(
|
||||
)
|
||||
|
||||
|
||||
def try_get_buf_size(buf_name: str) -> Optional[int]:
|
||||
buf = V.graph.try_get_buffer(buf_name)
|
||||
if not buf:
|
||||
return None
|
||||
return V.graph.sizevars.atomically_apply_size_hint(
|
||||
sympy_product(buf.get_size()), fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
|
||||
|
||||
def get_hint(v: Union[sympy.Expr, int]) -> int:
|
||||
if isinstance(v, int):
|
||||
return v
|
||||
@ -657,8 +612,6 @@ class CoalesceVarAnalysis:
|
||||
# TODO: separate into dataclass that olds mem, dtype, is_write
|
||||
coalesced_by_var: dict[sympy.Expr, int]
|
||||
|
||||
uncoalesced_addrs: dict[sympy.Expr, int]
|
||||
|
||||
norm_read_writes: FusedNormalizedReadsWrites
|
||||
|
||||
suggested_split: Optional[VarTiling] = None
|
||||
@ -704,40 +657,28 @@ def analyze_memory_coalescing(
|
||||
if indirect_expr:
|
||||
continue
|
||||
|
||||
size = get_score(memory_expr, var_ranges, buf_names)
|
||||
|
||||
size = get_score(memory_expr, var_ranges)
|
||||
if size == 0:
|
||||
continue
|
||||
|
||||
maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges)
|
||||
# while broadcasting vars are not technically coalesced,
|
||||
# accesses at least stay in cache, so they provide most of the benefit.
|
||||
# treat the same for now.
|
||||
if maybe_coalesced_var is None:
|
||||
maybe_coalesced_var = find_broadcast_var(memory_expr, var_ranges)
|
||||
|
||||
total_score = 0
|
||||
byte_multipler = 0
|
||||
for buf_name in buf_names:
|
||||
if (buf := V.graph.try_get_buffer(buf_name)) and (
|
||||
buf_size := try_get_buf_size(buf_name)
|
||||
):
|
||||
# constrain by buf size since we'll read at most that many elements
|
||||
# score could be more through either masking or by broadcasting (e.g. x // 16)
|
||||
total_score += min(buf_size, size) * buf.dtype.itemsize
|
||||
if buf := V.graph.try_get_buffer(buf_name):
|
||||
byte_multipler += buf.dtype.itemsize
|
||||
|
||||
# coalesced writes more important
|
||||
total_score *= 1 if is_read else 2
|
||||
byte_multipler *= 1 if is_read else 2
|
||||
|
||||
if maybe_coalesced_var:
|
||||
coalesced_by_var[maybe_coalesced_var] += total_score
|
||||
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
|
||||
else:
|
||||
uncoalesced_addrs[memory_expr] += total_score
|
||||
uncoalesced_addrs[memory_expr] += size * byte_multipler
|
||||
|
||||
if not uncoalesced_addrs:
|
||||
return CoalesceVarAnalysis(
|
||||
coalesced_by_var=coalesced_by_var,
|
||||
uncoalesced_addrs=uncoalesced_addrs,
|
||||
norm_read_writes=norm_read_writes,
|
||||
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
||||
)
|
||||
|
||||
# map from var -> tiling -> total_score
|
||||
@ -781,9 +722,7 @@ def analyze_memory_coalescing(
|
||||
|
||||
if len(tiling_scores) == 0:
|
||||
return CoalesceVarAnalysis(
|
||||
coalesced_by_var=coalesced_by_var,
|
||||
uncoalesced_addrs=uncoalesced_addrs,
|
||||
norm_read_writes=norm_read_writes,
|
||||
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
||||
)
|
||||
|
||||
best_tiling: Optional[tuple[sympy.Expr, int]] = None
|
||||
@ -797,9 +736,7 @@ def analyze_memory_coalescing(
|
||||
|
||||
if best_tiling is None:
|
||||
return CoalesceVarAnalysis(
|
||||
coalesced_by_var=coalesced_by_var,
|
||||
uncoalesced_addrs=uncoalesced_addrs,
|
||||
norm_read_writes=norm_read_writes,
|
||||
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
||||
)
|
||||
|
||||
# TODO - for strictly pointwise fusions,
|
||||
@ -808,7 +745,6 @@ def analyze_memory_coalescing(
|
||||
# TODO - could also prefer index var splits to reduction, better tested
|
||||
return CoalesceVarAnalysis(
|
||||
coalesced_by_var=coalesced_by_var,
|
||||
uncoalesced_addrs=uncoalesced_addrs,
|
||||
norm_read_writes=norm_read_writes,
|
||||
suggested_split=VarTiling(best_tiling[0], best_tiling[1], best_tiling_score),
|
||||
)
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
from ._core import ComplexTensor
|
||||
from ._ops import ComplexTensorMode, is_complex_tensor
|
||||
|
||||
|
||||
__all__ = ["ComplexTensor", "ComplexTensorMode", "is_complex_tensor"]
|
||||
|
||||
ComplexTensor.__module__ = __name__
|
||||
ComplexTensorMode.__module__ = __name__
|
||||
is_complex_tensor.__module__ = __name__
|
||||
@ -1,151 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._ops import OpOverload
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from torch.autograd.function import FunctionCtx
|
||||
|
||||
|
||||
class ComplexTensor(Tensor):
|
||||
"""A class that decomposes all ops on complex Tensors into their real and imaginary parts."""
|
||||
|
||||
_re: Tensor
|
||||
_im: Tensor
|
||||
|
||||
def __new__(cls, real: Tensor, imag: Tensor) -> Self:
|
||||
"""Initialize a ComplexTensor from its real and imaginary parts."""
|
||||
from ._ops.common import REAL_TO_COMPLEX
|
||||
|
||||
shape = real.shape
|
||||
device = real.device
|
||||
|
||||
# TODO (hameerabbasi): `torch.compile` sometimes fails here without making these
|
||||
# contiguous. Why?
|
||||
real = real.contiguous()
|
||||
imag = imag.contiguous()
|
||||
|
||||
# TODO (hameerabbasi):
|
||||
# What should we do with dtype?
|
||||
# We could convert to the complex type (float32 -> complex64), but we
|
||||
# can't use that model for say `bfloat16` which does not have a
|
||||
# corresponding complex dtype.
|
||||
# If we want to support this complex rep using any float type (see
|
||||
# https://github.com/pytorch/pytorch/issues/95100)
|
||||
# We either need to:
|
||||
# 1) add the complex types for say `complexbf32`, knowing they can't really be used anywhere
|
||||
# else.
|
||||
# 2) We use the real float dtype here, and it is up to the user to know
|
||||
# that dtype=float<size> here really means complex<2xSize> with dtype
|
||||
# matching that of re/im parts alone
|
||||
# I'm going with 1 for now, so that I can make gradcheck and some complex
|
||||
# ops work properly, but might want to discuss this in the RFP.
|
||||
dtype = REAL_TO_COMPLEX.get(real.dtype)
|
||||
if dtype is None:
|
||||
raise TypeError(
|
||||
"Unsupported dtype for constituent tensors. Supported dtypes are: "
|
||||
f"{set(REAL_TO_COMPLEX.keys())!r}."
|
||||
)
|
||||
storage_offset = real.storage_offset()
|
||||
strides = real.stride()
|
||||
layout = real.layout
|
||||
pin_memory = real.is_pinned()
|
||||
|
||||
assert shape == imag.shape, f"Expected imag shape {shape}, got {imag.shape}"
|
||||
assert device == imag.device, (
|
||||
f"Expected imag device {device}, got {imag.device}"
|
||||
)
|
||||
assert real.dtype == imag.dtype, (
|
||||
f"Expected imag dtype {real.dtype}, got {imag.dtype}"
|
||||
)
|
||||
assert pin_memory == imag.is_pinned(), (
|
||||
f"Expected imag pinning {pin_memory}, got {imag.is_pinned()}"
|
||||
)
|
||||
|
||||
res = Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
cls,
|
||||
shape,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
storage_offset=storage_offset,
|
||||
strides=strides,
|
||||
pin_memory=pin_memory,
|
||||
layout=layout,
|
||||
requires_grad=False,
|
||||
)
|
||||
res._re = real.clone().detach()
|
||||
res._im = imag.clone().detach()
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def re(self) -> Tensor:
|
||||
return self._re
|
||||
|
||||
@property
|
||||
def im(self) -> Tensor:
|
||||
return self._im
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(
|
||||
cls,
|
||||
func: OpOverload,
|
||||
types: tuple[type, ...],
|
||||
args: tuple = (),
|
||||
kwargs: dict | None = None,
|
||||
):
|
||||
from ._ops.common import lookup_complex
|
||||
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
|
||||
impl = lookup_complex(func, *args, **kwargs)
|
||||
if impl is None:
|
||||
return NotImplemented
|
||||
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_interleaved(t: Tensor) -> ComplexTensor:
|
||||
t_real = torch.real(t)
|
||||
t_imag = torch.imag(t) if t.dtype.is_complex else torch.zeros_like(t_real)
|
||||
return Complex.apply(t_real, t_imag)
|
||||
|
||||
def as_interleaved(self) -> Tensor:
|
||||
return torch.complex(self.real, self.imag)
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(
|
||||
inner_tensors: dict[str, Tensor],
|
||||
meta: Any,
|
||||
outer_size: tuple[int, ...],
|
||||
outer_stride: tuple[int, ...],
|
||||
) -> ComplexTensor:
|
||||
assert meta is None
|
||||
re, im = inner_tensors["re"], inner_tensors["im"]
|
||||
return ComplexTensor(re, im)
|
||||
|
||||
def __tensor_flatten__(self) -> tuple[list[str], Any]:
|
||||
return ["re", "im"], None
|
||||
|
||||
def __repr__(self, *, tensor_contents=None) -> str:
|
||||
return f"ComplexTensor(real={self.re!r}, imag={self.im!r})"
|
||||
|
||||
def is_pinned(self, device: DeviceLikeType | None = None) -> bool:
|
||||
return self.re.is_pinned(device)
|
||||
|
||||
|
||||
class Complex(Function):
|
||||
@staticmethod
|
||||
def forward(ctx: FunctionCtx, real: Tensor, imag: Tensor) -> ComplexTensor: # type: ignore[bad-override]
|
||||
return ComplexTensor(real, imag)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: FunctionCtx, grad_output: ComplexTensor) -> tuple[Tensor, Tensor]: # type: ignore[bad-override]
|
||||
return grad_output.real, grad_output.imag
|
||||
@ -1,5 +0,0 @@
|
||||
from . import aten, prims
|
||||
from .common import ComplexTensorMode, is_complex_tensor
|
||||
|
||||
|
||||
__all__ = ["ComplexTensorMode", "is_complex_tensor", "aten", "prims"]
|
||||
@ -1,921 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .._core import ComplexTensor
|
||||
from .common import (
|
||||
_get_func_name,
|
||||
COMPLEX_TO_REAL,
|
||||
complex_to_real_dtype,
|
||||
is_complex,
|
||||
OpType,
|
||||
promote_tensors,
|
||||
register_binary_nonlinear,
|
||||
register_complex,
|
||||
register_error,
|
||||
register_force_test,
|
||||
register_simple,
|
||||
split_complex_arg,
|
||||
split_complex_tensor,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
def register_binary_linear(op: OpType):
|
||||
def impl_with_alpha(
|
||||
lhs: ComplexTensor, rhs: ComplexTensor, *args, alpha, **kwargs
|
||||
) -> ComplexTensor:
|
||||
return op(lhs, aten.mul(rhs, alpha, *args, **kwargs), *args, **kwargs)
|
||||
|
||||
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
|
||||
alpha = kwargs.pop("alpha", None)
|
||||
if alpha is not None:
|
||||
return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs)
|
||||
a_r, a_i = split_complex_arg(lhs)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
|
||||
u = op(a_r, b_r, *args, **kwargs)
|
||||
v = op(a_i, b_i, *args, **kwargs)
|
||||
return ComplexTensor(u.to(out_dt), v.to(out_dt))
|
||||
|
||||
return register_complex(op, impl)
|
||||
|
||||
|
||||
@register_complex(aten.real)
|
||||
def real_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
re, _ = split_complex_tensor(self)
|
||||
return re
|
||||
|
||||
|
||||
@register_complex(aten.imag)
|
||||
def imag_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
_, im = split_complex_tensor(self)
|
||||
return im
|
||||
|
||||
|
||||
@register_complex(aten.is_pinned)
|
||||
def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> bool:
|
||||
return self.is_pinned(device)
|
||||
|
||||
|
||||
SIMPLE_OPS_LIST = [
|
||||
aten.slice,
|
||||
aten.flatten,
|
||||
aten.view,
|
||||
aten.diagonal,
|
||||
aten.expand,
|
||||
aten.unsqueeze,
|
||||
aten.unsqueeze_,
|
||||
aten.mean,
|
||||
aten.sum,
|
||||
aten.clone,
|
||||
aten.neg,
|
||||
aten.flip,
|
||||
aten.permute,
|
||||
aten.repeat,
|
||||
aten.index_select,
|
||||
aten.split,
|
||||
aten.split_with_sizes,
|
||||
aten.cumsum,
|
||||
aten.detach,
|
||||
aten.select,
|
||||
aten.squeeze,
|
||||
aten.zero_,
|
||||
aten.transpose,
|
||||
aten.t,
|
||||
aten.gather,
|
||||
]
|
||||
|
||||
for simple_op in SIMPLE_OPS_LIST:
|
||||
globals()[_get_func_name(simple_op)] = register_simple(simple_op)
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
SIMPLE_FORCE_TESTED_OPS = [
|
||||
aten.copy,
|
||||
aten.col2im,
|
||||
aten.alias,
|
||||
aten.lift_fresh,
|
||||
aten._unsafe_view,
|
||||
aten.index,
|
||||
aten._neg_view,
|
||||
aten.avg_pool2d,
|
||||
aten.avg_pool3d,
|
||||
aten.avg_pool2d_backward,
|
||||
aten.avg_pool3d_backward,
|
||||
aten.masked_scatter_backward,
|
||||
aten.select_backward,
|
||||
aten.slice_backward,
|
||||
aten.embedding,
|
||||
]
|
||||
|
||||
for simple_op in SIMPLE_FORCE_TESTED_OPS:
|
||||
globals()[_get_func_name(simple_op)] = register_force_test(
|
||||
simple_op, register_simple(simple_op)
|
||||
)
|
||||
|
||||
del simple_op
|
||||
|
||||
# some binary ops which we can stamp out
|
||||
mul_impl = register_binary_nonlinear(aten.mul)
|
||||
mul__impl = register_binary_nonlinear(aten.mul_)
|
||||
mm_impl = register_binary_nonlinear(aten.mm)
|
||||
dot_impl = register_binary_nonlinear(aten.dot)
|
||||
bmm_impl = register_binary_nonlinear(aten.bmm)
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
convolution_impl = register_force_test(
|
||||
aten.convolution, register_binary_nonlinear(aten.convolution)
|
||||
)
|
||||
|
||||
slice_scatter_impl = register_force_test(
|
||||
aten.slice_scatter, register_binary_linear(aten.slice_scatter)
|
||||
)
|
||||
select_scatter_impl = register_force_test(
|
||||
aten.select_scatter, register_binary_linear(aten.select_scatter)
|
||||
)
|
||||
|
||||
add_impl = register_binary_linear(aten.add)
|
||||
add__impl = register_binary_linear(aten.add_)
|
||||
sub_impl = register_binary_linear(aten.sub)
|
||||
sub__impl = register_binary_linear(aten.sub_)
|
||||
diagonal_scatter_impl = register_binary_linear(aten.diagonal_scatter)
|
||||
fill__impl = register_binary_linear(aten.fill_)
|
||||
|
||||
|
||||
@register_complex(aten.rsub)
|
||||
def rsub_impl(lhs: ComplexTensor, rhs: ComplexTensor, alpha=None) -> ComplexTensor:
|
||||
if alpha is None:
|
||||
return torch.sub(rhs, lhs) # type: ignore[bad-return]
|
||||
return torch.sub(rhs, lhs, alpha=alpha) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.div)
|
||||
@register_complex(aten.true_divide)
|
||||
def div_impl(lhs: ComplexTensor, rhs: ComplexTensor, *, rounding_mode=None):
|
||||
if rounding_mode is not None:
|
||||
raise NotImplementedError(
|
||||
"`rounding_mode` other than `None` not implemented for`ComplexTensor`."
|
||||
)
|
||||
a_r, a_i = split_complex_tensor(lhs)
|
||||
if not is_complex(rhs):
|
||||
return ComplexTensor(a_r / rhs, a_i / rhs)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
|
||||
num_r = a_r * b_r + a_i * b_i
|
||||
num_i = a_i * b_r - a_r * b_i
|
||||
den = b_r * b_r + b_i * b_i
|
||||
return ComplexTensor(
|
||||
(num_r / den).to(out_dt),
|
||||
(num_i / den).to(out_dt),
|
||||
)
|
||||
|
||||
|
||||
@register_complex(aten.reciprocal)
|
||||
def reciprocal_impl(self: ComplexTensor):
|
||||
self_r, self_i = split_complex_tensor(self)
|
||||
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
|
||||
den = self_r * self_r + self_i * self_i
|
||||
return ComplexTensor(
|
||||
aten.div(self_r, den).to(out_dt),
|
||||
aten.div(-self_i, den).to(out_dt),
|
||||
)
|
||||
|
||||
|
||||
# reductions
|
||||
@register_complex(aten.prod)
|
||||
def prod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
|
||||
out_dt, (self,) = promote_tensors(self)
|
||||
dtype = kwargs.pop("dtype", out_dt)
|
||||
kwargs["dtype"] = complex_to_real_dtype(self.dtype)
|
||||
|
||||
prod_r = torch.prod(torch.abs(self), *args, **kwargs)
|
||||
sum_phi = torch.sum(torch.angle(self), *args, **kwargs)
|
||||
u = prod_r * torch.cos(sum_phi)
|
||||
v = prod_r * torch.sin(sum_phi)
|
||||
return ComplexTensor(u, v).to(dtype) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.pow)
|
||||
def pow_impl(self: ComplexTensor, exponent: ComplexTensor) -> ComplexTensor:
|
||||
out_dt, (self, exponent) = promote_tensors(self, exponent)
|
||||
return torch.exp(exponent * torch.log(self)).to(out_dt) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.cumprod)
|
||||
def cumprod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
kwargs["dtype"] = complex_to_real_dtype(dtype)
|
||||
|
||||
prod_r = torch.cumprod(torch.abs(self), *args, **kwargs)
|
||||
sum_phi = torch.cumsum(torch.angle(self), *args, **kwargs)
|
||||
u = prod_r * torch.cos(sum_phi)
|
||||
v = prod_r * torch.sin(sum_phi)
|
||||
return ComplexTensor(u, v)
|
||||
|
||||
|
||||
# unary funcs,
|
||||
# most of these are simple or require some kind of identity
|
||||
@register_complex(aten.abs)
|
||||
def abs_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
result = torch.hypot(x, y)
|
||||
return result.to(out_dt)
|
||||
|
||||
|
||||
@register_complex(aten.angle)
|
||||
def angle_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
return torch.atan2(y, x)
|
||||
|
||||
|
||||
@register_complex(aten.acos)
|
||||
def acos_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
_, y = split_complex_tensor(self)
|
||||
acosh_z = torch.acosh(self)
|
||||
assert isinstance(acosh_z, ComplexTensor)
|
||||
acosh_z_re, acosh_z_im = split_complex_tensor(acosh_z)
|
||||
sign_im = 2 * torch.signbit(y) - 1
|
||||
return ComplexTensor(torch.abs(acosh_z_im), sign_im * torch.abs(acosh_z_re))
|
||||
|
||||
|
||||
@register_complex(aten.asin)
|
||||
def asin_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
asinh_iz = torch.asinh(ComplexTensor(-y, x))
|
||||
assert isinstance(asinh_iz, ComplexTensor)
|
||||
asinh_iz_re, asinh_iz_im = split_complex_tensor(asinh_iz)
|
||||
return ComplexTensor(asinh_iz_im, -asinh_iz_re)
|
||||
|
||||
|
||||
@register_complex(aten.atan)
|
||||
def atan_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
tanh_iz = torch.atanh(ComplexTensor(-y, x))
|
||||
assert isinstance(tanh_iz, ComplexTensor)
|
||||
tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz)
|
||||
return ComplexTensor(tanh_iz_im, -tanh_iz_re)
|
||||
|
||||
|
||||
@register_complex(aten.asinh)
|
||||
def asinh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
out_dt, (self,) = promote_tensors(self)
|
||||
return torch.log(self + torch.sqrt(self * self + 1)).to(out_dt) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.acosh)
|
||||
def acosh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
out_dt, (self,) = promote_tensors(self)
|
||||
return torch.log(self + torch.sqrt(self * self - 1)).to(out_dt) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.atanh)
|
||||
def atanh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
|
||||
ret = 0.5 * (
|
||||
torch.log(ComplexTensor(1 + x, y)) - torch.log(ComplexTensor(1 - x, -y))
|
||||
)
|
||||
assert isinstance(ret, ComplexTensor)
|
||||
ret_re, ret_im = split_complex_tensor(ret)
|
||||
|
||||
return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.cos)
|
||||
def cos_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
return torch.cosh(ComplexTensor(-y, x)) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.cosh)
|
||||
def cosh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
u = torch.cosh(x) * torch.cos(y)
|
||||
v = torch.sinh(x) * torch.sin(y)
|
||||
return ComplexTensor(u.to(out_dt), v.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.sin)
|
||||
def sin_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
sinh_iz = torch.sinh(ComplexTensor(-y, x))
|
||||
assert isinstance(sinh_iz, ComplexTensor)
|
||||
sinh_iz_re, sinh_iz_im = split_complex_tensor(sinh_iz)
|
||||
return ComplexTensor(sinh_iz_im, -sinh_iz_re)
|
||||
|
||||
|
||||
@register_complex(aten.sinh)
|
||||
def sinh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
u = torch.sinh(x) * torch.cos(y)
|
||||
v = torch.cosh(x) * torch.sin(y)
|
||||
return ComplexTensor(u.to(out_dt), v.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.tan)
|
||||
def tan_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
tanh_iz = torch.tanh(ComplexTensor(-y, x))
|
||||
assert isinstance(tanh_iz, ComplexTensor)
|
||||
tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz)
|
||||
return ComplexTensor(tanh_iz_im, -tanh_iz_re)
|
||||
|
||||
|
||||
@register_complex(aten.tanh)
|
||||
def tanh_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
|
||||
_2x = 2 * x
|
||||
_2y = 2 * y
|
||||
_d = torch.cosh(_2x) + torch.cos(_2y)
|
||||
_2xsh = torch.sinh(_2x)
|
||||
|
||||
out_re = _2xsh / _d
|
||||
out_im = torch.sin(_2y) / _d
|
||||
|
||||
return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.exp)
|
||||
def exp_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
ex = torch.exp(x)
|
||||
u = ex * torch.cos(y)
|
||||
v = ex * torch.sin(y)
|
||||
return ComplexTensor(u.to(out_dt), v.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.expm1)
|
||||
def expm1_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
out_dt, (x, y) = promote_tensors(x, y)
|
||||
# TODO (hameerabbasi): The two lines below may have numerical issues
|
||||
ex = torch.exp(x)
|
||||
u = ex * torch.cos(y) - 1
|
||||
v = ex * torch.sin(y)
|
||||
return ComplexTensor(u.to(out_dt), v.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.log)
|
||||
def log_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
out_dt, (self,) = promote_tensors(self)
|
||||
re = torch.log(torch.abs(self))
|
||||
im = torch.angle(self)
|
||||
return ComplexTensor(re, im).to(out_dt) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.log1p)
|
||||
def log1p_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
# TODO (hameerabbasi): The line below may have numerical issues
|
||||
return torch.log(ComplexTensor(x + 1, y)) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.any)
|
||||
def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
return torch.any(x, *args, **kwargs) | torch.any(y, *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.all)
|
||||
def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
return torch.any(x, *args, **kwargs) & torch.any(y, *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.eq)
|
||||
def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
a_r, a_i = split_complex_arg(self)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.ne)
|
||||
def ne_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
a_r, a_i = split_complex_tensor(self)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
return torch.ne(a_r, b_r, *args, **kwargs) | torch.ne(a_i, b_i, *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.isnan)
|
||||
def isnan_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return torch.isnan(re) | torch.isnan(im)
|
||||
|
||||
|
||||
@register_complex(aten.isinf)
|
||||
def isinf_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return torch.isinf(re) | torch.isinf(im)
|
||||
|
||||
|
||||
@register_complex(aten.isfinite)
|
||||
def isfinite_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return torch.isfinite(re) & torch.isfinite(im)
|
||||
|
||||
|
||||
@register_complex(aten.isclose)
|
||||
def isclose_impl(
|
||||
self: ComplexTensor,
|
||||
rhs: ComplexTensor,
|
||||
rtol=1e-5,
|
||||
atol=1e-8,
|
||||
equal_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
abs_diff = torch.abs(self - rhs)
|
||||
abs_other = torch.abs(rhs)
|
||||
basic_condition = abs_diff <= (rtol * abs_other + atol)
|
||||
|
||||
# This is the nontrivial part
|
||||
if equal_nan:
|
||||
a_r, a_i = split_complex_tensor(self)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
|
||||
a_r_nan = torch.isnan(a_r)
|
||||
b_r_nan = torch.isnan(b_r)
|
||||
a_i_nan = torch.isnan(a_i)
|
||||
b_i_nan = torch.isnan(b_i)
|
||||
a_nan = a_r_nan | a_i_nan
|
||||
|
||||
# This logical expression makes sure that the isnan of both the real and imaginary parts
|
||||
# matches (so 1 + nan*i doesn't equal nan + 1*i)
|
||||
equal_nan_condition = ((a_r_nan == b_r_nan) & (a_i_nan == b_i_nan)) & a_nan
|
||||
return basic_condition | equal_nan_condition
|
||||
|
||||
return basic_condition
|
||||
|
||||
|
||||
ERROR_OPS_LIST = [
|
||||
aten.lt,
|
||||
aten.le,
|
||||
aten.gt,
|
||||
aten.ge,
|
||||
aten.amin,
|
||||
aten.amax,
|
||||
aten.clamp,
|
||||
aten.ceil,
|
||||
aten.floor,
|
||||
aten.minimum,
|
||||
aten.maximum,
|
||||
aten.trunc,
|
||||
aten.sign,
|
||||
aten.argmax,
|
||||
aten.argmin,
|
||||
aten.sort,
|
||||
aten.topk,
|
||||
aten.round,
|
||||
aten.fmod,
|
||||
]
|
||||
|
||||
|
||||
ERROR_TYPES = {
|
||||
aten.minimum: RuntimeError,
|
||||
aten.maximum: RuntimeError,
|
||||
aten.argmax: RuntimeError,
|
||||
aten.argmin: RuntimeError,
|
||||
aten.sort: RuntimeError,
|
||||
aten.topk: RuntimeError,
|
||||
}
|
||||
|
||||
|
||||
for err_op in ERROR_OPS_LIST:
|
||||
globals()[_get_func_name(err_op)] = register_error(
|
||||
err_op, ERROR_TYPES.get(err_op, NotImplementedError)
|
||||
)
|
||||
|
||||
del err_op
|
||||
|
||||
|
||||
@register_complex(aten.masked_scatter)
|
||||
def masked_scatter_impl(
|
||||
self: ComplexTensor, mask: torch.Tensor, source: ComplexTensor
|
||||
) -> ComplexTensor:
|
||||
self_r, self_i = split_complex_tensor(self)
|
||||
source_r, source_i = split_complex_arg(source)
|
||||
ret_r = torch.masked_scatter(self_r, mask, source_r)
|
||||
ret_i = torch.masked_scatter(self_i, mask, source_i)
|
||||
|
||||
return ComplexTensor(ret_r, ret_i)
|
||||
|
||||
|
||||
@register_complex(aten.where)
|
||||
def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> ComplexTensor:
|
||||
x_r, x_i = split_complex_arg(x)
|
||||
y_r, y_i = split_complex_arg(y)
|
||||
|
||||
ret_r = torch.where(mask, x_r, y_r)
|
||||
ret_i = torch.where(mask, x_i, y_i)
|
||||
|
||||
return ComplexTensor(ret_r, ret_i)
|
||||
|
||||
|
||||
@register_complex(aten.full_like)
|
||||
def full_like_impl(
|
||||
input: ComplexTensor,
|
||||
fill_value: complex,
|
||||
*args,
|
||||
dtype: torch.dtype | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | ComplexTensor:
|
||||
# Note: Cannot be merged with the cases below due to the `fill_value` argument
|
||||
input_r, input_i = split_complex_tensor(input)
|
||||
if dtype is not None and dtype not in COMPLEX_TO_REAL:
|
||||
return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs)
|
||||
|
||||
if dtype is not None:
|
||||
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
|
||||
|
||||
fv_r, fv_i = split_complex_arg(fill_value)
|
||||
ret_r = torch.full_like(input_r, fv_r, *args, **kwargs)
|
||||
ret_i = torch.full_like(input_i, fv_i, *args, **kwargs)
|
||||
|
||||
return ComplexTensor(ret_r, ret_i)
|
||||
|
||||
|
||||
def register_like(op: OpType) -> Callable[..., torch.Tensor | ComplexTensor]:
|
||||
def impl(
|
||||
self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs
|
||||
) -> torch.Tensor | ComplexTensor:
|
||||
self_re, self_im = split_complex_tensor(self)
|
||||
|
||||
if dtype is not None and dtype not in COMPLEX_TO_REAL:
|
||||
return op(self_re, *args, dtype=dtype, **kwargs)
|
||||
|
||||
if dtype is not None:
|
||||
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
|
||||
|
||||
ret_re = op(self_re, *args, **kwargs)
|
||||
ret_im = op(self_im, *args, **kwargs)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
impl.__name__ = func_name
|
||||
impl.__qualname__ = func_name
|
||||
|
||||
return register_complex(op, impl)
|
||||
|
||||
|
||||
LIKE_OPS_LIST = [
|
||||
aten.empty_like,
|
||||
aten.zeros_like,
|
||||
aten.randn_like,
|
||||
aten.new_zeros,
|
||||
]
|
||||
|
||||
for like_op in LIKE_OPS_LIST:
|
||||
globals()[_get_func_name(like_op)] = register_like(like_op)
|
||||
|
||||
del like_op
|
||||
|
||||
|
||||
@register_complex(aten.cat)
|
||||
def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor:
|
||||
tensors_r = []
|
||||
tensors_i = []
|
||||
|
||||
for t in tensors:
|
||||
t_r, t_i = split_complex_arg(t)
|
||||
tensors_r.append(t_r)
|
||||
tensors_i.append(t_i)
|
||||
|
||||
ret_r = torch.cat(tensors_r, dim=dim)
|
||||
ret_i = torch.cat(tensors_i, dim=dim)
|
||||
|
||||
return ComplexTensor(ret_r, ret_i)
|
||||
|
||||
|
||||
@register_complex(aten.sgn)
|
||||
def sgn_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
self_r, self_i = split_complex_tensor(self)
|
||||
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
|
||||
abs_self = torch.abs(ComplexTensor(self_r, self_i))
|
||||
mask = (self_r != 0) | (self_i != 0)
|
||||
masked_sgn = ComplexTensor(
|
||||
(self_r / abs_self).to(out_dt), (self_i / abs_self).to(out_dt)
|
||||
)
|
||||
return torch.where(mask, masked_sgn, 0) # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.sqrt)
|
||||
def sqrt_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
self_r, self_i = split_complex_tensor(self)
|
||||
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
|
||||
self = ComplexTensor(self_r, self_i)
|
||||
self_abs_sqrt = torch.sqrt(torch.abs(self))
|
||||
self_half_angle = 0.5 * torch.angle(self)
|
||||
|
||||
ret_r = self_abs_sqrt * torch.cos(self_half_angle)
|
||||
ret_i = self_abs_sqrt * torch.sin(self_half_angle)
|
||||
|
||||
return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.rsqrt)
|
||||
def rsqrt_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
self_r, self_i = split_complex_tensor(self)
|
||||
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
|
||||
self = ComplexTensor(self_r, self_i)
|
||||
self_abs_rsqrt = torch.rsqrt(torch.abs(self))
|
||||
self_neg_half_angle = -0.5 * torch.angle(self)
|
||||
|
||||
ret_r = self_abs_rsqrt * torch.cos(self_neg_half_angle)
|
||||
ret_i = self_abs_rsqrt * torch.sin(self_neg_half_angle)
|
||||
|
||||
return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt))
|
||||
|
||||
|
||||
@register_complex(aten.addmm)
|
||||
def addmm_impl(
|
||||
input: ComplexTensor,
|
||||
mat1: ComplexTensor,
|
||||
mat2: ComplexTensor,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
beta: complex = 1,
|
||||
alpha: complex = 1,
|
||||
) -> ComplexTensor:
|
||||
ret = beta * input + alpha * torch.mm(mat1, mat2)
|
||||
assert isinstance(ret, ComplexTensor)
|
||||
ret_r, ret_i = split_complex_tensor(ret)
|
||||
if out_dtype is not None:
|
||||
out_dtype = COMPLEX_TO_REAL[out_dtype]
|
||||
ret_r, ret_i = ret_r.to(out_dtype), ret_i.to(out_dtype)
|
||||
return ComplexTensor(ret_r, ret_i)
|
||||
|
||||
|
||||
def elemwise_nonzero(self: ComplexTensor) -> torch.Tensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return (re != 0) | (im != 0)
|
||||
|
||||
|
||||
def register_nonzero_impl(op: OpType):
|
||||
def nonzero_impl(
|
||||
self: ComplexTensor, other: ComplexTensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
return op(elemwise_nonzero(self), elemwise_nonzero(other), *args, **kwargs)
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
nonzero_impl.__name__ = func_name
|
||||
nonzero_impl.__qualname__ = func_name
|
||||
|
||||
return register_complex(op, nonzero_impl)
|
||||
|
||||
|
||||
logical_and_impl = register_nonzero_impl(aten.logical_and)
|
||||
logical_or_impl = register_nonzero_impl(aten.logical_or)
|
||||
logical_xor_impl = register_nonzero_impl(aten.logical_xor)
|
||||
|
||||
|
||||
@register_complex(aten.logical_not)
|
||||
def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
return torch.logical_not(elemwise_nonzero(self), *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.view_as_real)
|
||||
def view_as_real_impl(self: ComplexTensor) -> torch.Tensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return torch.stack([re, im], dim=-1)
|
||||
|
||||
|
||||
@register_complex(aten.linalg_vector_norm)
|
||||
def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs)
|
||||
|
||||
|
||||
@register_force_test(aten.copy_)
|
||||
def copy__impl(self: ComplexTensor, src, *args, **kwargs):
|
||||
self_re, self_im = split_complex_tensor(self)
|
||||
src_re, src_im = split_complex_arg(src)
|
||||
|
||||
ret_re = self_re.copy_(src_re, *args, **kwargs)
|
||||
ret_im = self_im.copy_(src_im, *args, **kwargs)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
@register_complex(aten._local_scalar_dense)
|
||||
def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex:
|
||||
x, y = split_complex_tensor(self)
|
||||
u = aten._local_scalar_dense(x, *args, **kwargs)
|
||||
v = aten._local_scalar_dense(y, *args, **kwargs)
|
||||
return complex(u, v)
|
||||
|
||||
|
||||
@register_complex(aten.allclose)
|
||||
def allclose_impl(
|
||||
input: torch.Tensor,
|
||||
other: torch.Tensor,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
equal_nan: bool = False,
|
||||
) -> bool:
|
||||
return torch.all(
|
||||
torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
||||
).item() # type: ignore[bad-return]
|
||||
|
||||
|
||||
@register_complex(aten.stack)
|
||||
def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor:
|
||||
re_im_tuples = [split_complex_arg(self_i) for self_i in self]
|
||||
u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs)
|
||||
v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs)
|
||||
return ComplexTensor(u, v)
|
||||
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
@register_complex(aten._conj_physical)
|
||||
@register_complex(aten.conj_physical)
|
||||
def conj_physical_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return ComplexTensor(re, -im)
|
||||
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
@register_complex(aten._conj)
|
||||
def _conj_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
re, im = split_complex_tensor(self)
|
||||
return ComplexTensor(re, torch._neg_view(im))
|
||||
|
||||
|
||||
@register_complex(aten.index_add)
|
||||
def index_add_impl(
|
||||
self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs
|
||||
) -> ComplexTensor:
|
||||
alpha = kwargs.pop("alpha", None)
|
||||
if alpha is not None:
|
||||
source = source * alpha
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
source_re, source_im = split_complex_arg(source)
|
||||
|
||||
ret_re = self_re.index_add(dim, index, source_re)
|
||||
ret_im = self_im.index_add(dim, index, source_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
@register_complex(aten.index_add_)
|
||||
def index_add__impl(
|
||||
self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs
|
||||
) -> ComplexTensor:
|
||||
alpha = kwargs.pop("alpha", None)
|
||||
if alpha is not None:
|
||||
source = source * alpha
|
||||
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
source_re, source_im = split_complex_arg(source)
|
||||
|
||||
ret_re = self_re.index_add_(dim, index, source_re)
|
||||
ret_im = self_im.index_add_(dim, index, source_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
@register_complex(aten.masked_fill)
|
||||
def masked_fill_impl(
|
||||
self: ComplexTensor, mask: torch.Tensor, value: complex
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
value_re, value_im = split_complex_arg(value)
|
||||
|
||||
ret_re = self_re.masked_fill(mask, value_re)
|
||||
ret_im = self_im.masked_fill(mask, value_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
@register_complex(aten.masked_fill_)
|
||||
def masked_fill__impl(
|
||||
self: ComplexTensor, mask: torch.Tensor, value: complex
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
value_re, value_im = split_complex_arg(value)
|
||||
|
||||
ret_re = self_re.masked_fill_(mask, value_re)
|
||||
ret_im = self_im.masked_fill_(mask, value_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
@register_complex(aten.constant_pad_nd)
|
||||
def constant_pad_nd_impl(
|
||||
self: ComplexTensor, pad, value: complex | None = None
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_tensor(self)
|
||||
if value is None:
|
||||
ret_re = aten.constant_pad_nd(self_re, pad)
|
||||
ret_im = aten.constant_pad_nd(self_im, pad)
|
||||
else:
|
||||
value_re, value_im = split_complex_arg(value)
|
||||
ret_re = aten.constant_pad_nd(self_re, pad, value_re)
|
||||
ret_im = aten.constant_pad_nd(self_im, pad, value_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
@register_complex(aten.var)
|
||||
def var_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
|
||||
self_re, self_im = split_complex_tensor(self)
|
||||
return torch.var(self_re, *args, **kwargs) + torch.var(self_im, *args, **kwargs)
|
||||
|
||||
|
||||
@register_complex(aten.scatter_add)
|
||||
def scatter_add_impl(
|
||||
self: ComplexTensor, dim, index, src: ComplexTensor
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
src_re, src_im = split_complex_arg(src)
|
||||
|
||||
ret_re = torch.scatter_add(self_re, dim, index, src_re)
|
||||
ret_im = torch.scatter_add(self_im, dim, index, src_im)
|
||||
|
||||
return ComplexTensor(ret_re, ret_im)
|
||||
|
||||
|
||||
@register_complex(aten.scatter_add_)
|
||||
def scatter_add__impl(
|
||||
self: ComplexTensor, dim, index, src: ComplexTensor
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
src_re, src_im = split_complex_arg(src)
|
||||
|
||||
out_re = self_re.scatter_add_(dim, index, src_re)
|
||||
out_im = self_im.scatter_add_(dim, index, src_im)
|
||||
|
||||
return ComplexTensor(out_re, out_im)
|
||||
|
||||
|
||||
@register_complex(aten.index_put_)
|
||||
def index_put__impl(
|
||||
self: ComplexTensor,
|
||||
indices: tuple[torch.Tensor, ...],
|
||||
values: ComplexTensor,
|
||||
accumulate: bool = False,
|
||||
) -> ComplexTensor:
|
||||
self_re, self_im = split_complex_arg(self)
|
||||
values_re, values_im = split_complex_arg(values)
|
||||
|
||||
out_re = self_re.index_put_(indices, values_re, accumulate=accumulate)
|
||||
out_im = self_im.index_put_(indices, values_im, accumulate=accumulate)
|
||||
|
||||
return ComplexTensor(out_re, out_im)
|
||||
|
||||
|
||||
@register_complex(aten.tanh_backward)
|
||||
def tanh_backward(out_grad: torch.Tensor, y: torch.Tensor):
|
||||
return out_grad * (1.0 - y * y).conj_physical()
|
||||
|
||||
|
||||
@register_complex(aten.diagonal_backward)
|
||||
def diagonal_backward(
|
||||
grad_output: torch.Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int
|
||||
):
|
||||
grad_input = grad_output.new_zeros(input_sizes)
|
||||
return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
|
||||
|
||||
|
||||
def _dt_to_real(dt: torch.dtype | Any) -> torch.dtype | Any:
|
||||
if not isinstance(dt, torch.dtype):
|
||||
return dt
|
||||
|
||||
return COMPLEX_TO_REAL[dt]
|
||||
|
||||
|
||||
def register_to_impl(op: OpType):
|
||||
"""Register an op similar to `aten.to`, but may have different signatures."""
|
||||
|
||||
def impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
try:
|
||||
args = tuple(_dt_to_real(a) for a in args)
|
||||
kwargs = {k: _dt_to_real(v) for k, v in kwargs.items()}
|
||||
except KeyError:
|
||||
return op(x, *args, **kwargs)
|
||||
|
||||
return ComplexTensor(op(x, *args, **kwargs), op(y, *args, **kwargs))
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
impl.__name__ = func_name
|
||||
impl.__qualname__ = func_name
|
||||
|
||||
return register_complex(op, impl)
|
||||
|
||||
|
||||
to_impl = register_to_impl(aten.to)
|
||||
_to_copy_impl = register_to_impl(aten._to_copy)
|
||||
@ -1,317 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, overload, TypeAlias
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._decomp import get_decompositions
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
from torch._refs import is_complex as _is_complex
|
||||
from torch.types import Number
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
from .._core import ComplexTensor
|
||||
|
||||
|
||||
OpType: TypeAlias = OpOverloadPacket | OpOverload
|
||||
|
||||
TableType: TypeAlias = dict[OpType, Callable]
|
||||
|
||||
# Mapping from ops to implementations
|
||||
COMPLEX_OPS_TABLE: TableType = {}
|
||||
|
||||
COMPLEX_TO_REAL = {
|
||||
torch.complex128: torch.float64,
|
||||
torch.complex64: torch.float32,
|
||||
torch.complex32: torch.float16,
|
||||
}
|
||||
|
||||
REAL_TO_COMPLEX = {v: k for k, v in COMPLEX_TO_REAL.items()}
|
||||
|
||||
# Used to promote dtypes in `promote_real_cpu_tensors`
|
||||
PROMOTE_TYPES = {
|
||||
torch.float16: torch.float32,
|
||||
torch.bfloat16: torch.float32,
|
||||
torch.complex32: torch.complex64,
|
||||
}
|
||||
|
||||
|
||||
def is_complex_tensor(obj: Any, /) -> TypeIs[ComplexTensor]:
|
||||
r"""Returns True if the input is a ComplexTensor, else False
|
||||
|
||||
Args:
|
||||
a: any input
|
||||
|
||||
Examples:
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> from torch.complex import ComplexTensor
|
||||
>>> data = torch.zeros((3, 2), dtype=torch.complex64)
|
||||
>>> ct = ComplexTensor.from_interleaved(data)
|
||||
>>> is_complex_tensor(ct)
|
||||
True
|
||||
"""
|
||||
return isinstance(obj, ComplexTensor)
|
||||
|
||||
|
||||
@overload
|
||||
def promote_tensors(
|
||||
*tensors: ComplexTensor,
|
||||
) -> tuple[torch.dtype, tuple[ComplexTensor, ...]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def promote_tensors(
|
||||
*tensors: Tensor,
|
||||
) -> tuple[torch.dtype, tuple[Tensor, ...]]: ...
|
||||
|
||||
|
||||
def promote_tensors(
|
||||
*tensors: Tensor | ComplexTensor,
|
||||
) -> tuple[torch.dtype, tuple[Tensor | ComplexTensor, ...]]:
|
||||
"""
|
||||
Promotes all tensors to a common dtype.
|
||||
Additionally promotes CPU tensors to at least `float32`.
|
||||
"""
|
||||
tensor = next(t for t in tensors if isinstance(t, Tensor))
|
||||
out_dt = tensor.dtype
|
||||
for t in tensors:
|
||||
if isinstance(t, Tensor):
|
||||
out_dt = torch.promote_types(out_dt, t.dtype)
|
||||
|
||||
prom_dt = PROMOTE_TYPES.get(out_dt, out_dt)
|
||||
return out_dt, tuple(
|
||||
t.to(prom_dt) if isinstance(t, Tensor) else torch.asarray(t, dtype=prom_dt)
|
||||
for t in tensors
|
||||
)
|
||||
|
||||
|
||||
def register_complex(
|
||||
op: OpType,
|
||||
func_impl: Callable | None = None,
|
||||
):
|
||||
"""Decorator to register an implementation for some ops in some dispatch tables"""
|
||||
|
||||
def inner(func):
|
||||
if COMPLEX_OPS_TABLE.get(op, func) is not func:
|
||||
raise RuntimeError(f"Attempted to register multiple functions for {op}")
|
||||
COMPLEX_OPS_TABLE[op] = func
|
||||
return func
|
||||
|
||||
if func_impl is None:
|
||||
return inner
|
||||
|
||||
return inner(func_impl)
|
||||
|
||||
|
||||
FORCE_TEST_LIST: list[OpType] = []
|
||||
|
||||
|
||||
def register_force_test(op: OpType, *args, **kwargs):
|
||||
"""Will attempt to test these ops even if they err on "normal" inputs"""
|
||||
FORCE_TEST_LIST.append(op)
|
||||
return register_complex(op, *args, **kwargs)
|
||||
|
||||
|
||||
DECOMPOSITIONS = get_decompositions(list(torch.ops.aten)) # type: ignore[no-matching-overload]
|
||||
|
||||
|
||||
def lookup_complex(func: OpOverload, *args, **kwargs) -> Callable | None:
|
||||
"""
|
||||
Lookup an impl from the table.
|
||||
|
||||
Try the particular overload first, then the overload packet.
|
||||
|
||||
If nothing is found, try the decompositions with both.
|
||||
"""
|
||||
return COMPLEX_OPS_TABLE.get(
|
||||
func,
|
||||
COMPLEX_OPS_TABLE.get(
|
||||
func.overloadpacket,
|
||||
DECOMPOSITIONS.get(func, DECOMPOSITIONS.get(func.overloadpacket)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def is_complex(x: Any, /) -> bool:
|
||||
"""Utility to detect if a given object is (known) to be complex."""
|
||||
return (isinstance(x, Tensor) and _is_complex(x)) or isinstance(x, complex)
|
||||
|
||||
|
||||
@overload
|
||||
def split_complex_arg(
|
||||
arg: Tensor | ComplexTensor,
|
||||
) -> tuple[Tensor, Tensor]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def split_complex_arg(
|
||||
arg: complex | Number,
|
||||
) -> tuple[Number, Number]: ...
|
||||
|
||||
|
||||
def split_complex_arg(
|
||||
arg: Tensor | ComplexTensor | complex | Number,
|
||||
) -> tuple[Tensor, Tensor] | tuple[Number, Number]:
|
||||
"""
|
||||
Split a complex argument into a real/imaginary component.
|
||||
|
||||
If real, use zero for the imaginary part.
|
||||
"""
|
||||
if isinstance(arg, ComplexTensor):
|
||||
return split_complex_tensor(arg)
|
||||
if isinstance(arg, Tensor):
|
||||
if is_complex(arg):
|
||||
return arg.real, arg.imag
|
||||
return arg, torch.zeros_like(arg)
|
||||
# TODO (hameerabbasi): Should there be a `torch.SymComplex`?
|
||||
if isinstance(arg, complex):
|
||||
return arg.real, arg.imag
|
||||
if isinstance(arg, float | torch.SymFloat):
|
||||
return arg, 0.0
|
||||
if isinstance(arg, int | torch.SymInt):
|
||||
return arg, 0
|
||||
if isinstance(arg, bool | torch.SymBool):
|
||||
return arg, False
|
||||
raise TypeError(f"Expected tensor or number got, {type(arg)}")
|
||||
|
||||
|
||||
def split_complex_tensor(complex_tensor: ComplexTensor) -> tuple[Tensor, Tensor]:
|
||||
"""Split a ComplexTensor into its real and imaginary parts."""
|
||||
return complex_tensor.re, complex_tensor.im
|
||||
|
||||
|
||||
def complex_to_real_dtype(dtype: torch.dtype) -> torch.dtype:
|
||||
"""Convert a complex dtype to the dtype of its real part. Return other dtypes as-is."""
|
||||
return COMPLEX_TO_REAL.get(dtype, dtype)
|
||||
|
||||
|
||||
def _get_op_name(op: OpType) -> str:
|
||||
"""Get the op name from the op."""
|
||||
if isinstance(op, OpOverload):
|
||||
op = op.overloadpacket
|
||||
return str(op).split(".", 1)[1]
|
||||
|
||||
|
||||
def _get_func_name(op: OpType) -> str:
|
||||
"""Get the name of the implementation function from the op."""
|
||||
return f"{_get_op_name(op)}_impl"
|
||||
|
||||
|
||||
def register_error(op: OpType, exc_type: type[Exception] = NotImplementedError):
|
||||
msg = f"`aten.{_get_op_name(op)}` not implemented for `{ComplexTensor.__name__}`."
|
||||
|
||||
def ordered_impl(*args, **kwargs):
|
||||
raise exc_type(msg)
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
ordered_impl.__name__ = func_name
|
||||
ordered_impl.__qualname__ = func_name
|
||||
|
||||
return register_force_test(op, ordered_impl)
|
||||
|
||||
|
||||
def register_binary_nonlinear(op: OpType) -> Callable:
|
||||
"""Register a "multiplication-style" op, e.g. aten.mul, aten.mm, ..."""
|
||||
|
||||
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
|
||||
a_r, a_i = split_complex_arg(lhs)
|
||||
b_r, b_i = split_complex_arg(rhs)
|
||||
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
|
||||
real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs)
|
||||
imag = op(a_r, b_i, *args, **kwargs) + op(a_i, b_r, *args, **kwargs)
|
||||
return ComplexTensor(real.to(out_dt), imag.to(out_dt))
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
impl.__name__ = func_name
|
||||
impl.__qualname__ = func_name
|
||||
|
||||
return register_complex(op, impl)
|
||||
|
||||
|
||||
def register_simple(op: OpType):
|
||||
"""Register an op which can be applied independently to the real and complex parts to get the result."""
|
||||
|
||||
def impl(
|
||||
self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs
|
||||
) -> ComplexTensor:
|
||||
x, y = split_complex_tensor(self)
|
||||
if dtype is not None and dtype not in COMPLEX_TO_REAL:
|
||||
raise RuntimeError(
|
||||
"Non-complex `dtype` specified, please write custom impl."
|
||||
)
|
||||
|
||||
if dtype in COMPLEX_TO_REAL:
|
||||
assert dtype is not None
|
||||
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
|
||||
|
||||
u = op(x, *args, **kwargs)
|
||||
v = op(y, *args, **kwargs)
|
||||
|
||||
u_flat, u_spec = tree_flatten(u)
|
||||
v_flat, v_spec = tree_flatten(v)
|
||||
assert u_spec == v_spec
|
||||
out_flat = [
|
||||
ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False)
|
||||
]
|
||||
return tree_unflatten(out_flat, u_spec)
|
||||
|
||||
func_name = _get_func_name(op)
|
||||
impl.__name__ = func_name
|
||||
impl.__qualname__ = func_name
|
||||
|
||||
return register_complex(op, impl)
|
||||
|
||||
|
||||
def _as_complex_tensor(arg: Tensor | Any) -> Tensor | ComplexTensor | Any:
|
||||
"""Convert a Tensor with complex dtypes to a ComplexTensor. Pass along other args as-is."""
|
||||
if (
|
||||
not isinstance(arg, ComplexTensor)
|
||||
and isinstance(arg, Tensor)
|
||||
and arg.dtype in COMPLEX_TO_REAL
|
||||
):
|
||||
return ComplexTensor.from_interleaved(arg)
|
||||
return arg
|
||||
|
||||
|
||||
def _as_interleaved(arg: ComplexTensor | Any) -> Tensor | Any:
|
||||
"""Convert a ComplexTensor to a Tensor with a complex dtype. Pass other arguments as-is."""
|
||||
if isinstance(arg, ComplexTensor):
|
||||
return arg.as_interleaved()
|
||||
return arg
|
||||
|
||||
|
||||
class ComplexTensorMode(TorchDispatchMode):
|
||||
_compile: bool
|
||||
|
||||
""" A TorchDispatchMode to replace any Tensor that has a complex dtype with a ComplexTensor for the computation. """
|
||||
|
||||
def __init__(self, _dispatch_key=None, *, _compile: bool = False):
|
||||
"""Initialize a ComplexTensorMode.
|
||||
|
||||
Args:
|
||||
_dispatch_key: passed on to TorchDispatchMode
|
||||
_compile: Compile the op before the computation
|
||||
"""
|
||||
super().__init__(_dispatch_key)
|
||||
self._compile = _compile
|
||||
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: tuple[type],
|
||||
args: tuple = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# TODO (hameerabbasi): Test perf with `_compile` set to `True`
|
||||
if self._compile:
|
||||
func = torch.compile(func) # type: ignore[bad-assignment]
|
||||
|
||||
args = tree_map(_as_complex_tensor, args)
|
||||
kwargs = tree_map(_as_complex_tensor, kwargs)
|
||||
|
||||
return tree_map(_as_interleaved, func(*args, **kwargs))
|
||||
@ -1,34 +0,0 @@
|
||||
import torch
|
||||
|
||||
from .._core import ComplexTensor
|
||||
from .common import (
|
||||
complex_to_real_dtype,
|
||||
register_complex,
|
||||
register_force_test,
|
||||
split_complex_tensor,
|
||||
)
|
||||
|
||||
|
||||
prims = torch.ops.prims
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
# TODO (hameerabbasi): Not being tested
|
||||
@register_force_test(prims.convert_element_type)
|
||||
def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor:
|
||||
dtype = complex_to_real_dtype(dtype)
|
||||
u, v = split_complex_tensor(x)
|
||||
u_out = prims.convert_element_type(u, dtype)
|
||||
v_out = prims.convert_element_type(v, dtype)
|
||||
|
||||
return ComplexTensor(u_out, v_out)
|
||||
|
||||
|
||||
@register_complex(prims.conj_physical)
|
||||
def conj_physical_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
return aten._conj_physical(self)
|
||||
|
||||
|
||||
@register_complex(prims.conj)
|
||||
def conj_impl(self: ComplexTensor) -> ComplexTensor:
|
||||
return aten._conj(self)
|
||||
@ -492,17 +492,4 @@ void FileStore::wait(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> FileStore::listKeys() {
|
||||
std::unique_lock<std::mutex> l(activeFileOpLock_);
|
||||
File file(path_, O_RDONLY, timeout_);
|
||||
auto lock = file.lockShared();
|
||||
pos_ = refresh(file, pos_, cache_, deletePrefix_);
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(cache_.size());
|
||||
for (const auto& kv : cache_) {
|
||||
keys.push_back(kv.first.substr(regularPrefix_.size()));
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
@ -45,8 +45,6 @@ class TORCH_API FileStore : public Store {
|
||||
return path_;
|
||||
}
|
||||
|
||||
std::vector<std::string> listKeys() override;
|
||||
|
||||
protected:
|
||||
int64_t addHelper(const std::string& key, int64_t i);
|
||||
|
||||
|
||||
@ -217,14 +217,4 @@ int64_t HashStore::queueLen(const std::string& key) {
|
||||
return static_cast<int64_t>(it->second.size());
|
||||
}
|
||||
|
||||
std::vector<std::string> HashStore::listKeys() {
|
||||
std::unique_lock<std::mutex> lock(m_);
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(map_.size());
|
||||
for (const auto& kv : map_) {
|
||||
keys.push_back(kv.first);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
@ -59,8 +59,6 @@ class TORCH_API HashStore : public Store {
|
||||
|
||||
int64_t queueLen(const std::string& key) override;
|
||||
|
||||
std::vector<std::string> listKeys() override;
|
||||
|
||||
protected:
|
||||
bool checkLocked(
|
||||
const std::unique_lock<std::mutex>& lock,
|
||||
|
||||
@ -146,18 +146,4 @@ c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
|
||||
return store;
|
||||
}
|
||||
|
||||
std::vector<std::string> PrefixStore::listKeys() {
|
||||
auto keys = store_->listKeys();
|
||||
std::vector<std::string> filteredKeys;
|
||||
filteredKeys.reserve(keys.size());
|
||||
|
||||
for (auto& key : keys) {
|
||||
if (key.find(prefix_) == 0) {
|
||||
key = key.substr(prefix_.size() + 1);
|
||||
filteredKeys.push_back(std::move(key));
|
||||
}
|
||||
}
|
||||
return filteredKeys;
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
@ -64,8 +64,6 @@ class TORCH_API PrefixStore : public Store {
|
||||
// Recursively to fetch the store before layers of wrapping with PrefixStore.
|
||||
c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore();
|
||||
|
||||
std::vector<std::string> listKeys() override;
|
||||
|
||||
protected:
|
||||
std::string prefix_;
|
||||
c10::intrusive_ptr<Store> store_;
|
||||
|
||||
@ -114,11 +114,6 @@ class TORCH_API Store : public torch::CustomClassHolder {
|
||||
C10_THROW_ERROR(NotImplementedError, "queue support is not implemented.");
|
||||
}
|
||||
|
||||
virtual std::vector<std::string> listKeys() {
|
||||
C10_THROW_ERROR(
|
||||
NotImplementedError, "listKeys support is not implemented.");
|
||||
}
|
||||
|
||||
protected:
|
||||
std::chrono::milliseconds timeout_;
|
||||
};
|
||||
|
||||
@ -723,30 +723,6 @@ int64_t TCPStore::queueLen(const std::string& key) {
|
||||
return client_->receiveValue<int64_t>();
|
||||
}
|
||||
|
||||
std::vector<std::string> TCPStore::listKeys() {
|
||||
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__list);
|
||||
|
||||
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
||||
|
||||
detail::SendBuffer buffer(*client_, detail::QueryType::LIST_KEYS);
|
||||
buffer.flush();
|
||||
|
||||
auto numKeys = client_->receiveValue<int64_t>();
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(numKeys);
|
||||
for (auto i = 0; i < numKeys; ++i) {
|
||||
auto bits = client_->receiveBits();
|
||||
std::string str(bits.begin(), bits.end());
|
||||
if (str.find(keyPrefix_) == 0) {
|
||||
str = str.substr(keyPrefix_.size());
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
keys.emplace_back(str);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
bool TCPStore::hasExtendedApi() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -121,8 +121,6 @@ class TORCH_API TCPStore : public Store {
|
||||
|
||||
int64_t queueLen(const std::string& key) override;
|
||||
|
||||
std::vector<std::string> listKeys() override;
|
||||
|
||||
// Waits for all workers to join.
|
||||
void waitForWorkers();
|
||||
|
||||
|
||||
@ -78,7 +78,6 @@ class TCPStoreMasterDaemon : public BackgroundThread {
|
||||
void multiGetHandler(int socket);
|
||||
void multiSetHandler(int socket);
|
||||
void cancelWaitHandler(int socket);
|
||||
void listKeysHandler(int socket);
|
||||
void addMiscellaneousSocket(int socket);
|
||||
void removeMiscellaneousSocket(int socket);
|
||||
bool isMiscellaneousSocket(int socket);
|
||||
@ -296,8 +295,6 @@ void TCPStoreMasterDaemon::query(int socket) {
|
||||
multiSetHandler(socket);
|
||||
} else if (qt == QueryType::CANCEL_WAIT) {
|
||||
cancelWaitHandler(socket);
|
||||
} else if (qt == QueryType::LIST_KEYS) {
|
||||
listKeysHandler(socket);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unexpected query type");
|
||||
}
|
||||
@ -485,13 +482,6 @@ void TCPStoreMasterDaemon::cancelWaitHandler(int socket) {
|
||||
socket, detail::WaitResponseType::WAIT_CANCELED);
|
||||
}
|
||||
|
||||
void TCPStoreMasterDaemon::listKeysHandler(int socket) {
|
||||
tcputil::sendValue<size_t>(socket, tcpStore_.size());
|
||||
for (const auto& kv : tcpStore_) {
|
||||
tcputil::sendString(socket, kv.first);
|
||||
}
|
||||
}
|
||||
|
||||
bool TCPStoreMasterDaemon::checkKeys(
|
||||
const std::vector<std::string>& keys) const {
|
||||
return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) {
|
||||
|
||||
@ -36,7 +36,6 @@ enum class QueryType : uint8_t {
|
||||
QUEUE_PUSH,
|
||||
QUEUE_POP,
|
||||
QUEUE_LEN,
|
||||
LIST_KEYS,
|
||||
};
|
||||
|
||||
enum class CheckResponseType : uint8_t { READY, NOT_READY };
|
||||
|
||||
@ -683,7 +683,6 @@ class LibUVStoreDaemon : public BackgroundThread {
|
||||
const std::string& queueName,
|
||||
const c10::intrusive_ptr<UvHandle>& client);
|
||||
int64_t queueLen(const std::string& queueName);
|
||||
std::vector<std::string> listKeys();
|
||||
|
||||
void registerClient(const c10::intrusive_ptr<UvHandle>& client);
|
||||
void unregisterClient(const c10::intrusive_ptr<UvHandle>& client);
|
||||
@ -823,10 +822,6 @@ class UvClient : public UvTcpSocket {
|
||||
if (!parse_queue_len_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::LIST_KEYS:
|
||||
if (!parse_list_keys_command())
|
||||
return;
|
||||
break;
|
||||
default:
|
||||
C10D_DEBUG(
|
||||
"Client sent invalid command. client:{} command:{}",
|
||||
@ -1169,19 +1164,6 @@ class UvClient : public UvTcpSocket {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool parse_list_keys_command() {
|
||||
C10D_TRACE("list_keys address:{}", this->address());
|
||||
|
||||
auto keys = store->listKeys();
|
||||
StreamWriter sw(iptr());
|
||||
sw.write_value<int64_t>(static_cast<int64_t>(keys.size()));
|
||||
for (const auto& key : keys) {
|
||||
sw.write_string(key);
|
||||
}
|
||||
sw.send();
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
explicit UvClient(uv_loop_t* loop, LibUVStoreDaemon* store)
|
||||
: UvTcpSocket(loop), store(store) {}
|
||||
@ -1560,16 +1542,6 @@ int64_t LibUVStoreDaemon::queueLen(const std::string& key) {
|
||||
}
|
||||
return static_cast<int64_t>(it->second.size());
|
||||
}
|
||||
|
||||
std::vector<std::string> LibUVStoreDaemon::listKeys() {
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(tcpStore_.size());
|
||||
for (const auto& kv : tcpStore_) {
|
||||
keys.push_back(kv.first);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
std::unique_ptr<BackgroundThread> create_libuv_tcpstore_backend(
|
||||
|
||||
@ -1656,12 +1656,6 @@ See queue_push for more details.
|
||||
|
||||
Arguments:
|
||||
key (str): The key of the queue to get the length.
|
||||
)")
|
||||
.def(
|
||||
"list_keys",
|
||||
&::c10d::Store::listKeys,
|
||||
R"(
|
||||
Returns a list of all keys in the store.
|
||||
)")
|
||||
.def(
|
||||
"has_extended_api",
|
||||
|
||||
@ -26,18 +26,15 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding:
|
||||
|
||||
assert isinstance(input_spec, DTensorSpec)
|
||||
assert isinstance(weight_spec, DTensorSpec)
|
||||
# bias_spec can be None (optional parameter in aten.convolution schema)
|
||||
if bias_spec is not None:
|
||||
assert isinstance(bias_spec, DTensorSpec)
|
||||
assert isinstance(bias_spec, DTensorSpec)
|
||||
assert input_spec.tensor_meta is not None
|
||||
assert weight_spec.tensor_meta is not None
|
||||
in_shape = input_spec.tensor_meta.shape
|
||||
weight_shape = weight_spec.tensor_meta.shape
|
||||
assert isinstance(stride, list), f"stride must be list, got {type(stride)}"
|
||||
assert isinstance(padding, list), f"padding must be list, got {type(padding)}"
|
||||
assert isinstance(dilation, list), f"dilation must be list, got {type(dilation)}"
|
||||
# weight_shape might not be torch.Size in all cases (e.g., SymIntArrayRef during tracing)
|
||||
# so we don't assert its type, just use it
|
||||
assert isinstance(stride, list)
|
||||
assert isinstance(padding, list)
|
||||
assert isinstance(dilation, list)
|
||||
assert isinstance(weight_shape, torch.Size)
|
||||
out_conv_shape = [
|
||||
(d + 2 * padding[i] - dilation[i] * (weight_shape[i + 1] - 1) - 1) // stride[i]
|
||||
+ 1
|
||||
@ -85,21 +82,14 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
|
||||
assert isinstance(grad_output_spec, DTensorSpec)
|
||||
assert isinstance(input_spec, DTensorSpec)
|
||||
assert isinstance(weight_spec, DTensorSpec)
|
||||
# bias_shape_opt can be None (optional parameter in aten.convolution_backward schema)
|
||||
if bias_shape_opt is not None:
|
||||
assert isinstance(bias_shape_opt, list)
|
||||
assert isinstance(bias_shape_opt, list)
|
||||
assert input_spec.tensor_meta is not None
|
||||
weight_tensor_meta = weight_spec.tensor_meta
|
||||
|
||||
# Only create bias_tensor_meta if bias_shape_opt is not None
|
||||
if bias_shape_opt is not None:
|
||||
bias_tensor_meta = TensorMeta(
|
||||
torch.Size(bias_shape_opt),
|
||||
(1,),
|
||||
input_spec.tensor_meta.dtype,
|
||||
)
|
||||
else:
|
||||
bias_tensor_meta = None
|
||||
bias_tensor_meta = TensorMeta(
|
||||
torch.Size(bias_shape_opt),
|
||||
(1,),
|
||||
input_spec.tensor_meta.dtype,
|
||||
)
|
||||
|
||||
grad_input_spec = input_spec
|
||||
grad_weight_spec = DTensorSpec.from_dim_map(
|
||||
@ -108,18 +98,12 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
|
||||
[0],
|
||||
tensor_meta=weight_tensor_meta,
|
||||
)
|
||||
|
||||
# Only create grad_bias_spec if we have bias_tensor_meta
|
||||
if bias_tensor_meta is not None:
|
||||
grad_bias_spec = DTensorSpec.from_dim_map(
|
||||
input_spec.mesh,
|
||||
[-1],
|
||||
[0],
|
||||
tensor_meta=bias_tensor_meta,
|
||||
)
|
||||
else:
|
||||
grad_bias_spec = None
|
||||
|
||||
grad_bias_spec = DTensorSpec.from_dim_map(
|
||||
input_spec.mesh,
|
||||
[-1],
|
||||
[0],
|
||||
tensor_meta=bias_tensor_meta,
|
||||
)
|
||||
# TODO: actually the output_mask is not respected here, we should
|
||||
# set the corresponding spec to `None` if the output_mask is not `False`
|
||||
# for a certain output Tensor. This also applies to the conv handler
|
||||
|
||||
@ -275,16 +275,14 @@ class ShardingPropagator:
|
||||
output_tensor_meta_i = output_tensor_meta[i]
|
||||
if not isinstance(output_tensor_meta_i, TensorMeta):
|
||||
# NOTE: aten.convolution_backward.default is an exception and it
|
||||
# needs extra handling because any Tensor in the output tuple
|
||||
# can be `None` depending on the output_mask parameter. This can
|
||||
# occur during double backpropagation or when certain gradients
|
||||
# are not needed (e.g., grad_input when input has requires_grad=False,
|
||||
# grad_weight/grad_bias when weight/bias have requires_grad=False,
|
||||
# or grad_bias when bias is None). We explicitly allow the
|
||||
# corresponding TensorMeta to be `None`.
|
||||
# needs extra handling because the first Tensor in the output
|
||||
# tuple can be `None` if the input Tensor to convolution op has
|
||||
# `requires_grad=False` (e.g. convolution layer is the first
|
||||
# layer in the model). We explicitly allow its corresponding
|
||||
# TensorMeta to be `None`.
|
||||
if (
|
||||
op == aten.convolution_backward.default
|
||||
and i in (0, 1, 2)
|
||||
and i == 0
|
||||
and output_tensor_meta_i is None
|
||||
):
|
||||
assert isinstance(output_specs, list)
|
||||
|
||||
Reference in New Issue
Block a user