Compare commits

..

1 Commits

Author SHA1 Message Date
f203b98062 Add check for MacOS26 to use different code path in SDPA 2025-11-17 15:38:19 -08:00
67 changed files with 495 additions and 5232 deletions

View File

@ -1,19 +0,0 @@
# Aarch64 (ARM/Graviton) Support Scripts
Scripts for building aarch64 PyTorch PIP Wheels. These scripts build the following wheels:
* torch
* torchvision
* torchaudio
* torchtext
* torchdata
## Aarch64_ci_build.sh
This script is design to support CD operations within PyPi manylinux aarch64 container, and be executed in the container. It prepares the container and then executes __aarch64_wheel_ci_build.py__ to build the wheels. The script "assumes" the PyTorch repo is located at: ```/pytorch``` and will put the wheels into ```/artifacts```.
### Usage
```DESIRED_PYTHON=<PythonVersion> aarch64_ci_build.sh```
__NOTE:__ CI build is currently __EXPERMINTAL__
## Build_aarch64_wheel.py
This app allows a person to build using AWS EC3 resources and requires AWS-CLI and Boto3 with AWS credentials to support building EC2 instances for the wheel builds. Can be used in a codebuild CD or from a local system.
### Usage
```build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch <RCtag>```

View File

@ -1,53 +0,0 @@
#!/bin/bash
set -eux -o pipefail
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
# Set CUDA architecture lists to match x86 build_cuda.sh
if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
elif [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
elif [[ "$GPU_ARCH_VERSION" == *"13.0"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;11.0;12.0+PTX"
fi
# Compress the fatbin with -compress-mode=size for CUDA 13
if [[ "$DESIRED_CUDA" == *"13"* ]]; then
export TORCH_NVCC_FLAGS="-compress-mode=size"
# Bundle ptxas into the cu13 wheel, see https://github.com/pytorch/pytorch/issues/163801
export BUILD_BUNDLE_PTXAS=1
fi
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
source $SCRIPTPATH/aarch64_ci_setup.sh
###############################################################################
# Run aarch64 builder python
###############################################################################
cd /
# adding safe directory for git as the permissions will be
# on the mounted pytorch repo
git config --global --add safe.directory /pytorch
pip install -r /pytorch/requirements.txt
pip install auditwheel==6.2.0 wheel
if [ "$DESIRED_CUDA" = "cpu" ]; then
echo "BASE_CUDA_VERSION is not set. Building cpu wheel."
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
else
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
export USE_SYSTEM_NCCL=1
# Check if we should use NVIDIA libs from PyPI (similar to x86 build_cuda.sh logic)
if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
echo "Bundling CUDA libraries with wheel for aarch64."
else
echo "Using nvidia libs from pypi for aarch64."
echo "Updated PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64: $PYTORCH_EXTRA_INSTALL_REQUIREMENTS"
export USE_NVIDIA_PYPI_LIBS=1
fi
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
fi

View File

@ -1,21 +0,0 @@
#!/bin/bash
set -eux -o pipefail
# This script is used to prepare the Docker container for aarch64_ci_wheel_build.py python script
# By creating symlinks from desired /opt/python to /usr/local/bin/
NUMPY_VERSION=2.0.2
if [[ "$DESIRED_PYTHON" == "3.13" || "$DESIRED_PYTHON" == "3.13t" ]]; then
NUMPY_VERSION=2.1.2
fi
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
source $SCRIPTPATH/../manywheel/set_desired_python.sh
pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2
for tool in python python3 pip pip3 ninja scons patchelf; do
ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin;
done
python --version

View File

@ -1,333 +0,0 @@
#!/usr/bin/env python3
# encoding: UTF-8
import os
import shutil
from subprocess import check_call, check_output
def list_dir(path: str) -> list[str]:
"""'
Helper for getting paths for Python
"""
return check_output(["ls", "-1", path]).decode().split("\n")
def replace_tag(filename) -> None:
with open(filename) as f:
lines = f.readlines()
for i, line in enumerate(lines):
if line.startswith("Tag:"):
lines[i] = line.replace("-linux_", "-manylinux_2_28_")
print(f"Updated tag from {line} to {lines[i]}")
break
with open(filename, "w") as f:
f.writelines(lines)
def patch_library_rpath(
folder: str,
lib_name: str,
use_nvidia_pypi_libs: bool = False,
desired_cuda: str = "",
) -> None:
"""Apply patchelf to set RPATH for a library in torch/lib"""
lib_path = f"{folder}/tmp/torch/lib/{lib_name}"
if use_nvidia_pypi_libs:
# For PyPI NVIDIA libraries, construct CUDA RPATH
cuda_rpaths = [
"$ORIGIN/../../nvidia/cudnn/lib",
"$ORIGIN/../../nvidia/nvshmem/lib",
"$ORIGIN/../../nvidia/nccl/lib",
"$ORIGIN/../../nvidia/cusparselt/lib",
]
if "130" in desired_cuda:
cuda_rpaths.append("$ORIGIN/../../nvidia/cu13/lib")
else:
cuda_rpaths.extend(
[
"$ORIGIN/../../nvidia/cublas/lib",
"$ORIGIN/../../nvidia/cuda_cupti/lib",
"$ORIGIN/../../nvidia/cuda_nvrtc/lib",
"$ORIGIN/../../nvidia/cuda_runtime/lib",
"$ORIGIN/../../nvidia/cufft/lib",
"$ORIGIN/../../nvidia/curand/lib",
"$ORIGIN/../../nvidia/cusolver/lib",
"$ORIGIN/../../nvidia/cusparse/lib",
"$ORIGIN/../../nvidia/nvtx/lib",
"$ORIGIN/../../nvidia/cufile/lib",
]
)
# Add $ORIGIN for local torch libs
rpath = ":".join(cuda_rpaths) + ":$ORIGIN"
else:
# For bundled libraries, just use $ORIGIN
rpath = "$ORIGIN"
if os.path.exists(lib_path):
os.system(
f"cd {folder}/tmp/torch/lib/; "
f"patchelf --set-rpath '{rpath}' --force-rpath {lib_name}"
)
def copy_and_patch_library(
src_path: str,
folder: str,
use_nvidia_pypi_libs: bool = False,
desired_cuda: str = "",
) -> None:
"""Copy a library to torch/lib and patch its RPATH"""
if os.path.exists(src_path):
lib_name = os.path.basename(src_path)
shutil.copy2(src_path, f"{folder}/tmp/torch/lib/{lib_name}")
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
def package_cuda_wheel(wheel_path, desired_cuda) -> None:
"""
Package the cuda wheel libraries
"""
folder = os.path.dirname(wheel_path)
os.mkdir(f"{folder}/tmp")
os.system(f"unzip {wheel_path} -d {folder}/tmp")
# Delete original wheel since it will be repackaged
os.system(f"rm {wheel_path}")
# Check if we should use PyPI NVIDIA libraries or bundle system libraries
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
if use_nvidia_pypi_libs:
print("Using nvidia libs from pypi - skipping CUDA library bundling")
# For PyPI approach, we don't bundle CUDA libraries - they come from PyPI packages
# We only need to bundle non-NVIDIA libraries
minimal_libs_to_copy = [
"/lib64/libgomp.so.1",
"/usr/lib64/libgfortran.so.5",
"/acl/build/libarm_compute.so",
"/acl/build/libarm_compute_graph.so",
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_lapack_core.so.0",
"/usr/local/lib/libnvpl_blas_core.so.0",
]
# Copy minimal libraries to unzipped_folder/torch/lib
for lib_path in minimal_libs_to_copy:
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
# Patch torch libraries used for searching libraries
torch_libs_to_patch = [
"libtorch.so",
"libtorch_cpu.so",
"libtorch_cuda.so",
"libtorch_cuda_linalg.so",
"libtorch_global_deps.so",
"libtorch_python.so",
"libtorch_nvshmem.so",
"libc10.so",
"libc10_cuda.so",
"libcaffe2_nvrtc.so",
"libshm.so",
]
for lib_name in torch_libs_to_patch:
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
else:
print("Bundling CUDA libraries with wheel")
# Original logic for bundling system CUDA libraries
# Common libraries for all CUDA versions
common_libs = [
# Non-NVIDIA system libraries
"/lib64/libgomp.so.1",
"/usr/lib64/libgfortran.so.5",
"/acl/build/libarm_compute.so",
"/acl/build/libarm_compute_graph.so",
# Common CUDA libraries (same for all versions)
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_lapack_core.so.0",
"/usr/local/lib/libnvpl_blas_core.so.0",
"/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so",
"/usr/local/cuda/lib64/libcudnn.so.9",
"/usr/local/cuda/lib64/libcusparseLt.so.0",
"/usr/local/cuda/lib64/libcurand.so.10",
"/usr/local/cuda/lib64/libnccl.so.2",
"/usr/local/cuda/lib64/libnvshmem_host.so.3",
"/usr/local/cuda/lib64/libcudnn_adv.so.9",
"/usr/local/cuda/lib64/libcudnn_cnn.so.9",
"/usr/local/cuda/lib64/libcudnn_graph.so.9",
"/usr/local/cuda/lib64/libcudnn_ops.so.9",
"/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9",
"/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9",
"/usr/local/cuda/lib64/libcudnn_heuristic.so.9",
"/usr/local/cuda/lib64/libcufile.so.0",
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
"/usr/local/cuda/lib64/libcusparse.so.12",
]
# CUDA version-specific libraries
if "13" in desired_cuda:
minor_version = desired_cuda[-1]
version_specific_libs = [
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13",
"/usr/local/cuda/lib64/libcublas.so.13",
"/usr/local/cuda/lib64/libcublasLt.so.13",
"/usr/local/cuda/lib64/libcudart.so.13",
"/usr/local/cuda/lib64/libcufft.so.12",
"/usr/local/cuda/lib64/libcusolver.so.12",
"/usr/local/cuda/lib64/libnvJitLink.so.13",
"/usr/local/cuda/lib64/libnvrtc.so.13",
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.{minor_version}",
]
elif "12" in desired_cuda:
# Get the last character for libnvrtc-builtins version (e.g., "129" -> "9")
minor_version = desired_cuda[-1]
version_specific_libs = [
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12",
"/usr/local/cuda/lib64/libcublas.so.12",
"/usr/local/cuda/lib64/libcublasLt.so.12",
"/usr/local/cuda/lib64/libcudart.so.12",
"/usr/local/cuda/lib64/libcufft.so.11",
"/usr/local/cuda/lib64/libcusolver.so.11",
"/usr/local/cuda/lib64/libnvJitLink.so.12",
"/usr/local/cuda/lib64/libnvrtc.so.12",
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.{minor_version}",
]
else:
raise ValueError(f"Unsupported CUDA version: {desired_cuda}.")
# Combine all libraries
libs_to_copy = common_libs + version_specific_libs
# Copy libraries to unzipped_folder/torch/lib
for lib_path in libs_to_copy:
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
# Make sure the wheel is tagged with manylinux_2_28
for f in os.scandir(f"{folder}/tmp/"):
if f.is_dir() and f.name.endswith(".dist-info"):
replace_tag(f"{f.path}/WHEEL")
break
os.system(f"wheel pack {folder}/tmp/ -d {folder}")
os.system(f"rm -rf {folder}/tmp/")
def complete_wheel(folder: str) -> str:
"""
Complete wheel build and put in artifact location
"""
wheel_name = list_dir(f"/{folder}/dist")[0]
# Please note for cuda we don't run auditwheel since we use custom script to package
# the cuda dependencies to the wheel file using update_wheel() method.
# However we need to make sure filename reflects the correct Manylinux platform.
if "pytorch" in folder and not enable_cuda:
print("Repairing Wheel with AuditWheel")
check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder)
repaired_wheel_name = list_dir(f"/{folder}/wheelhouse")[0]
print(f"Moving {repaired_wheel_name} wheel to /{folder}/dist")
os.rename(
f"/{folder}/wheelhouse/{repaired_wheel_name}",
f"/{folder}/dist/{repaired_wheel_name}",
)
else:
repaired_wheel_name = list_dir(f"/{folder}/dist")[0]
print(f"Copying {repaired_wheel_name} to artifacts")
shutil.copy2(
f"/{folder}/dist/{repaired_wheel_name}", f"/artifacts/{repaired_wheel_name}"
)
return repaired_wheel_name
def parse_arguments():
"""
Parse inline arguments
"""
from argparse import ArgumentParser
parser = ArgumentParser("AARCH64 wheels python CD")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")
parser.add_argument("--test-only", type=str)
parser.add_argument("--enable-mkldnn", action="store_true")
parser.add_argument("--enable-cuda", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
"""
Entry Point
"""
args = parse_arguments()
enable_mkldnn = args.enable_mkldnn
enable_cuda = args.enable_cuda
branch = check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd="/pytorch"
).decode()
print("Building PyTorch wheel")
build_vars = ""
# MAX_JOB=5 is not required for CPU backend (see commit 465d98b)
if enable_cuda:
build_vars += "MAX_JOBS=5 "
# Handle PyPI NVIDIA libraries vs bundled libraries
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
if use_nvidia_pypi_libs:
print("Configuring build for PyPI NVIDIA libraries")
# Configure for dynamic linking (matching x86 logic)
build_vars += "ATEN_STATIC_CUDA=0 USE_CUDA_STATIC_LINK=0 USE_CUPTI_SO=1 "
else:
print("Configuring build for bundled NVIDIA libraries")
# Keep existing static linking approach - already configured above
override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION")
desired_cuda = os.getenv("DESIRED_CUDA")
if override_package_version is not None:
version = override_package_version
build_vars += (
f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version} PYTORCH_BUILD_NUMBER=1 "
)
elif branch in ["nightly", "main"]:
build_date = (
check_output(["git", "log", "--pretty=format:%cs", "-1"], cwd="/pytorch")
.decode()
.replace("-", "")
)
version = (
check_output(["cat", "version.txt"], cwd="/pytorch").decode().strip()[:-2]
)
if enable_cuda:
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date}+{desired_cuda} PYTORCH_BUILD_NUMBER=1 "
else:
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 "
elif branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
if enable_mkldnn:
print("build pytorch with mkldnn+acl backend")
build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
build_vars += "ACL_ROOT_DIR=/acl "
if enable_cuda:
build_vars += "BLAS=NVPL "
else:
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS "
else:
print("build pytorch without mkldnn backend")
os.system(f"cd /pytorch; {build_vars} python3 -m build --wheel --no-isolation")
if enable_cuda:
print("Updating Cuda Dependency")
filename = os.listdir("/pytorch/dist/")
wheel_path = f"/pytorch/dist/{filename[0]}"
package_cuda_wheel(wheel_path, desired_cuda)
pytorch_wheel_name = complete_wheel("/pytorch/")
print(f"Build Complete. Created {pytorch_wheel_name}..")

View File

@ -1,999 +0,0 @@
#!/usr/bin/env python3
# This script is for building AARCH64 wheels using AWS EC2 instances.
# To generate binaries for the release follow these steps:
# 1. Update mappings for each of the Domain Libraries by adding new row to a table like this:
# "v1.11.0": ("0.11.0", "rc1"),
# 2. Run script with following arguments for each of the supported python versions and required tag, for example:
# build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch v1.11.0-rc3
import os
import subprocess
import sys
import time
from typing import Optional, Union
import boto3
# AMI images for us-east-1, change the following based on your ~/.aws/config
os_amis = {
"ubuntu20_04": "ami-052eac90edaa9d08f", # login_name: ubuntu
"ubuntu22_04": "ami-0c6c29c5125214c77", # login_name: ubuntu
"redhat8": "ami-0698b90665a2ddcf1", # login_name: ec2-user
}
ubuntu20_04_ami = os_amis["ubuntu20_04"]
def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]:
if key_name is None:
key_name = os.getenv("AWS_KEY_NAME")
if key_name is None:
return os.getenv("SSH_KEY_PATH", ""), ""
homedir_path = os.path.expanduser("~")
default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem")
return os.getenv("SSH_KEY_PATH", default_path), key_name
ec2 = boto3.resource("ec2")
def ec2_get_instances(filter_name, filter_value):
return ec2.instances.filter(
Filters=[{"Name": filter_name, "Values": [filter_value]}]
)
def ec2_instances_of_type(instance_type="t4g.2xlarge"):
return ec2_get_instances("instance-type", instance_type)
def ec2_instances_by_id(instance_id):
rc = list(ec2_get_instances("instance-id", instance_id))
return rc[0] if len(rc) > 0 else None
def start_instance(
key_name, ami=ubuntu20_04_ami, instance_type="t4g.2xlarge", ebs_size: int = 50
):
inst = ec2.create_instances(
ImageId=ami,
InstanceType=instance_type,
SecurityGroups=["ssh-allworld"],
KeyName=key_name,
MinCount=1,
MaxCount=1,
BlockDeviceMappings=[
{
"DeviceName": "/dev/sda1",
"Ebs": {
"DeleteOnTermination": True,
"VolumeSize": ebs_size,
"VolumeType": "standard",
},
}
],
)[0]
print(f"Create instance {inst.id}")
inst.wait_until_running()
running_inst = ec2_instances_by_id(inst.id)
print(f"Instance started at {running_inst.public_dns_name}")
return running_inst
class RemoteHost:
addr: str
keyfile_path: str
login_name: str
container_id: Optional[str] = None
ami: Optional[str] = None
def __init__(self, addr: str, keyfile_path: str, login_name: str = "ubuntu"):
self.addr = addr
self.keyfile_path = keyfile_path
self.login_name = login_name
def _gen_ssh_prefix(self) -> list[str]:
return [
"ssh",
"-o",
"StrictHostKeyChecking=no",
"-i",
self.keyfile_path,
f"{self.login_name}@{self.addr}",
"--",
]
@staticmethod
def _split_cmd(args: Union[str, list[str]]) -> list[str]:
return args.split() if isinstance(args, str) else args
def run_ssh_cmd(self, args: Union[str, list[str]]) -> None:
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
def check_ssh_output(self, args: Union[str, list[str]]) -> str:
return subprocess.check_output(
self._gen_ssh_prefix() + self._split_cmd(args)
).decode("utf-8")
def scp_upload_file(self, local_file: str, remote_file: str) -> None:
subprocess.check_call(
[
"scp",
"-i",
self.keyfile_path,
local_file,
f"{self.login_name}@{self.addr}:{remote_file}",
]
)
def scp_download_file(
self, remote_file: str, local_file: Optional[str] = None
) -> None:
if local_file is None:
local_file = "."
subprocess.check_call(
[
"scp",
"-i",
self.keyfile_path,
f"{self.login_name}@{self.addr}:{remote_file}",
local_file,
]
)
def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None:
self.run_ssh_cmd("sudo apt-get install -y docker.io")
self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}")
self.run_ssh_cmd("sudo service docker start")
self.run_ssh_cmd(f"docker pull {image}")
self.container_id = self.check_ssh_output(
f"docker run -t -d -w /root {image}"
).strip()
def using_docker(self) -> bool:
return self.container_id is not None
def run_cmd(self, args: Union[str, list[str]]) -> None:
if not self.using_docker():
return self.run_ssh_cmd(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + [
"docker",
"exec",
"-i",
self.container_id,
"bash",
]
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE)
p.communicate(
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
"utf-8"
)
)
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd)
def check_output(self, args: Union[str, list[str]]) -> str:
if not self.using_docker():
return self.check_ssh_output(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + [
"docker",
"exec",
"-i",
self.container_id,
"bash",
]
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
(out, err) = p.communicate(
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
"utf-8"
)
)
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err)
return out.decode("utf-8")
def upload_file(self, local_file: str, remote_file: str) -> None:
if not self.using_docker():
return self.scp_upload_file(local_file, remote_file)
tmp_file = os.path.join("/tmp", os.path.basename(local_file))
self.scp_upload_file(local_file, tmp_file)
self.run_ssh_cmd(
["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"]
)
self.run_ssh_cmd(["rm", tmp_file])
def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None:
if not self.using_docker():
return self.scp_download_file(remote_file, local_file)
tmp_file = os.path.join("/tmp", os.path.basename(remote_file))
self.run_ssh_cmd(
["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file]
)
self.scp_download_file(tmp_file, local_file)
self.run_ssh_cmd(["rm", tmp_file])
def download_wheel(
self, remote_file: str, local_file: Optional[str] = None
) -> None:
if self.using_docker() and local_file is None:
basename = os.path.basename(remote_file)
local_file = basename.replace(
"-linux_aarch64.whl", "-manylinux2014_aarch64.whl"
)
self.download_file(remote_file, local_file)
def list_dir(self, path: str) -> list[str]:
return self.check_output(["ls", "-1", path]).split("\n")
def wait_for_connection(addr, port, timeout=15, attempt_cnt=5):
import socket
for i in range(attempt_cnt):
try:
with socket.create_connection((addr, port), timeout=timeout):
return
except (ConnectionRefusedError, TimeoutError): # noqa: PERF203
if i == attempt_cnt - 1:
raise
time.sleep(timeout)
def update_apt_repo(host: RemoteHost) -> None:
time.sleep(5)
host.run_cmd("sudo systemctl stop apt-daily.service || true")
host.run_cmd("sudo systemctl stop unattended-upgrades.service || true")
host.run_cmd(
"while systemctl is-active --quiet apt-daily.service; do sleep 1; done"
)
host.run_cmd(
"while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done"
)
host.run_cmd("sudo apt-get update")
time.sleep(3)
host.run_cmd("sudo apt-get update")
def install_condaforge(
host: RemoteHost, suffix: str = "latest/download/Miniforge3-Linux-aarch64.sh"
) -> None:
print("Install conda-forge")
host.run_cmd(f"curl -OL https://github.com/conda-forge/miniforge/releases/{suffix}")
host.run_cmd(f"sh -f {os.path.basename(suffix)} -b")
host.run_cmd(f"rm -f {os.path.basename(suffix)}")
if host.using_docker():
host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc")
else:
host.run_cmd(
[
"sed",
"-i",
"'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH'",
".bashrc",
]
)
def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None:
if python_version == "3.6":
# Python-3.6 EOLed and not compatible with conda-4.11
install_condaforge(
host, suffix="download/4.10.3-10/Miniforge3-4.10.3-10-Linux-aarch64.sh"
)
host.run_cmd(f"conda install -y python={python_version} numpy pyyaml")
else:
install_condaforge(
host, suffix="download/4.11.0-4/Miniforge3-4.11.0-4-Linux-aarch64.sh"
)
# Pytorch-1.10 or older are not compatible with setuptools=59.6 or newer
host.run_cmd(
f"conda install -y python={python_version} numpy pyyaml setuptools>=59.5.0"
)
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
host.run_cmd("pip3 install auditwheel")
host.run_cmd(
"conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf"
)
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as tmp:
tmp.write(embed_library_script.encode("utf-8"))
tmp.flush()
host.upload_file(tmp.name, "embed_library.py")
print("Embedding libgomp into wheel")
if host.using_docker():
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
else:
host.run_cmd(f"python3 embed_library.py {wheel_name}")
def checkout_repo(
host: RemoteHost,
*,
branch: str = "main",
url: str,
git_clone_flags: str,
mapping: dict[str, tuple[str, str]],
) -> Optional[str]:
for prefix in mapping:
if not branch.startswith(prefix):
continue
tag = f"v{mapping[prefix][0]}-{mapping[prefix][1]}"
host.run_cmd(f"git clone {url} -b {tag} {git_clone_flags}")
return mapping[prefix][0]
host.run_cmd(f"git clone {url} -b {branch} {git_clone_flags}")
return None
def build_torchvision(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str,
run_smoke_tests: bool = True,
) -> str:
print("Checking out TorchVision repo")
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/vision",
git_clone_flags=git_clone_flags,
mapping={
"v1.7.1": ("0.8.2", "rc2"),
"v1.8.0": ("0.9.0", "rc3"),
"v1.8.1": ("0.9.1", "rc1"),
"v1.9.0": ("0.10.0", "rc1"),
"v1.10.0": ("0.11.1", "rc1"),
"v1.10.1": ("0.11.2", "rc1"),
"v1.10.2": ("0.11.3", "rc1"),
"v1.11.0": ("0.12.0", "rc1"),
"v1.12.0": ("0.13.0", "rc4"),
"v1.12.1": ("0.13.1", "rc6"),
"v1.13.0": ("0.14.0", "rc4"),
"v1.13.1": ("0.14.1", "rc2"),
"v2.0.0": ("0.15.1", "rc2"),
"v2.0.1": ("0.15.2", "rc2"),
},
)
print("Building TorchVision wheel")
# Please note libnpg and jpeg are required to build image.so extension
if use_conda:
host.run_cmd("conda install -y libpng jpeg")
# Remove .so files to force static linking
host.run_cmd(
"rm miniforge3/lib/libpng.so miniforge3/lib/libpng16.so miniforge3/lib/libjpeg.so"
)
# And patch setup.py to include libz dependency for libpng
host.run_cmd(
[
'sed -i -e \'s/image_link_flags\\.append("png")/image_link_flags += ["png", "z"]/\' vision/setup.py'
]
)
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"]
).strip()
if len(version) == 0:
# In older revisions, version was embedded in setup.py
version = (
host.check_output(["grep", '"version = \'"', "vision/setup.py"])
.strip()
.split("'")[1][:-2]
)
build_date = (
host.check_output("cd vision && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd vision && {build_vars} python3 -m build --wheel --no-isolation")
vision_wheel_name = host.list_dir("vision/dist")[0]
embed_libgomp(host, use_conda, os.path.join("vision", "dist", vision_wheel_name))
print("Copying TorchVision wheel")
host.download_wheel(os.path.join("vision", "dist", vision_wheel_name))
if run_smoke_tests:
host.run_cmd(
f"pip3 install {os.path.join('vision', 'dist', vision_wheel_name)}"
)
host.run_cmd("python3 vision/test/smoke_test.py")
print("Delete vision checkout")
host.run_cmd("rm -rf vision")
return vision_wheel_name
def build_torchdata(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchData repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/data",
git_clone_flags=git_clone_flags,
mapping={
"v1.13.1": ("0.5.1", ""),
"v2.0.0": ("0.6.0", "rc5"),
"v2.0.1": ("0.6.1", "rc1"),
},
)
print("Building TorchData wheel")
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f data/version.txt ]; then cat data/version.txt; fi"]
).strip()
build_date = (
host.check_output("cd data && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd data && {build_vars} python3 -m build --wheel --no-isolation")
wheel_name = host.list_dir("data/dist")[0]
embed_libgomp(host, use_conda, os.path.join("data", "dist", wheel_name))
print("Copying TorchData wheel")
host.download_wheel(os.path.join("data", "dist", wheel_name))
return wheel_name
def build_torchtext(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchText repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/text",
git_clone_flags=git_clone_flags,
mapping={
"v1.9.0": ("0.10.0", "rc1"),
"v1.10.0": ("0.11.0", "rc2"),
"v1.10.1": ("0.11.1", "rc1"),
"v1.10.2": ("0.11.2", "rc1"),
"v1.11.0": ("0.12.0", "rc1"),
"v1.12.0": ("0.13.0", "rc2"),
"v1.12.1": ("0.13.1", "rc5"),
"v1.13.0": ("0.14.0", "rc3"),
"v1.13.1": ("0.14.1", "rc1"),
"v2.0.0": ("0.15.1", "rc2"),
"v2.0.1": ("0.15.2", "rc2"),
},
)
print("Building TorchText wheel")
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f text/version.txt ]; then cat text/version.txt; fi"]
).strip()
build_date = (
host.check_output("cd text && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd text && {build_vars} python3 -m build --wheel --no-isolation")
wheel_name = host.list_dir("text/dist")[0]
embed_libgomp(host, use_conda, os.path.join("text", "dist", wheel_name))
print("Copying TorchText wheel")
host.download_wheel(os.path.join("text", "dist", wheel_name))
return wheel_name
def build_torchaudio(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchAudio repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/audio",
git_clone_flags=git_clone_flags,
mapping={
"v1.9.0": ("0.9.0", "rc2"),
"v1.10.0": ("0.10.0", "rc5"),
"v1.10.1": ("0.10.1", "rc1"),
"v1.10.2": ("0.10.2", "rc1"),
"v1.11.0": ("0.11.0", "rc1"),
"v1.12.0": ("0.12.0", "rc3"),
"v1.12.1": ("0.12.1", "rc5"),
"v1.13.0": ("0.13.0", "rc4"),
"v1.13.1": ("0.13.1", "rc2"),
"v2.0.0": ("2.0.1", "rc3"),
"v2.0.1": ("2.0.2", "rc2"),
},
)
print("Building TorchAudio wheel")
build_vars = ""
if branch == "nightly":
version = (
host.check_output(["grep", '"version = \'"', "audio/setup.py"])
.strip()
.split("'")[1][:-2]
)
build_date = (
host.check_output("cd audio && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(
f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \
&& ./packaging/ffmpeg/build.sh \
&& {build_vars} python3 -m build --wheel --no-isolation"
)
wheel_name = host.list_dir("audio/dist")[0]
embed_libgomp(host, use_conda, os.path.join("audio", "dist", wheel_name))
print("Copying TorchAudio wheel")
host.download_wheel(os.path.join("audio", "dist", wheel_name))
return wheel_name
def configure_system(
host: RemoteHost,
*,
compiler: str = "gcc-8",
use_conda: bool = True,
python_version: str = "3.8",
) -> None:
if use_conda:
install_condaforge_python(host, python_version)
print("Configuring the system")
if not host.using_docker():
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip")
else:
host.run_cmd("yum install -y sudo")
host.run_cmd("conda install -y ninja scons")
if not use_conda:
host.run_cmd(
"sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip"
)
host.run_cmd("pip3 install dataclasses typing-extensions")
if not use_conda:
print("Installing Cython + numpy from PyPy")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
def build_domains(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> tuple[str, str, str, str]:
vision_wheel_name = build_torchvision(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
audio_wheel_name = build_torchaudio(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
data_wheel_name = build_torchdata(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
text_wheel_name = build_torchtext(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
return (vision_wheel_name, audio_wheel_name, data_wheel_name, text_wheel_name)
def start_build(
host: RemoteHost,
*,
branch: str = "main",
compiler: str = "gcc-8",
use_conda: bool = True,
python_version: str = "3.8",
pytorch_only: bool = False,
pytorch_build_number: Optional[str] = None,
shallow_clone: bool = True,
enable_mkldnn: bool = False,
) -> tuple[str, str, str, str, str]:
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
if host.using_docker() and not use_conda:
print("Auto-selecting conda option for docker images")
use_conda = True
if not host.using_docker():
print("Disable mkldnn for host builds")
enable_mkldnn = False
configure_system(
host, compiler=compiler, use_conda=use_conda, python_version=python_version
)
if host.using_docker():
print("Move libgfortant.a into a standard location")
# HACK: pypa gforntran.a is compiled without PIC, which leads to the following error
# libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17' # noqa: E501, B950
# Workaround by copying gfortran library from the host
host.run_ssh_cmd("sudo apt-get install -y gfortran-8")
host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8")
host.run_ssh_cmd(
[
"docker",
"cp",
"/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a",
f"{host.container_id}:/opt/rh/devtoolset-10/root/usr/lib/gcc/aarch64-redhat-linux/10/",
]
)
print("Checking out PyTorch repo")
host.run_cmd(
f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}"
)
host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh")
print("Building PyTorch wheel")
build_opts = ""
if pytorch_build_number is not None:
build_opts += f" -C--build-option=--build-number={pytorch_build_number}"
# Breakpad build fails on aarch64
build_vars = "USE_BREAKPAD=0 "
if branch == "nightly":
build_date = (
host.check_output("cd pytorch && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
version = host.check_output("cat pytorch/version.txt").strip()[:-2]
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1"
if branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
if enable_mkldnn:
host.run_cmd("pytorch/.ci/docker/common/install_acl.sh")
print("build pytorch with mkldnn+acl backend")
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
build_vars += " BLAS=OpenBLAS"
build_vars += " OpenBLAS_HOME=/opt/OpenBLAS"
build_vars += " ACL_ROOT_DIR=/acl"
host.run_cmd(
f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
)
print("Repair the wheel")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
ld_library_path = "/acl/build:$HOME/pytorch/build/lib"
host.run_cmd(
f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}"
)
print("replace the original wheel with the repaired one")
pytorch_repaired_wheel_name = host.list_dir("wheelhouse")[0]
host.run_cmd(
f"cp $HOME/wheelhouse/{pytorch_repaired_wheel_name} $HOME/pytorch/dist/{pytorch_wheel_name}"
)
else:
print("build pytorch without mkldnn backend")
host.run_cmd(
f"cd pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
)
print("Deleting build folder")
host.run_cmd("cd pytorch && rm -rf build")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
embed_libgomp(host, use_conda, os.path.join("pytorch", "dist", pytorch_wheel_name))
print("Copying the wheel")
host.download_wheel(os.path.join("pytorch", "dist", pytorch_wheel_name))
print("Installing PyTorch wheel")
host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}")
if pytorch_only:
return (pytorch_wheel_name, None, None, None, None)
domain_wheels = build_domains(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
return (pytorch_wheel_name, *domain_wheels)
embed_library_script = """
#!/usr/bin/env python3
from auditwheel.patcher import Patchelf
from auditwheel.wheeltools import InWheelCtx
from auditwheel.elfutils import elf_file_filter
from auditwheel.repair import copylib
from auditwheel.lddtree import lddtree
from subprocess import check_call
import os
import shutil
import sys
from tempfile import TemporaryDirectory
def replace_tag(filename):
with open(filename, 'r') as f:
lines = f.read().split("\\n")
for i,line in enumerate(lines):
if not line.startswith("Tag: "):
continue
lines[i] = line.replace("-linux_", "-manylinux2014_")
print(f'Updated tag from {line} to {lines[i]}')
with open(filename, 'w') as f:
f.write("\\n".join(lines))
class AlignedPatchelf(Patchelf):
def set_soname(self, file_name: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name])
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name])
def embed_library(whl_path, lib_soname, update_tag=False):
patcher = AlignedPatchelf()
out_dir = TemporaryDirectory()
whl_name = os.path.basename(whl_path)
tmp_whl_name = os.path.join(out_dir.name, whl_name)
with InWheelCtx(whl_path) as ctx:
torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib')
ctx.out_wheel=tmp_whl_name
new_lib_path, new_lib_soname = None, None
for filename, elf in elf_file_filter(ctx.iter_files()):
if not filename.startswith('torch/lib'):
continue
libtree = lddtree(filename)
if lib_soname not in libtree['needed']:
continue
lib_path = libtree['libs'][lib_soname]['path']
if lib_path is None:
print(f"Can't embed {lib_soname} as it could not be found")
break
if lib_path.startswith(torchlib_path):
continue
if new_lib_path is None:
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
patcher.replace_needed(filename, lib_soname, new_lib_soname)
print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}')
if update_tag:
# Add manylinux2014 tag
for filename in ctx.iter_files():
if os.path.basename(filename) != 'WHEEL':
continue
replace_tag(filename)
shutil.move(tmp_whl_name, whl_path)
if __name__ == '__main__':
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
"""
def run_tests(host: RemoteHost, whl: str, branch="main") -> None:
print("Configuring the system")
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y python3-pip git")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
host.upload_file(whl, ".")
host.run_cmd(f"sudo pip3 install {whl}")
host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'")
host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch")
host.run_cmd("cd pytorch/test; python3 test_torch.py -v")
def get_instance_name(instance) -> Optional[str]:
if instance.tags is None:
return None
for tag in instance.tags:
if tag["Key"] == "Name":
return tag["Value"]
return None
def list_instances(instance_type: str) -> None:
print(f"All instances of type {instance_type}")
for instance in ec2_instances_of_type(instance_type):
ifaces = instance.network_interfaces
az = ifaces[0].subnet.availability_zone if len(ifaces) > 0 else None
print(
f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']} {az}"
)
def terminate_instances(instance_type: str) -> None:
print(f"Terminating all instances of type {instance_type}")
instances = list(ec2_instances_of_type(instance_type))
for instance in instances:
print(f"Terminating {instance.id}")
instance.terminate()
print("Waiting for termination to complete")
for instance in instances:
instance.wait_until_terminated()
def parse_arguments():
from argparse import ArgumentParser
parser = ArgumentParser("Build and test AARCH64 wheels using EC2")
parser.add_argument("--key-name", type=str)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")
parser.add_argument("--test-only", type=str)
group = parser.add_mutually_exclusive_group()
group.add_argument("--os", type=str, choices=list(os_amis.keys()))
group.add_argument("--ami", type=str)
parser.add_argument(
"--python-version",
type=str,
choices=[f"3.{d}" for d in range(6, 12)],
default=None,
)
parser.add_argument("--alloc-instance", action="store_true")
parser.add_argument("--list-instances", action="store_true")
parser.add_argument("--pytorch-only", action="store_true")
parser.add_argument("--keep-running", action="store_true")
parser.add_argument("--terminate-instances", action="store_true")
parser.add_argument("--instance-type", type=str, default="t4g.2xlarge")
parser.add_argument("--ebs-size", type=int, default=50)
parser.add_argument("--branch", type=str, default="main")
parser.add_argument("--use-docker", action="store_true")
parser.add_argument(
"--compiler",
type=str,
choices=["gcc-7", "gcc-8", "gcc-9", "clang"],
default="gcc-8",
)
parser.add_argument("--use-torch-from-pypi", action="store_true")
parser.add_argument("--pytorch-build-number", type=str, default=None)
parser.add_argument("--disable-mkldnn", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
ami = (
args.ami
if args.ami is not None
else os_amis[args.os]
if args.os is not None
else ubuntu20_04_ami
)
keyfile_path, key_name = compute_keyfile_path(args.key_name)
if args.list_instances:
list_instances(args.instance_type)
sys.exit(0)
if args.terminate_instances:
terminate_instances(args.instance_type)
sys.exit(0)
if len(key_name) == 0:
raise RuntimeError("""
Cannot start build without key_name, please specify
--key-name argument or AWS_KEY_NAME environment variable.""")
if len(keyfile_path) == 0 or not os.path.exists(keyfile_path):
raise RuntimeError(f"""
Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please
check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""")
# Starting the instance
inst = start_instance(
key_name, ami=ami, instance_type=args.instance_type, ebs_size=args.ebs_size
)
instance_name = f"{args.key_name}-{args.os}"
if args.python_version is not None:
instance_name += f"-py{args.python_version}"
inst.create_tags(
DryRun=False,
Tags=[
{
"Key": "Name",
"Value": instance_name,
}
],
)
addr = inst.public_dns_name
wait_for_connection(addr, 22)
host = RemoteHost(addr, keyfile_path)
host.ami = ami
if args.use_docker:
update_apt_repo(host)
host.start_docker()
if args.test_only:
run_tests(host, args.test_only)
sys.exit(0)
if args.alloc_instance:
if args.python_version is None:
sys.exit(0)
install_condaforge_python(host, args.python_version)
sys.exit(0)
python_version = args.python_version if args.python_version is not None else "3.10"
if args.use_torch_from_pypi:
configure_system(host, compiler=args.compiler, python_version=python_version)
print("Installing PyTorch wheel")
host.run_cmd("pip3 install torch")
build_domains(
host, branch=args.branch, git_clone_flags=" --depth 1 --shallow-submodules"
)
else:
start_build(
host,
branch=args.branch,
compiler=args.compiler,
python_version=python_version,
pytorch_only=args.pytorch_only,
pytorch_build_number=args.pytorch_build_number,
enable_mkldnn=not args.disable_mkldnn,
)
if not args.keep_running:
print(f"Waiting for instance {inst.id} to terminate")
inst.terminate()
inst.wait_until_terminated()

View File

@ -1,87 +0,0 @@
#!/usr/bin/env python3
import os
import shutil
import sys
from subprocess import check_call
from tempfile import TemporaryDirectory
from auditwheel.elfutils import elf_file_filter
from auditwheel.lddtree import lddtree
from auditwheel.patcher import Patchelf
from auditwheel.repair import copylib
from auditwheel.wheeltools import InWheelCtx
def replace_tag(filename):
with open(filename) as f:
lines = f.read().split("\\n")
for i, line in enumerate(lines):
if not line.startswith("Tag: "):
continue
lines[i] = line.replace("-linux_", "-manylinux2014_")
print(f"Updated tag from {line} to {lines[i]}")
with open(filename, "w") as f:
f.write("\\n".join(lines))
class AlignedPatchelf(Patchelf):
def set_soname(self, file_name: str, new_soname: str) -> None:
check_call(
["patchelf", "--page-size", "65536", "--set-soname", new_soname, file_name]
)
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
check_call(
[
"patchelf",
"--page-size",
"65536",
"--replace-needed",
soname,
new_soname,
file_name,
]
)
def embed_library(whl_path, lib_soname, update_tag=False):
patcher = AlignedPatchelf()
out_dir = TemporaryDirectory()
whl_name = os.path.basename(whl_path)
tmp_whl_name = os.path.join(out_dir.name, whl_name)
with InWheelCtx(whl_path) as ctx:
torchlib_path = os.path.join(ctx._tmpdir.name, "torch", "lib")
ctx.out_wheel = tmp_whl_name
new_lib_path, new_lib_soname = None, None
for filename, _ in elf_file_filter(ctx.iter_files()):
if not filename.startswith("torch/lib"):
continue
libtree = lddtree(filename)
if lib_soname not in libtree["needed"]:
continue
lib_path = libtree["libs"][lib_soname]["path"]
if lib_path is None:
print(f"Can't embed {lib_soname} as it could not be found")
break
if lib_path.startswith(torchlib_path):
continue
if new_lib_path is None:
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
patcher.replace_needed(filename, lib_soname, new_lib_soname)
print(f"Replacing {lib_soname} with {new_lib_soname} for {filename}")
if update_tag:
# Add manylinux2014 tag
for filename in ctx.iter_files():
if os.path.basename(filename) != "WHEEL":
continue
replace_tag(filename)
shutil.move(tmp_whl_name, whl_path)
if __name__ == "__main__":
embed_library(
sys.argv[1], "libgomp.so.1", len(sys.argv) > 2 and sys.argv[2] == "--update-tag"
)

View File

@ -4,14 +4,17 @@ set -ex
SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Source the common build script for architecture-specific configurations (MKLDNN, ACL, etc.)
source "${SCRIPTPATH}/../pytorch/build.sh" || true
case "${GPU_ARCH_TYPE:-BLANK}" in
cuda)
cuda | cuda-aarch64)
bash "${SCRIPTPATH}/build_cuda.sh"
;;
rocm)
bash "${SCRIPTPATH}/build_rocm.sh"
;;
cpu | cpu-cxx11-abi | cpu-s390x)
cpu | cpu-cxx11-abi | cpu-aarch64 | cpu-s390x)
bash "${SCRIPTPATH}/build_cpu.sh"
;;
xpu)

View File

@ -18,12 +18,31 @@ retry () {
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
}
# Detect architecture first
ARCH=$(uname -m)
echo "Detected architecture: $ARCH"
PLATFORM=""
# TODO move this into the Docker images
OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release)
if [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
retry yum install -q -y zip openssl
PLATFORM="manylinux_2_28_x86_64"
# Set platform based on architecture
case $ARCH in
x86_64)
PLATFORM="manylinux_2_28_x86_64"
;;
aarch64)
PLATFORM="manylinux_2_28_aarch64"
;;
s390x)
PLATFORM="manylinux_2_28_s390x"
;;
*)
echo "Unsupported architecture: $ARCH"
exit 1
;;
esac
elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
retry dnf install -q -y zip openssl
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
@ -38,6 +57,8 @@ else
exit 1
fi
echo "Platform set to: $PLATFORM"
# We use the package name to test the package by passing this to 'pip install'
# This is the env variable that setup.py uses to name the package. Note that
# pip 'normalizes' the name first by changing all - to _
@ -299,8 +320,8 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
# ROCm workaround for roctracer dlopens
if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then
patchedpath=$(fname_without_so_number $destpath)
# Keep the so number for XPU dependencies and libgomp.so.1 to avoid twice load
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" ]]; then
# Keep the so number for XPU dependencies, libgomp.so.1, ACL libraries, and NVPL libraries to avoid twice load
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" || "$filename" == libarm_compute* || "$filename" == libnvpl* || "$filename" == "libgfortran.so.5" ]]; then
patchedpath=$destpath
else
patchedpath=$(fname_with_sha256 $destpath)
@ -346,9 +367,22 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
done
# create Manylinux 2_28 tag this needs to happen before regenerate the RECORD
if [[ $PLATFORM == "manylinux_2_28_x86_64" && $GPU_ARCH_TYPE != "cpu-s390x" && $GPU_ARCH_TYPE != "xpu" ]]; then
# Support all architectures (x86_64, aarch64, s390x)
if [[ "$IS_MANYLINUX2_28" == "1" && $GPU_ARCH_TYPE != "xpu" ]]; then
wheel_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/WHEEL/g')
sed -i -e s#linux_x86_64#"${PLATFORM}"# $wheel_file;
echo "Updating wheel tag for $ARCH architecture"
# Replace linux_* with manylinux_2_28_* based on architecture
case $ARCH in
x86_64)
sed -i -e 's#linux_x86_64#manylinux_2_28_x86_64#g' $wheel_file
;;
aarch64)
sed -i -e 's#linux_aarch64#manylinux_2_28_aarch64#g' $wheel_file
;;
s390x)
sed -i -e 's#linux_s390x#manylinux_2_28_s390x#g' $wheel_file
;;
esac
fi
# regenerate the RECORD file with new hashes

View File

@ -15,6 +15,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
EXTRA_CAFFE2_CMAKE_FLAGS=()
fi
# Detect architecture
ARCH=$(uname -m)
echo "Building CPU wheel for architecture: $ARCH"
WHEELHOUSE_DIR="wheelhousecpu"
LIBTORCH_HOUSE_DIR="libtorch_housecpu"
if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then
@ -34,8 +38,10 @@ elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
LIBGOMP_PATH="/usr/lib64/libgomp.so.1"
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
if [[ "$(uname -m)" == "s390x" ]]; then
if [[ "$ARCH" == "s390x" ]]; then
LIBGOMP_PATH="/usr/lib/s390x-linux-gnu/libgomp.so.1"
elif [[ "$ARCH" == "aarch64" ]]; then
LIBGOMP_PATH="/usr/lib/aarch64-linux-gnu/libgomp.so.1"
else
LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1"
fi
@ -49,6 +55,34 @@ DEPS_SONAME=(
"libgomp.so.1"
)
# Add ARM-specific library dependencies for CPU builds
if [[ "$ARCH" == "aarch64" ]]; then
echo "Adding ARM-specific CPU library dependencies"
# ARM Compute Library (if available)
if [[ -d "/acl/build" ]]; then
echo "Adding ARM Compute Library for CPU"
DEPS_LIST+=(
"/acl/build/libarm_compute.so"
"/acl/build/libarm_compute_graph.so"
)
DEPS_SONAME+=(
"libarm_compute.so"
"libarm_compute_graph.so"
)
fi
# ARM system libraries
DEPS_LIST+=(
"/usr/lib64/libgfortran.so.5"
"/opt/OpenBLAS/lib/libopenblas.so.0"
)
DEPS_SONAME+=(
"libgfortran.so.5"
"libopenblas.so.0"
)
fi
rm -rf /usr/local/cuda*
SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )"

View File

@ -29,6 +29,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
EXTRA_CAFFE2_CMAKE_FLAGS=()
fi
# Detect architecture
ARCH=$(uname -m)
echo "Building for architecture: $ARCH"
# Determine CUDA version and architectures to build for
#
# NOTE: We should first check `DESIRED_CUDA` when determining `CUDA_VERSION`,
@ -53,34 +57,60 @@ fi
cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
# Function to remove architectures from a list
remove_archs() {
local result="$1"
shift
for arch in "$@"; do
result="${result//${arch};/}"
done
echo "$result"
}
# Function to filter CUDA architectures for aarch64
# aarch64 ARM GPUs only support certain compute capabilities
# Keep: 8.0 (A100), 9.0+ (Hopper, Grace Hopper, newer)
# Remove: < 8.0 (no ARM GPUs), 8.6 (x86_64 RTX 3090/A6000 only)
filter_aarch64_archs() {
local arch_list="$1"
# Explicitly remove architectures not needed on aarch64
arch_list=$(remove_archs "$arch_list" "5.0" "6.0" "7.0" "7.5" "8.6")
echo "$arch_list"
}
# Base: Common architectures across all modern CUDA versions
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0"
case ${CUDA_VERSION} in
#removing sm_50-sm_60 as these architectures are deprecated in CUDA 12.8/9 and will be removed in future releases
#however we would like to keep sm_70 architecture see: https://github.com/pytorch/pytorch/issues/157517
12.8)
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0"
;;
12.9)
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0+PTX"
# WAR to resolve the ld error in libtorch build with CUDA 12.9
12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;${TORCH_CUDA_ARCH_LIST}" ;; # Only 12.6 includes Legacy Maxwell/Pascal that will be removed in future releases
12.8) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0" ;; # +Hopper/Blackwell support
12.9) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0+PTX" # +Hopper/Blackwell support + PTX for forward compatibility
if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then
TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX"
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//7.0;/}" # Remove 7.0 to resolve the ld error
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//8.6;/}" # Remove 8.6 for libtorch
fi
;;
13.0)
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX"
;;
12.6)
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0"
;;
*)
echo "unknown cuda version $CUDA_VERSION"
exit 1
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;$([[ "$ARCH" == "aarch64" ]] && echo "11.0;" || echo "")12.0+PTX"
export TORCH_NVCC_FLAGS="-compress-mode=size"
export BUILD_BUNDLE_PTXAS=1
;;
*) echo "unknown cuda version $CUDA_VERSION"; exit 1 ;;
esac
# Filter for aarch64: Remove < 8.0 and 8.6
[[ "$ARCH" == "aarch64" ]] && TORCH_CUDA_ARCH_LIST=$(filter_aarch64_archs "$TORCH_CUDA_ARCH_LIST")
echo "TORCH_CUDA_ARCH_LIST set to: $TORCH_CUDA_ARCH_LIST"
export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}
echo "${TORCH_CUDA_ARCH_LIST}"
# Disable MAGMA for aarch64 as pre-built libraries are x86-64 only
if [[ "$ARCH" == "aarch64" ]]; then
echo "Disabling MAGMA for aarch64 architecture"
export USE_MAGMA=0
fi
# Package directories
WHEELHOUSE_DIR="wheelhouse$cuda_version_nodot"
LIBTORCH_HOUSE_DIR="libtorch_house$cuda_version_nodot"
@ -244,6 +274,51 @@ else
exit 1
fi
# Add ARM-specific library dependencies
if [[ "$ARCH" == "aarch64" ]]; then
echo "Adding ARM-specific library dependencies"
# ARM Compute Library (if available)
if [[ -d "/acl/build" ]]; then
echo "Adding ARM Compute Library"
DEPS_LIST+=(
"/acl/build/libarm_compute.so"
"/acl/build/libarm_compute_graph.so"
)
DEPS_SONAME+=(
"libarm_compute.so"
"libarm_compute_graph.so"
)
fi
# ARM system libraries
DEPS_LIST+=(
"/lib64/libgomp.so.1"
"/usr/lib64/libgfortran.so.5"
)
DEPS_SONAME+=(
"libgomp.so.1"
"libgfortran.so.5"
)
# NVPL libraries (ARM optimized BLAS/LAPACK)
if [[ -d "/usr/local/lib" && -f "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ]]; then
echo "Adding NVPL libraries for ARM"
DEPS_LIST+=(
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0"
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0"
"/usr/local/lib/libnvpl_lapack_core.so.0"
"/usr/local/lib/libnvpl_blas_core.so.0"
)
DEPS_SONAME+=(
"libnvpl_lapack_lp64_gomp.so.0"
"libnvpl_blas_lp64_gomp.so.0"
"libnvpl_lapack_core.so.0"
"libnvpl_blas_core.so.0"
)
fi
fi
# run_tests.sh requires DESIRED_CUDA to know what tests to exclude
export DESIRED_CUDA="$cuda_version_nodot"
@ -251,9 +326,11 @@ export DESIRED_CUDA="$cuda_version_nodot"
rm -rf /usr/local/cuda || true
ln -s "/usr/local/cuda-${CUDA_VERSION}" /usr/local/cuda
# Switch `/usr/local/magma` to the desired CUDA version
rm -rf /usr/local/magma || true
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
# Switch `/usr/local/magma` to the desired CUDA version (skip for aarch64)
if [[ "$ARCH" != "aarch64" ]]; then
rm -rf /usr/local/magma || true
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
fi
export CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) # 10.0.130
export CUDA_VERSION_SHORT=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev | cut -f1,2 -d".") # 10.0

View File

@ -21,87 +21,3 @@ if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]]; then
fi
mkdir -p "$pytest_reports_dir" || true
##########################################
# copied from .ci/pytorch/common_utils.sh
##########################################
function get_pinned_commit() {
cat .github/ci_commit_pins/"${1}".txt
}
function pip_install_whl() {
# This is used to install PyTorch and other build artifacts wheel locally
# without using any network connection
# Convert the input arguments into an array
local args=("$@")
# Check if the first argument contains multiple paths separated by spaces
if [[ "${args[0]}" == *" "* ]]; then
# Split the string by spaces into an array
IFS=' ' read -r -a paths <<< "${args[0]}"
# Loop through each path and install individually
for path in "${paths[@]}"; do
echo "Installing $path"
python3 -mpip install --no-index --no-deps "$path"
done
else
# Loop through each argument and install individually
for path in "${args[@]}"; do
echo "Installing $path"
python3 -mpip install --no-index --no-deps "$path"
done
fi
}
function pip_build_and_install() {
local build_target=$1
local wheel_dir=$2
local found_whl=0
for file in "${wheel_dir}"/*.whl
do
if [[ -f "${file}" ]]; then
found_whl=1
break
fi
done
# Build the wheel if it doesn't exist
if [ "${found_whl}" == "0" ]; then
python3 -m pip wheel \
--no-build-isolation \
--no-deps \
-w "${wheel_dir}" \
"${build_target}"
fi
for file in "${wheel_dir}"/*.whl
do
pip_install_whl "${file}"
done
}
function install_torchvision() {
local orig_preload
local commit
commit=$(get_pinned_commit vision)
orig_preload=${LD_PRELOAD}
if [ -n "${LD_PRELOAD}" ]; then
# Silence dlerror to work-around glibc ASAN bug, see https://sourceware.org/bugzilla/show_bug.cgi?id=27653#c9
echo 'char* dlerror(void) { return "";}'|gcc -fpic -shared -o "${HOME}/dlerror.so" -x c -
LD_PRELOAD=${orig_preload}:${HOME}/dlerror.so
fi
if [[ "${BUILD_ENVIRONMENT}" == *cuda* ]]; then
# Not sure if both are needed, but why not
export FORCE_CUDA=1
export WITH_CUDA=1
fi
pip_build_and_install "git+https://github.com/pytorch/vision.git@${commit}" dist/vision
if [ -n "${LD_PRELOAD}" ]; then
LD_PRELOAD=${orig_preload}
fi
}

View File

@ -19,7 +19,7 @@ git config --global --add safe.directory /var/lib/jenkins/workspace
if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
# TODO: This can be removed later once vision is also part of the Docker image
install_torchvision
pip install -q --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)"
# JIT C++ extensions require ninja, so put it into PATH.
export PATH="/var/lib/jenkins/.local/bin:$PATH"
# NB: ONNX test is fast (~15m) so it's ok to retry it few more times to avoid any flaky issue, we

View File

@ -86,10 +86,20 @@ else
fi
fi
# Enable MKLDNN with ARM Compute Library for ARM builds
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
export USE_MKLDNN=1
# ACL is required for aarch64 builds
if [[ ! -d "/acl" ]]; then
echo "ERROR: ARM Compute Library not found at /acl"
echo "ACL is required for aarch64 builds. Check Docker image setup."
exit 1
fi
export USE_MKLDNN_ACL=1
export ACL_ROOT_DIR=/acl
echo "ARM Compute Library enabled for MKLDNN: ACL_ROOT_DIR=/acl"
fi
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then

View File

@ -260,11 +260,8 @@ jobs:
"${DOCKER_IMAGE}"
)
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
if [[ ${BUILD_ENVIRONMENT} == *"aarch64"* ]]; then
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/aarch64_linux/aarch64_ci_build.sh"
else
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
fi
# Unified build script for all architectures (x86_64, aarch64, s390x)
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
- name: Chown artifacts
if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }}

View File

@ -22,6 +22,7 @@ enum class MacOSVersion : uint32_t {
MACOS_VER_15_0_PLUS,
MACOS_VER_15_1_PLUS,
MACOS_VER_15_2_PLUS,
MACOS_VER_26_0_PLUS,
};
//-----------------------------------------------------------------

View File

@ -65,6 +65,7 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
static bool _macos_15_1_plus = is_os_version_at_least(15, 1);
static bool _macos_15_2_plus = is_os_version_at_least(15, 2);
static bool _macos_26_0_plus = is_os_version_at_least(26, 0);
switch (version) {
case MacOSVersion::MACOS_VER_14_4_PLUS:
@ -75,6 +76,8 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
return _macos_15_1_plus;
case MacOSVersion::MACOS_VER_15_2_PLUS:
return _macos_15_2_plus;
case MacOSVersion::MACOS_VER_26_0_PLUS:
return _macos_26_0_plus;
default:
return false;
}

View File

@ -5,7 +5,6 @@
#include <ATen/native/Resize.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/xpu/Blas.h>
#include <ATen/xpu/XPUScaledBlas.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -340,399 +339,4 @@ Tensor _scaled_mm_xpu(
out);
}
using acceptance_fn = std::function<bool(
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&)>;
using namespace std::placeholders;
namespace scaled_blas = at::native::onednn::scaled;
using scaled_blas::convert_int_to_enum;
using scaled_blas::ScaledGemmImplementation;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2>
scale_kernel_dispatch = {{
{"tensorwise_tensorwise",
scaled_blas::check_tensorwise_recipe,
ScaledGemmImplementation::TENSORWISE_TENSORWISE},
{"rowwise_rowwise",
scaled_blas::check_rowwise_recipe,
ScaledGemmImplementation::ROWWISE_ROWWISE},
}};
Tensor& _scaled_tensorwise_tensorwise(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
bool use_fast_accum,
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(),
mat_b.scalar_type());
TORCH_CHECK_VALUE(
scale_a.numel() == 1 && scale_a.scalar_type() == kFloat,
"scale_a must have 1 Float element")
TORCH_CHECK_VALUE(
scale_b.numel() == 1 && scale_b.scalar_type() == kFloat,
"scale_b must have 1 Float element")
auto scaling_choice_a = ScalingType::TensorWise;
auto scaling_choice_b = ScalingType::TensorWise;
_scaled_gemm(
mat_a,
mat_b,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
use_fast_accum,
out);
return out;
}
Tensor& _scaled_rowwise_rowwise(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
bool use_fast_accum,
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32, shape M/N for A/B
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(),
mat_b.scalar_type());
TORCH_CHECK_VALUE(
scale_a.size(0) == mat_a.size(0) && scale_a.size(1) == 1,
"scale_a must have shape [",
mat_a.size(0),
", 1], got [",
scale_a.sizes(),
"]");
TORCH_CHECK_VALUE(
scale_a.numel() == mat_a.size(0) && scale_a.scalar_type() == kFloat,
"scale_a must have ",
mat_a.size(0),
" Float elements, got ",
scale_a.numel())
TORCH_CHECK_VALUE(
scale_b.numel() == mat_b.size(1) && scale_b.scalar_type() == kFloat,
"scale_b must have ",
mat_b.size(1),
" Float elements, got ",
scale_b.numel())
TORCH_CHECK_VALUE(
scale_a.stride(1) == 1,
"expected scale_a.stride(1) to be 1, but got ",
scale_a.stride(1));
TORCH_CHECK_VALUE(
scale_b.stride(1) == 1,
"expected scale_b.stride(1) to be 1, but got ",
scale_b.stride(1));
auto scaling_choice_a = ScalingType::RowWise;
auto scaling_choice_b = ScalingType::RowWise;
_scaled_gemm(
mat_a,
mat_b,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
use_fast_accum,
out);
return out;
}
// V2: Computes matrix multiply + bias while applying scaling to input and
// output matrices Scales are only applicable when matrices are of Float8 type
// and assumed to be equal to 1.0 by default. If output matrix type is 16 or
// 32-bit type, scale_result is not applied. Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
// - If 1-dimensional tensors are used then scale_a should be size =
// mat1.size(0)
// and scale_b should have size = to mat2.size(1)
// Arguments:
// - `mat_a`: the first operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `mat_b`: the second operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_recipe_a`: An integer corresponding to an enum describing the
// scaling scheme used for `scale_a`
// - `swizzle_a`: An integer corresponding to a `SwizzleType` enum describing
// the swizzling scheme for `scale_a`.
// Not supported for XPU for now.
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_recipe_b`: An integer corresponding to an enum describing the
// scaling scheme used for `scale_b`
// - `swizzle_b`: An integer corresponding to a `SwizzleType` enum describing
// the swizzling scheme for `scale_b`.
// Not supported for XPU for now.
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
// - `out_dtype`: the output dtype, can either be a float8 or a higher
// precision floating point type
// - `contraction_dim`: describe which dimensions are `K` in the matmul.
// Not supported for XPU. Should always be empty.
// - `use_fast_accum`: Not supported for XPU, should always be false.
// - `out`: a reference to the output tensor
Tensor& _scaled_mm_xpu_v2_out(
const Tensor& mat_a,
const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum,
Tensor& out) {
TORCH_CHECK_VALUE(mat_a.dim() == 2, "mat_a must be a matrix");
TORCH_CHECK_VALUE(mat_b.dim() == 2, "mat_b must be a matrix");
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
// kernels do not support this case).
if (mat_a.size(0) == 0 || mat_a.size(1) == 0 || mat_b.size(1) == 0) {
// `out` was created with `at::empty`. In the case where we are multiplying
// MxK by KxN and K is the zero dim, we need to initialize here to properly
// return a tensor of zeros.
at::native::resize_output(out, {mat_a.size(0), mat_b.size(1)});
if (mat_a.size(1) == 0) {
out.zero_();
}
return out;
}
// Note: The `contraction_dim` is not actually used for now. We will need to
// align this code when upstreamed CUDA code is done. Currently, only keeps
// the code here for check.
// Check if the input matrix sizes can be multiplied
// - if optional contraction dims are provided, use those
// -- mostly for < 1B formats (i.e. nvfp4x2) where cheap .t() is not
// available.
if (contraction_dim.size() > 0) {
TORCH_CHECK_VALUE(
contraction_dim.size() == 2,
"contraction_dim must have exactly 2 elements");
auto mat_a_dim = contraction_dim[0];
auto mat_b_dim = contraction_dim[1];
TORCH_CHECK_VALUE(
mat_a.size(mat_a_dim) == mat_b.size(mat_b_dim),
"mat_a and mat_b shapes cannot be multiplied (",
mat_a.size(0),
"x",
mat_a.size(1),
" and ",
mat_b.size(0),
"x",
mat_b.size(1),
") ",
"with contraction dims mat_a: ",
mat_a_dim,
", mat_b: ",
mat_b_dim);
} else {
TORCH_CHECK_VALUE(
mat_a.size(1) == mat_b.size(0),
"mat_a and mat_b shapes cannot be multiplied (",
mat_a.size(0),
"x",
mat_a.size(1),
" and ",
mat_b.size(0),
"x",
mat_b.size(1),
")");
}
TORCH_CHECK_VALUE(
!bias || bias->numel() == mat_b.sizes()[1],
"Bias must be size ",
mat_b.sizes()[1],
" but got ",
bias->numel());
TORCH_CHECK_VALUE(
!out_dtype || *out_dtype == out.scalar_type(),
"out_dtype must match output matrix type");
if (bias) {
TORCH_CHECK_VALUE(
bias->scalar_type() == kFloat ||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
bias->scalar_type() == c10::ScalarType::Half,
"Bias must be Float32 or BFloat16 or Half, but got ",
bias->scalar_type());
}
{
auto bias_ = bias.value_or(Tensor());
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{
{out, "out", 0},
{mat_a, "mat_a", 1},
{mat_b, "mat_b", 2},
{bias_, "bias", 3},
{scale_a[0], "scale_a", 4},
{scale_b[0], "scale_b", 5}};
checkAllSameGPU(__func__, targs);
}
// Align with CUDA's default out to be bf16
auto out_dtype_ = out_dtype.value_or(c10::ScalarType::BFloat16);
// Conversion of implicitly-defined enums to explicit
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
// XPU does not support swizzle for now. So directly return false.
TORCH_CHECK_VALUE(
swizzle_a_enum[0] == at::blas::SwizzleType::NO_SWIZZLE &&
swizzle_b_enum[0] == at::blas::SwizzleType::NO_SWIZZLE,
"XPU does not support swizzle yet.");
// at this point we can start working out what we want to be doing
// Try to do as few steps as possible.
// NOTE: support is deliberately sparse, can explicitly enumerate all
// combinations allowed. Do this via a list of defined (name, acceptance,
// concrete_impl) tuples.
bool found_impl = false;
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
for (const auto& fn_entry : scale_kernel_dispatch) {
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
bool ok = accept_fn(
mat_a.scalar_type(),
scale_recipe_a_enum,
scale_a,
mat_b.scalar_type(),
scale_recipe_b_enum,
scale_b);
if (ok) {
gemm_impl = scaled_gemm_impl;
found_impl = true;
break;
}
}
TORCH_CHECK_VALUE(
found_impl,
"Invalid scaling configuration.\n"
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
mat_a.size(0),
", 1) and scale_b should be (1, ",
mat_b.size(1),
"), and both should be contiguous.\n"
"Got mat_a.dtype()=",
mat_a.scalar_type(),
", scale_a[0].dtype()=",
scale_a[0].scalar_type(),
", scale_a[0].size()=",
scale_a[0].sizes(),
", scale_a[0].stride()=",
scale_a[0].strides(),
", ",
"mat_b.dtype()=",
mat_b.scalar_type(),
", scale_b[0].dtype()=",
scale_b[0].scalar_type(),
", scale_b[0].size()=",
scale_b[0].sizes(),
" and scale_b[0].stride()=",
scale_b[0].strides());
at::native::resize_output(out, {mat_a.size(0), mat_b.size(1)});
auto bias_ = bias.value_or(Tensor());
// dispatch to appropriate lower-level calls for error checking & execution
if (gemm_impl == ScaledGemmImplementation::TENSORWISE_TENSORWISE) {
return _scaled_tensorwise_tensorwise(
mat_a,
mat_b,
scale_a[0],
scale_b[0],
bias,
out_dtype_,
use_fast_accum,
out);
} else if (gemm_impl == ScaledGemmImplementation::ROWWISE_ROWWISE) {
return _scaled_rowwise_rowwise(
mat_a,
mat_b,
scale_a[0],
scale_b[0],
bias,
out_dtype_,
use_fast_accum,
out);
} else {
TORCH_CHECK_VALUE(
false, "Invalid state - found an implementation, but not really");
}
}
Tensor _scaled_mm_xpu_v2(
const Tensor& mat_a,
const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_xpu_v2_out(
mat_a,
mat_b,
scale_a,
scale_recipe_a,
swizzle_a,
scale_b,
scale_recipe_b,
swizzle_b,
bias,
out_dtype,
contraction_dim,
use_fast_accum,
out);
}
} // namespace at::native

View File

@ -69,75 +69,139 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
auto out = at::empty({batchSize, num_head, qSize, headSize}, query.options());
auto attn = at::empty({batchSize, num_head, qSize, maxSeqLength}, query.options());
auto scale_factor = sdp::calculate_scale(query, scale).expect_float();
static const bool is_macOS_26_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_26_0_PLUS);
@autoreleasepool {
auto mkey = __func__ + getTensorsStringKey({query, key, value}) + ":" + std::to_string(is_causal) + ":" +
std::to_string(attn_mask.has_value());
auto cachedGraph =
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor
shape:getMPSShape({1})
dataType:MPSDataTypeFloat32];
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
CachedGraph* cachedGraph;
//if(is_macOS_26_0_or_newer) {
if(true) {
cachedGraph =
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
// bug in MacOS15, without this trick SDPA leaks memory, adding 0.0f gets ignored(still takes SDPA sequence
// path which leaks)
auto oneTensor = [mpsGraph constantWithScalar:1e-20f shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:oneTensor name:nil];
}
if (is_causal) {
MPSShape* maskShape = @[@(qSize), @(maxSeqLength)];
auto x = [mpsGraph coordinateAlongAxis:-1 withShape:@[@(qSize), @1] name:nil];
auto y = [mpsGraph coordinateAlongAxis:-2 withShape:@[@1, @(maxSeqLength)] name:nil];
auto isLess = [mpsGraph lessThanOrEqualToWithPrimaryTensor:x secondaryTensor:y name:nil];
auto causalMask = [mpsGraph selectWithPredicateTensor:isLess
truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:qTensor.dataType]
falsePredicateTensor:[mpsGraph constantWithScalar:-INFINITY dataType:qTensor.dataType]
name:nil];
graph->maskTensor = causalMask;
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
}
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
// Overwrites expected NANs in sm with zeros.
// auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
// auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
// auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
// auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
// auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
//
// auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
// MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
// truePredicateTensor:zeroTensor
// falsePredicateTensor:sm
// name:nil];
//
// auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
if (is_causal) {
auto causalMask = [mpsGraph constantWithScalar:1.0f
shape:getMPSShape({qSize, maxSeqLength})
dataType:MPSDataTypeBool];
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
truePredicateTensor:maskedMM
falsePredicateTensor:minusInf
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
name:nil];
}
MPSGraphTensor* output;
if(graph->maskTensor != nil) {
output = [mpsGraph scaledDotProductAttentionWithQueryTensor:qTensor
keyTensor:kTensor
valueTensor:vTensor
maskTensor:graph->maskTensor
scale:scale_factor
name:@"MPSGraph SDPA"];
} else {
output = [mpsGraph scaledDotProductAttentionWithQueryTensor:qTensor
keyTensor:kTensor
valueTensor:vTensor
scale:scale_factor
name:@"MPSGraph SDPA"];
}
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
// graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
} else {
cachedGraph =
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor
shape:getMPSShape({1})
dataType:MPSDataTypeFloat32];
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
// Overwrites expected NANs in sm with zeros.
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
truePredicateTensor:zeroTensor
falsePredicateTensor:sm
name:nil];
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
// bug in MacOS15, without this trick SDPA leaks memory, adding 0.0f gets ignored(still takes SDPA sequence
// path which leaks)
auto oneTensor = [mpsGraph constantWithScalar:1e-20f shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:oneTensor name:nil];
}
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
if (is_causal) {
auto causalMask = [mpsGraph constantWithScalar:1.0f
shape:getMPSShape({qSize, maxSeqLength})
dataType:MPSDataTypeBool];
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
truePredicateTensor:maskedMM
falsePredicateTensor:minusInf
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
name:nil];
}
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
// Overwrites expected NANs in sm with zeros.
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
truePredicateTensor:zeroTensor
falsePredicateTensor:sm
name:nil];
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
}
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);
auto vPlaceholder = Placeholder(cachedGraph->vTensor, value);
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, out);
auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
// auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
NSDictionary* feeds = nil;
if (!attn_mask) {
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder);
@ -145,7 +209,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
auto mPlaceholder = Placeholder(cachedGraph->maskTensor, *attn_mask);
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder, mPlaceholder);
}
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
// NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder);
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outs);
}

View File

@ -1,122 +0,0 @@
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/util/typeid.h>
#include <cstdint>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/BlasBackend.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/GroupedMMUtils.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/ceil_div.h>
#include <ATen/xpu/XPUScaledBlas.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
using at::blas::ScalingType;
namespace at::native::onednn::scaled {
/**
* Both inputs must be fp8,
* Each needs a single scale, {Tensorwise (float)}
*/
bool check_tensorwise_recipe(
c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 ||
recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::TensorWise)
return false;
if (scales_a[0].scalar_type() != ScalarType::Float)
return false;
if (recipe_b[0] != ScalingType::TensorWise)
return false;
if (scales_b[0].scalar_type() != ScalarType::Float)
return false;
return true;
}
/**
* Both inputs must be fp8,
* Each needs scales, {Rowwise (float)}
*/
bool check_rowwise_recipe(
c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 ||
recipe_b.size() != 1) {
return false;
}
// Need {RowWise, dp32} for A & B
if (recipe_a[0] != ScalingType::RowWise)
return false;
if (scales_a[0].scalar_type() != ScalarType::Float)
return false;
if (recipe_b[0] != ScalingType::RowWise)
return false;
if (scales_b[0].scalar_type() != ScalarType::Float)
return false;
return true;
}
} // namespace at::native::onednn::scaled

View File

@ -1,95 +0,0 @@
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/util/typeid.h>
#include <cstdint>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/BlasBackend.h>
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
#include <fbgemm_gpu/torch_ops.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
using at::blas::ScalingType;
namespace at::native::onednn::scaled {
/**
* Track concrete implementations available
*/
enum class ScaledGemmImplementation {
NONE = 0,
TENSORWISE_TENSORWISE = 1,
ROWWISE_ROWWISE = 2,
};
/**
* Convert passed int (enum) from python back into a
* strictly-typed enum
*/
template <class EnumType, class ArrayType>
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
std::vector<EnumType> converted;
converted.reserve(v.size());
for (auto vi : v) {
converted.push_back(static_cast<EnumType>(vi));
}
return converted;
}
bool check_tensorwise_recipe(
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_rowwise_recipe(
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
} // namespace at::native::onednn::scaled

View File

@ -1,238 +0,0 @@
# Owner(s): ["module: complex"]
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.distributed as dist
# Support both when imported from elsewhere or directly as a file
try:
from .utils import (
COMPLEX_DTYPES,
Descriptor,
force_test_op_db,
get_overload_packet_from_name,
implemented_op_db,
TestCase,
Variant,
)
except ImportError:
from utils import (
COMPLEX_DTYPES,
Descriptor,
force_test_op_db,
get_overload_packet_from_name,
implemented_op_db,
TestCase,
Variant,
)
from torch._subclasses.complex_tensor._ops.common import ComplexTensorMode
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
OpDTypes,
ops,
)
from torch.testing._internal.common_utils import (
run_tests,
TestGradients,
unMarkDynamoStrictTest,
)
if TYPE_CHECKING:
from torch.testing._internal.opinfo.core import OpInfo
aten = torch.ops.aten
SKIPS = {
Descriptor(op=aten.empty_like, variant=None): "Non-deterministic output",
Descriptor(op=aten.randn_like, variant=None): "Non-deterministic output",
Descriptor(op=aten.angle, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.asinh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.atanh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(
op=aten.reciprocal, variant=Variant.GradCheck
): "Numerical inconsistency",
Descriptor(op=aten.rsqrt, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.select, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.asin, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.log, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sgn, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.cumprod, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.slice, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sqrt, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.tan, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(
op=aten.true_divide, variant=Variant.GradCheck
): "Numerical inconsistency",
Descriptor(op=aten.prod, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.div, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.expm1, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.var, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.bmm, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.diagonal, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sinh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.abs, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sin, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.atan, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.acos, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.acosh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.cos, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.cosh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.addmm, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.pow, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.log1p, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.tanh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.mm, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.dot, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.mul, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.exp, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.to, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(
op=aten.any, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.all, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.allclose, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.conj_physical, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten._conj_physical, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.cumprod, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.index_add, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.diagonal_scatter, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.flip, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.masked_fill, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.masked_scatter, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.rsub, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.ne, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.squeeze, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.index_select, variant=Variant.Distributed
): "Sharding propagation failed",
Descriptor(op=aten.real, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.imag, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.isfinite, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.transpose, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.view_as_real, variant=Variant.Distributed): "No scalar support",
}
EXTRA_KWARGS = {
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Op): {
"rtol": 2e-5,
"atol": 5e-5,
},
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Op): {
"rtol": 1e-4,
"atol": 1e-5,
},
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Op): {
"rtol": 2e-2,
"atol": 2e-6,
},
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 2e-5,
"atol": 5e-5,
},
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 1e-4,
"atol": 1e-5,
},
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 2e-2,
"atol": 2e-6,
},
Descriptor(op=aten.tan, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 2e-6,
"atol": 1e-2,
},
}
class TestComplexTensor(TestCase):
_default_dtype_check_enabled = True
@ops(
implemented_op_db,
dtypes=OpDTypes.supported,
allowed_dtypes=list(COMPLEX_DTYPES),
)
def test_consistency(self, device, dtype, op: OpInfo):
self.check_consistency(device, dtype, op, Variant.Op)
@ops(force_test_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
def test_maybe_error(self, device, dtype, op: OpInfo):
self.check_consistency(device, dtype, op, Variant.Op)
@unMarkDynamoStrictTest
class TestComplexBwdGradients(TestGradients):
_default_dtype_check_enabled = True
@ops(
implemented_op_db,
dtypes=OpDTypes.supported_backward,
allowed_dtypes=[torch.complex128],
)
def test_fn_grad(self, device: str, dtype: torch.dtype, op: OpInfo) -> None:
test_info = Descriptor(
op=get_overload_packet_from_name(op.name),
device_type=torch.device(device).type,
dtype=dtype,
variant=Variant.GradCheck,
)
for xfail_info, reason in SKIPS.items():
if xfail_info.matches(test_info):
self.skipTest(reason)
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
self.skipTest(f"Skipped! {dtype=} is not in supported backward dtypes!")
with ComplexTensorMode():
op.gradcheck_fast_mode = False
self._grad_test_helper(device, dtype, op, op.get_op())
instantiate_device_type_tests(TestComplexTensor, globals())
instantiate_device_type_tests(TestComplexBwdGradients, globals())
if dist.is_available():
from torch.testing._internal.common_distributed import MultiProcessTestCase
@unMarkDynamoStrictTest
class TestComplexDistributed(TestCase, MultiProcessTestCase):
@ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
def test_distributed(self, device, dtype, op: OpInfo):
self.check_consistency(device, dtype, op, Variant.Distributed)
instantiate_device_type_tests(TestComplexDistributed, globals())
if __name__ == "__main__":
run_tests()

View File

@ -1,214 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field, fields
from enum import auto, Enum
from typing import Any, TYPE_CHECKING
import torch
import torch.distributed as dist
from torch._subclasses.complex_tensor._ops.common import (
_as_complex_tensor,
_as_interleaved,
_get_op_name,
COMPLEX_OPS_TABLE,
COMPLEX_TO_REAL,
FORCE_TEST_LIST,
OpOverloadPacket,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
from torch.utils._pytree import tree_flatten
if TYPE_CHECKING:
from collections.abc import Callable
from torch.distributed.tensor import DTensor
from torch.testing._internal.opinfo.core import OpInfo
COMPLEX_DTYPES = set(COMPLEX_TO_REAL)
class Variant(Enum):
Op = auto()
GradCheck = auto()
Distributed = auto()
def _as_local(arg: DTensor | Any) -> torch.Tensor | Any:
if not (dist.is_available() and isinstance(arg, dist.tensor.DTensor)):
return arg
return arg.full_tensor()
def _as_complex_dtensor(arg: torch.Tensor | Any) -> torch.Tensor | Any:
if not isinstance(arg, torch.Tensor):
return arg
return dist.tensor.DTensor.from_local(_as_complex_tensor(arg))
TRANSFORM_FUNCS = {
Variant.Op: _as_complex_tensor,
Variant.Distributed: _as_complex_dtensor,
}
@dataclass(frozen=True, kw_only=True)
class Descriptor:
op: OpOverloadPacket
variant: Variant | None
device_type: str | None = field(default=None)
dtype: torch.dtype | None = field(default=None)
def matches(self, other: Descriptor) -> bool:
fields1 = fields(self)
fields2 = fields(other)
if fields1 != fields2:
return False
for f in fields1:
f1 = getattr(self, f.name)
f2 = getattr(other, f.name)
if f1 is not None and f2 is not None and f1 != f2:
return False
return True
class TestCase(PytorchTestCase):
def assertSameResult(
self,
expected: Callable[[], Any],
actual: Callable[[], Any],
*args,
**kwargs,
) -> None:
try:
result_e = expected()
exception_e = None
except Exception as e: # noqa: BLE001
result_e = None
exception_e = e
try:
result_a = actual()
exception_a = None
except Exception as e: # noqa: BLE001
result_a = None
exception_a = e
if (exception_e is None) != (exception_a is None):
if exception_a is not None and exception_e is None:
raise exception_a
self.assertIs(
type(exception_e),
type(exception_a),
f"\n{exception_e=}\n{exception_a=}",
)
if exception_e is None:
flattened_e, spec_e = tree_flatten(result_e)
flattened_a, spec_a = tree_flatten(result_a)
self.assertEqual(
spec_e,
spec_a,
"Both functions must return a result with the same tree structure.",
)
for value_e, value_a in zip(flattened_e, flattened_a, strict=True):
value_e = _as_interleaved(_as_local(value_e))
value_a = _as_interleaved(_as_local(value_a))
self.assertEqual(value_e, value_a, *args, **kwargs)
def check_consistency(
self, device: str, dtype, op: OpInfo, variant: Variant
) -> None:
try:
from .test_complex_tensor import EXTRA_KWARGS, SKIPS
except ImportError:
from test_complex_tensor import EXTRA_KWARGS, SKIPS
test_info = Descriptor(
op=get_overload_packet_from_name(op.name),
device_type=torch.device(device).type,
dtype=dtype,
variant=variant,
)
for xfail_info, reason in SKIPS.items():
if xfail_info.matches(test_info):
self.skipTest(reason)
kwargs = {}
for extra_info, extra_kw in EXTRA_KWARGS.items():
if extra_info.matches(test_info):
kwargs = extra_kw
break
sample_inputs = op.sample_inputs(device, dtype)
transform_fn = TRANSFORM_FUNCS[variant]
for sample_input in sample_inputs:
def expected(sample_input=sample_input):
return op(sample_input.input, *sample_input.args, **sample_input.kwargs)
subclass_sample = sample_input.transform(transform_fn)
def actual(subclass_sample=subclass_sample):
return op(
subclass_sample.input,
*subclass_sample.args,
**subclass_sample.kwargs,
)
self.assertSameResult(expected, actual, **kwargs)
aten = torch.ops.aten
complex_op_db = tuple(
filter(lambda op: any(op.supports_dtype(ct, "cpu") for ct in COMPLEX_DTYPES), op_db)
)
def get_overload_packet_from_name(name: str) -> OpOverloadPacket:
for domain_name in torch.ops:
op_namespace = getattr(torch.ops, domain_name)
op: OpOverloadPacket | None = getattr(op_namespace, name, None)
if op is not None:
return op
raise RuntimeError(f"No op with {name=} found.")
force_test_names = set(map(_get_op_name, FORCE_TEST_LIST))
implemented_op_names = (
set(map(_get_op_name, COMPLEX_OPS_TABLE.keys())) - force_test_names
)
implemented_op_db = tuple(
filter(lambda op: op.name in implemented_op_names, complex_op_db)
)
force_test_op_db = tuple(filter(lambda op: op.name in force_test_names, op_db))
tested_op_names = {op.name for op in implemented_op_db} | {
op.name for op in force_test_op_db
}
non_tested_ops = {
op for op in COMPLEX_OPS_TABLE if _get_op_name(op) not in tested_op_names
}
# TODO (hameerabbasi): There are a number of ops that don't have any associated
# OpInfos. We still need to write tests for those ops.
if len(non_tested_ops) != 0:
import textwrap
import warnings
list_missing_ops = "\n".join(sorted([str(op) for op in non_tested_ops]))
warnings.warn(
"Not all implemented ops are tested. List of ops missing tests:"
f"\n{textwrap.indent(list_missing_ops, ' ')}",
UserWarning,
stacklevel=2,
)

View File

@ -230,98 +230,6 @@ class DistConvolutionOpsTest(DTensorTestBase):
out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)])
self.assertEqual(out_dt, out)
@with_comms
def test_conv2d_no_bias_compile(self):
"""Test Conv2d with bias=False in compile mode (Issue #167091)
Regression test: Previously this would fail during torch.compile
tracing with AssertionError when bias_spec was None.
"""
device_mesh = self.build_device_mesh()
def conv_fn(x, w):
return F.conv2d(x, w, bias=None, padding=1)
compiled_fn = torch.compile(conv_fn)
# Create tensors
x = torch.randn(1, 4, 5, 5, device=self.device_type)
w = torch.randn(8, 4, 3, 3, device=self.device_type)
# Distribute tensors
x_dt = distribute_tensor(x, device_mesh, [Replicate()])
w_dt = distribute_tensor(w, device_mesh, [Replicate()])
# Test eager mode for comparison
result_eager = conv_fn(x_dt, w_dt)
# Test compiled mode - this should not crash
result_compiled = compiled_fn(x_dt, w_dt)
# Verify shape is correct (the key regression test)
self.assertEqual(result_compiled.shape, torch.Size([1, 8, 5, 5]))
# Verify numerical correctness
torch.testing.assert_close(result_compiled.to_local(), result_eager.to_local())
@with_comms
def test_conv2d_no_bias_backward(self):
"""Test Conv2d backward pass with bias=False (Issue #167091)
Regression test: Previously backward pass would fail when
grad_bias_spec was None.
"""
device_mesh = self.build_device_mesh()
# Create tensors with requires_grad
x = torch.randn(1, 4, 5, 5, device=self.device_type)
w = torch.randn(8, 4, 3, 3, device=self.device_type, requires_grad=True)
# Distribute tensors
x_dt = distribute_tensor(x, device_mesh, [Replicate()])
w_dt = torch.nn.Parameter(distribute_tensor(w, device_mesh, [Replicate()]))
# Forward pass
result = F.conv2d(x_dt, w_dt, bias=None, padding=1)
# Backward pass - this should not crash
grad_output = torch.randn_like(result)
result.backward(grad_output)
# Check weight gradient exists (the key regression test)
self.assertIsNotNone(w_dt.grad)
self.assertEqual(w_dt.grad.shape, torch.Size([8, 4, 3, 3]))
@with_comms
def test_conv2d_module_no_bias(self):
"""Test nn.Conv2d module with bias=False (Issue #167091)
Regression test: Ensures nn.Conv2d with bias=False works with DTensor.
"""
device_mesh = self.build_device_mesh()
# Create model with bias=False
model = nn.Conv2d(4, 8, kernel_size=3, padding=1, bias=False).to(
self.device_type
)
nn.init.ones_(model.weight)
# Distribute model
model_dt = distribute_module(model, device_mesh, _conv_fn)
# Create input
x = torch.randn(1, 4, 5, 5, device=self.device_type)
x_dt = distribute_tensor(x, device_mesh, [Replicate()])
# Forward pass - this should not crash
output_dt = model_dt(x_dt)
# Check outputs shape is correct
self.assertEqual(output_dt.shape, torch.Size([1, 8, 5, 5]))
# Check that model.bias is None
self.assertIsNone(model.bias)
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
DistConvolutionOpsTest,
@ -330,10 +238,6 @@ DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
"test_conv_backward_none_grad_inp",
"test_depthwise_convolution",
"test_downsampling_convolution",
# New tests for Issue #167091 - use send/recv via tp_convolution
"test_conv2d_no_bias_compile",
"test_conv2d_no_bias_backward",
"test_conv2d_module_no_bias",
],
)

View File

@ -1061,63 +1061,6 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
correct = func(a, b, c)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_multiple_hiding_nodes_bucketing(self):
"""Test that collectives hidden by multiple compute ops can bucket together."""
# Use 0.5 compute multiplier so each collective needs 2 matmuls to be fully hidden
def estimate_with_half_compute(fx_node, override_size=None):
return estimate_aten_runtime(fx_node, compute_multiplier=0.5)
def func(a, b, *, ranks):
# Two all_gathers that will be hidden by multiple compute operations
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
# Multiple compute operations that can hide the collectives
# With 0.5 multiplier: mm1 and mm2 together hide ag1, mm2 and mm3 together hide ag2
mm1 = torch.matmul(a, a.T)
mm2 = torch.matmul(b, b.T)
mm3 = torch.matmul(a + b, (a + b).T)
return ag1.sum() + ag2.sum() + mm1.sum() + mm2.sum() + mm3.sum()
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
# Patch with custom estimation that uses 0.5 multiplier
with torch._inductor.config.patch(
{
"aten_distributed_optimizations.custom_runtime_estimation": estimate_with_half_compute
}
):
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b)
# Should have 1 bucketed all_gather (both ag1 and ag2 bucketed together)
FileCheck().check_count(
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
).run(aten_graph_str)
# Verify bucketed collective is scheduled before all matmuls
FileCheck().check("functional.all_gather_into_tensor").check(
"aten.mm"
).check("aten.mm").check("aten.mm").check("wait_tensor").run(aten_graph_str)
# Verify correctness
correct = func(a, b, ranks=ranks)
self.assertTrue(same(out, correct))
def get_toy_model(device_type: str):
"""

View File

@ -49,8 +49,7 @@ def build_collective_info(graph, hiding_annotations):
"""
Build CollectiveInfo dict from manual hiding annotations.
hiding_annotations: dict mapping collective_start -> hiding_compute_node(s)
Can be a single node or a list/OrderedSet of nodes
hiding_annotations: dict mapping collective_start -> hiding_compute_node
"""
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
@ -66,20 +65,12 @@ def build_collective_info(graph, hiding_annotations):
# Build CollectiveInfo for each collective
for start_node, wait_node in start_to_wait.items():
hiding_annotation = hiding_annotations.get(start_node)
# Convert to OrderedSet
hiding_nodes = OrderedSet()
if hiding_annotation is not None:
if isinstance(hiding_annotation, list | OrderedSet):
hiding_nodes = OrderedSet(hiding_annotation)
else:
hiding_nodes = OrderedSet([hiding_annotation])
hiding_node = hiding_annotations.get(start_node)
# Estimate size and time
size_bytes = 16 * 4 # 4x4 tensor of floats
estimated_time_ms = 1.0 # Dummy time
exposed_time_ms = 0.0 if hiding_nodes else 1.0 # Hidden if has hiding_nodes
exposed_time_ms = 0.0 if hiding_node else 1.0 # Hidden if has hiding_node
collective_info[start_node] = CollectiveInfo(
start_node=start_node,
@ -87,7 +78,7 @@ def build_collective_info(graph, hiding_annotations):
size_bytes=size_bytes,
estimated_time_ms=estimated_time_ms,
exposed_time_ms=exposed_time_ms,
hiding_nodes=hiding_nodes,
hiding_node=hiding_node,
)
return collective_info
@ -576,97 +567,6 @@ class TestOverlapPreservingBucketing(InductorTestCase):
graph_str
)
def test_can_bucket_with_multiple_hiding_nodes(self):
"""
Test that collectives with multiple hiding nodes CAN bucket.
Graph structure:
ag1_start -> ag2_start -> mm1 -> mm2 -> mm3 -> ag1_wait -> ag2_wait
Where:
- ag1 is hidden by mm1 and mm2
- ag2 is hidden by mm2 and mm3
- Both collectives share mm2 as a hiding node
"""
def func(a, b):
group_name = "0"
group_size = 1
# Start both collectives
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a, group_size, group_name
)
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
b, group_size, group_name
)
# Three compute operations that hide the collectives
mm1 = torch.mm(a, a)
mm2 = torch.mm(b, b)
mm3 = torch.mm(a + b, a + b)
# Wait for both
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum() + mm3.sum()
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device)
b = torch.ones(4, 4, device=self.device) * 2
# Trace with make_fx
traced = make_fx(func)(a, b)
# Find nodes using find_nodes
ag1, ag2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
mm1, mm2, mm3 = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
# Manually annotate hiding relationships with multiple hiding nodes
hiding_annotations = {
ag1: [mm1, mm2], # ag1 is hidden by mm1 and mm2
ag2: [mm2, mm3], # ag2 is hidden by mm2 and mm3
}
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Verify hiding_nodes are correctly set
self.assertEqual(len(collective_info[ag1].hiding_nodes), 2)
self.assertIn(mm1, collective_info[ag1].hiding_nodes)
self.assertIn(mm2, collective_info[ag1].hiding_nodes)
self.assertEqual(len(collective_info[ag2].hiding_nodes), 2)
self.assertIn(mm2, collective_info[ag2].hiding_nodes)
self.assertIn(mm3, collective_info[ag2].hiding_nodes)
# Run bucketing
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
FileCheck().check_count(
"all_gather_into_tensor_out", 1, exactly=False
).check_count("torch.ops.aten.mm.default", 3, exactly=True).run(
str(traced.graph)
)
if __name__ == "__main__":
run_tests()

View File

@ -253,14 +253,6 @@ class StoreTestBase:
a.set("foo", "bar")
self.assertEqual(b.get("foo"), b"bar")
def test_list_keys(self):
a = self._create_store()
a.set("foo", "bar")
a.set("baz", "qux")
keys = a.list_keys()
self.assertIn("foo", keys)
self.assertIn("baz", keys)
# This is the number of keys used in test_set_get. Adding this as a class
# property instead of hardcoding in the test since some Store
# implementations will have differing number of keys. In the base case,

View File

@ -470,7 +470,7 @@ class <lambda>(torch.nn.Module):
)
@requires_cuda
def test_stream_backward_simple(self) -> None:
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
@ -524,68 +524,7 @@ class GraphModule(torch.nn.Module):
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
@requires_cuda
def test_stream_backward_sync(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, device="cuda:0", requires_grad=True) + 1,
torch.ones(2, 2, device="cuda:0", requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",

View File

@ -456,31 +456,6 @@ def forward(self, x):
test_inputs = make_inputs()
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
def test_dynamo_graph_capture_with_call_override(self):
class _InterestingModule(torch.nn.Module):
def __init__(self, module):
super().__init__()
self._module = module
def __call__(self, *args, **kwargs):
return self._module(*args, **kwargs)
class MyModel(torch.nn.Module):
def forward(self, x):
return x + 1
foo = _InterestingModule(MyModel())
def make_inputs():
return (torch.randn(2, 3),)
trace_inputs = make_inputs()
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
test_inputs = make_inputs()
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers())))
self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters())))
def test_dynamo_graph_capture_custom_pytree_type(self):
import torch.utils._pytree as pytree

View File

@ -3,17 +3,12 @@ import io
from unittest.mock import patch
import torch
from torch._dynamo.utils import counters
from torch._functorch.aot_autograd import aot_export_module
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
from torch.testing._internal.common_utils import run_tests, TestCase
@instantiate_parametrized_tests
class TestHopPrint(TestCase):
def test_base_print(self):
def f(x):
@ -23,6 +18,7 @@ class TestHopPrint(TestCase):
torch._higher_order_ops.print("moo")
return x
counters.clear()
x = torch.randn(3, 3)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
f(x)
@ -37,6 +33,7 @@ class TestHopPrint(TestCase):
x = x * x
return x
counters.clear()
x = torch.randn(3, 3)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
f(x)
@ -187,62 +184,6 @@ x = add_1, y = add_2); getitem = None
"""print(str format_str) -> ()""",
)
@parametrize("backend", ["eager", "aot_eager"])
def test_reorder_print_no_graph_break(self, backend):
def f(x):
x1 = x + x
torch._higher_order_ops.print("moo {x}", x=x1)
x2 = x1 * x1
torch._higher_order_ops.print("moo {x}", x=x2)
x3 = x2 + x2
return (x1, x3)
# Eager and aot_eager backend for dynamo tracing testing
x = torch.randn(3, 3)
opt_f = torch.compile(backend=backend, fullgraph=True)(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
self.assertEqual(
printed_output,
f"moo {x * 2}\nmoo {x * 2 * x * 2}",
)
self.assertEqual(orig_out, opt_out)
x_new = torch.randn(2, 2)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x_new)
printed_output = mock_stdout.getvalue().strip()
self.assertEqual(
printed_output,
f"moo {x_new * 2}\nmoo {x_new * 2 * x_new * 2}",
)
@parametrize("backend", ["eager", "aot_eager"])
def test_constant_mutation(self, backend):
def f(x):
alist = [x]
alist.append(x + 1)
torch._higher_order_ops.print("moo {x}", x=alist[-1])
alist[0].sum().item() # graph break
res = alist.pop()
torch._higher_order_ops.print("moo {x}", x=alist[-1])
res.sum().item() # graph break
return res
inputs = (torch.tensor([1]),)
opt_f = torch.compile(backend=backend, fullgraph=True)(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(*inputs)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(*inputs)
self.assertEqual(printed_output, "moo tensor([2])\nmoo tensor([1])")
self.assertEqual(orig_out, opt_out)
if __name__ == "__main__":
run_tests()

View File

@ -1188,22 +1188,6 @@ class TestTiling(TestCase):
with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad():
torch.compile(f)(x)
def test_find_broadcast_var(self):
"""Test broadcast variable detection for tiling improvements."""
from torch._inductor import tiling_utils
i, j, k = sympy.symbols("i j k", integer=True)
# Test broadcast pattern detection: FloorDiv creates broadcast
result = tiling_utils.find_broadcast_var(
FloorDiv(i, 10), {i: 100, j: 50, k: 20}
)
self.assertEqual(result, i)
# Test non-broadcast: linear access pattern
result = tiling_utils.find_broadcast_var(i + j * 10, {i: 10, j: 8, k: 20})
self.assertEqual(result, None)
class TestIndexInversion(TestCase):
@classmethod

View File

@ -630,6 +630,7 @@ class TestSparse(TestSparseBase):
i[0][0] = 0
self.assertEqual(torch.empty((3, 0), dtype=dtype, device=device), self.safeToDense(x))
@expectedFailureMPS
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error")
@ -646,8 +647,7 @@ class TestSparse(TestSparseBase):
def fn(x):
return x.to_dense(masked_grad=gradcheck.masked)
x.requires_grad_(True)
kwargs = {"eps": 1e-4} if device == "mps:0" else {}
gradcheck(fn, (x,), **kwargs)
gradcheck(fn, (x,))
i = self.index_tensor([
[0, 1, 2, 2],

View File

@ -423,156 +423,6 @@ class TestFuzzerCompileIssues(TestCase):
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #167937")
def test_fuzzer_issue_167937(self):
torch.manual_seed(1251149731)
torch.set_default_device("cuda")
def fuzzed_program(
arg_0,
arg_1,
arg_2,
arg_3,
arg_4,
arg_5,
arg_6,
arg_7,
arg_8,
arg_9,
sentinel,
):
var_node_3 = arg_0 # size=(27, 28, 7), stride=(196, 7, 1), dtype=bfloat16, device=cuda
var_node_4 = (
arg_1 # size=(27, 7, 6), stride=(42, 6, 1), dtype=bfloat16, device=cuda
)
var_node_2 = torch.matmul(
var_node_3.to(torch.bfloat16), var_node_4.to(torch.bfloat16)
) # size=(27, 28, 6), stride=(168, 6, 1), dtype=bfloat16, device=cuda
var_node_6 = (
arg_2 # size=(27, 6, 9), stride=(54, 9, 1), dtype=bfloat16, device=cuda
)
var_node_7 = torch.full(
(27, 9, 1), -0.310546875, dtype=torch.bfloat16
) # size=(27, 9, 1), stride=(9, 1, 1), dtype=bfloat16, device=cuda
var_node_5 = torch.matmul(
var_node_6.to(torch.bfloat16), var_node_7.to(torch.bfloat16)
) # size=(27, 6, 1), stride=(6, 1, 1), dtype=bfloat16, device=cuda
var_node_1 = torch.matmul(
var_node_2.to(torch.bfloat16), var_node_5.to(torch.bfloat16)
) # size=(27, 28, 1), stride=(28, 1, 1), dtype=bfloat16, device=cuda
var_node_8 = arg_3 # size=(27, 28, 1), stride=(28, 1, 1), dtype=bfloat16, device=cuda
var_node_9 = torch.full(
(27, 28, 1), 0.76953125, dtype=torch.bfloat16
) # size=(27, 28, 1), stride=(28, 1, 1), dtype=bfloat16, device=cuda
var_node_12 = (
arg_4 # size=(3, 4), stride=(4, 1), dtype=bfloat16, device=cuda
)
var_node_13 = (
arg_5 # size=(4, 15), stride=(15, 1), dtype=bfloat16, device=cuda
)
var_node_11 = torch.matmul(
var_node_12.to(torch.bfloat16), var_node_13.to(torch.bfloat16)
) # size=(3, 15), stride=(15, 1), dtype=bfloat16, device=cuda
var_node_15 = (
arg_6 # size=(15, 12), stride=(12, 1), dtype=bfloat16, device=cuda
)
var_node_16 = (
arg_7 # size=(12, 1), stride=(1, 1), dtype=bfloat16, device=cuda
)
var_node_14 = torch.matmul(
var_node_15.to(torch.bfloat16), var_node_16.to(torch.bfloat16)
) # size=(15, 1), stride=(1, 1), dtype=bfloat16, device=cuda
var_node_10 = torch.matmul(
var_node_11.to(torch.bfloat16), var_node_14.to(torch.bfloat16)
) # size=(3, 1), stride=(1, 1), dtype=bfloat16, device=cuda
var_node_19 = (
arg_8 # size=(1, 8), stride=(8, 1), dtype=bfloat16, device=cuda
)
var_node_20 = (
arg_9 # size=(8, 2), stride=(2, 1), dtype=bfloat16, device=cuda
)
var_node_18 = torch.matmul(
var_node_19.to(torch.bfloat16), var_node_20.to(torch.bfloat16)
) # size=(1, 2), stride=(2, 1), dtype=bfloat16, device=cuda
var_node_21 = torch.full(
(2, 1), 0.000762939453125, dtype=torch.bfloat16
) # size=(2, 1), stride=(1, 1), dtype=bfloat16, device=cuda
var_node_17 = torch.matmul(
var_node_18.to(torch.bfloat16), var_node_21.to(torch.bfloat16)
) # size=(1, 1), stride=(1, 1), dtype=bfloat16, device=cuda
var_node_0, _ = torch.nn.functional.multi_head_attention_forward(
var_node_1.to(torch.bfloat16),
var_node_8.to(torch.bfloat16),
var_node_9.to(torch.bfloat16),
1,
1,
var_node_10.to(torch.bfloat16),
None, # in_proj_bias
None, # bias_k
None, # bias_v
False, # add_zero_attn
0.0, # dropout_p (no dropout for testing)
var_node_17.to(torch.bfloat16),
None, # out_proj_bias
training=False, # Use eval mode for deterministic behavior
need_weights=False, # Don't compute attention weights for performance
) # size=(27, 28, 1), stride=(28, 1, 1), dtype=bfloat16, device=cuda
# Ensure gradient computation by multiplying with sentinel and taking real part
result = var_node_0 * sentinel
if result.is_complex():
result = result.real
return result
try:
# Sentinel tensor to ensure gradient computation
sentinel = torch.tensor(1.0, requires_grad=True)
arg_0 = torch.as_strided(
torch.randn(5292).to(torch.bfloat16), (27, 28, 7), (196, 7, 1)
)
arg_1 = torch.as_strided(
torch.randn(1134).to(torch.bfloat16), (27, 7, 6), (42, 6, 1)
)
arg_2 = torch.as_strided(
torch.randn(1458).to(torch.bfloat16), (27, 6, 9), (54, 9, 1)
)
arg_3 = torch.as_strided(
torch.randn(756).to(torch.bfloat16), (27, 28, 1), (28, 1, 1)
)
arg_4 = torch.as_strided(torch.randn(12).to(torch.bfloat16), (3, 4), (4, 1))
arg_5 = torch.as_strided(
torch.randn(60).to(torch.bfloat16), (4, 15), (15, 1)
)
arg_6 = torch.as_strided(
torch.randn(180).to(torch.bfloat16), (15, 12), (12, 1)
)
arg_7 = torch.as_strided(
torch.randn(12).to(torch.bfloat16), (12, 1), (1, 1)
)
arg_8 = torch.as_strided(torch.randn(8).to(torch.bfloat16), (1, 8), (8, 1))
arg_9 = torch.as_strided(torch.randn(16).to(torch.bfloat16), (8, 2), (2, 1))
args = (
arg_0,
arg_1,
arg_2,
arg_3,
arg_4,
arg_5,
arg_6,
arg_7,
arg_8,
arg_9,
) + (sentinel,)
out_eager = fuzzed_program(*args)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(*args)
out_compiled.sum().backward()
print("Compile Success! ✅")
finally:
torch.set_default_device(None)
if __name__ == "__main__":
run_tests()

View File

@ -267,10 +267,7 @@ class DefaultFuzzTemplate(FuzzTemplate):
]
def flags_codegen(self):
return [
"torch.set_default_device('cuda')",
"torch._dynamo.config.capture_scalar_outputs = True",
]
return ["torch._dynamo.config.capture_scalar_outputs = True"]
def epilogue_codegen(self):
return []
@ -493,7 +490,6 @@ class UnbackedFuzzTemplate(FuzzTemplate):
def flags_codegen(self):
return [
"torch.set_default_device('cuda')",
"torch._dynamo.config.capture_scalar_outputs = True",
"torch._dynamo.config.capture_dynamic_output_shape_ops = True",
]

View File

@ -48,8 +48,54 @@ def persist_print(msg):
# List of regex patterns for ignore bucket
IGNORE_PATTERNS: list[re.Pattern] = [
re.compile(
r"torch\._inductor\.exc\.InductorError: AssertionError: -1"
), # https://github.com/pytorch/pytorch/issues/167937
r"Dynamo failed to run FX node with fake tensors: call_method fill_diagonal_"
), # https://github.com/pytorch/pytorch/issues/163420
re.compile(
r"TypeError: unsupported operand type\(s\) for divmod\(\): 'SymInt' and 'int'"
), # https://github.com/pytorch/pytorch/issues/163457
re.compile(
r"RuntimeError: self\.stride\(-1\) must be 1 to view ComplexDouble as"
), # https://github.com/pytorch/pytorch/issues/162561
re.compile(
r"BooleanAtom not allowed in this context"
), # https://github.com/pytorch/pytorch/issues/160726
re.compile(
r"TypeError\(\"unsupported operand type\(s\) for \*: 'SymBool' and 'FakeTensor'\"\)"
), # https://github.com/pytorch/pytorch/issues/164684
re.compile(r"KeyError: u\d+"), # https://github.com/pytorch/pytorch/issues/164685
re.compile(
r"torch\._inductor\.exc\.InductorError: CppCompileError: C\+\+ compile error"
), # https://github.com/pytorch/pytorch/issues/164686
re.compile(
r"\.item\(\) # dtype="
), # https://github.com/pytorch/pytorch/issues/164725
re.compile(
r"dimensionality of sizes \(0\) must match dimensionality of strides \(1\)"
), # https://github.com/pytorch/pytorch/issues/164814
re.compile(
r"self and mat2 must have the same dtype"
), # https://github.com/pytorch/pytorch/issues/165718
re.compile(
r"free\(\): invalid next size \(fast\)"
), # TODO: figure out why sometimes heap metadata gets corrupted on program exit (checks actually pass successfully)
re.compile(
r'assert "int" in str\(indices\.get_dtype\(\)\)'
), # https://github.com/pytorch/pytorch/issues/166042
re.compile(
r'self\.shape_env\.guard_or_defer_runtime_assert\(expr, "guard_equals"\)'
), # https://github.com/pytorch/pytorch/issues/166245
re.compile(
r"assert len\(self\.stride\) == len\(order\)"
), # https://github.com/pytorch/pytorch/issues/166270
re.compile(
r"assert len\(input_size\) == len\(new_size\)"
), # https://github.com/pytorch/pytorch/issues/166279
re.compile(
r"torch\._inductor\.exc\.InductorError: IndexError: list index out of range"
), # https://github.com/pytorch/pytorch/issues/166290
re.compile(
r"assert bool\(static_expr\)"
), # https://github.com/pytorch/pytorch/issues/166319
# Add more patterns here as needed, e.g.:
# re.compile(r"Some other error message"),
]

View File

@ -215,7 +215,6 @@ class Store:
def queue_pop(self, key: str, block: bool = True) -> bytes: ...
def queue_push(self, key: str, value: Union[bytes, str]) -> None: ...
def queue_len(self, key: str) -> int: ...
def list_keys(self) -> list[str]: ...
class FileStore(Store):
def __init__(self, path: str, numWorkers: int = ...) -> None: ...

View File

@ -1043,11 +1043,6 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
import inspect
if isinstance(mod, torch.nn.Module):
resolved_forward = mod.forward
if hasattr(resolved_forward, "__self__"):
# pyrefly: ignore [missing-attribute]
resolved_forward = resolved_forward.__func__
# Mirrored from NNModuleVariable.call_function:
# https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/nn_module.py#L1035
if (
@ -1059,12 +1054,7 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
and len(mod._backward_hooks) == 0
and len(torch.nn.modules.module._global_backward_pre_hooks) == 0
and len(torch.nn.modules.module._global_backward_hooks) == 0
and resolved_forward != torch.nn.Module.forward
):
# We cannot trace __call__ by default because it will break
# the legacy dynamo export. If we want to revisit this,
# feel free to remove this path and try unittests in
# test_strict_export_v2.py
mod = mod.forward
elif isinstance(mod, torch.fx.GraphModule):
mod = mod._call_impl

View File

@ -1528,6 +1528,37 @@ class OutputGraph(OutputGraphCommon):
from .decorators import disable
if has_user_objects():
# NB: This is where we store possible user objects before running the graph
# index_to_user_object_weakref is the function used in the graph to translate
# the dynamo-generated index into the actual object passed to the compiled function.
# We generate bytecode to store all user objects at the proper index in the below
# call.
codegen = PyCodegen(
self.root_tx, root, overridden_sources=overridden_sources
)
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__,
"store_user_object_weakrefs",
)
)
tmp_vars = []
for constructor in index_to_bytecode_constructor.values():
constructor(codegen)
var_name = (
self.new_var()
) # keep alive any temp objects for the rest of the frame
codegen.store(var_name)
tmp_vars.append(var_name)
for var_name in tmp_vars:
codegen.append_output(codegen.create_load(var_name))
codegen.call_function(len(index_to_bytecode_constructor), False)
codegen.pop_top()
self.add_output_instructions(codegen.get_instructions())
# to handle random calls
if len(self.random_calls) > 0:
random_calls_instructions = []
@ -2312,33 +2343,6 @@ class OutputGraph(OutputGraphCommon):
assert self.root_tx is not None
cg = PyCodegen(self.root_tx)
if has_user_objects():
# NB: This is where we store possible user objects before running the graph
# index_to_user_object_weakref is the function used in the graph to translate
# the dynamo-generated index into the actual object passed to the compiled function.
# We generate bytecode to store all user objects at the proper index in the below
# call.
cg.add_push_null(
lambda: cg.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__,
"store_user_object_weakrefs",
)
)
tmp_vars = []
for constructor in index_to_bytecode_constructor.values():
constructor(cg)
var_name = (
self.new_var()
) # keep alive any temp objects for the rest of the frame
cg.store(var_name)
tmp_vars.append(var_name)
for var_name in tmp_vars:
cg.append_output(cg.create_load(var_name))
cg.call_function(len(index_to_bytecode_constructor), False)
cg.pop_top()
for idx, arg in enumerate(self.graphargs):
self.export_metadata.graph_input_idx_to_local_source[idx] = arg.source
@ -3007,7 +3011,7 @@ class SubgraphTracer(fx.Tracer):
self.tracked_tensor_or_symint_vt: OrderedSet[VariableTracker] = OrderedSet()
def record_tensor_or_symint_vt(self, vt: VariableTracker):
def record_tensor_or_symint_vt(self, vt):
self.tracked_tensor_or_symint_vt.add(vt)
# preserve original meta if it is available

View File

@ -903,33 +903,11 @@ def reset_graph_break_dup_checker() -> None:
graph_break_dup_warning_checker.reset()
# Matches ANSI escape sequences (CSI)
ANSI_ESCAPE_PATTERN = re.compile(
r"""
\x1B # ESC
\[ # [
[0-?]* # Parameter bytes
[ -/]* # Intermediate bytes
[@-~] # Final byte
""",
re.VERBOSE,
)
class StripAnsiFormatter(logging.Formatter):
"""Logging formatter that strips ANSI escape codes."""
def format(self, record):
msg = super().format(record)
return ANSI_ESCAPE_PATTERN.sub("", msg)
def add_file_handler() -> contextlib.ExitStack:
log_path = os.path.join(get_debug_dir(), "torchdynamo")
os.makedirs(log_path, exist_ok=True)
log_file_handler = logging.FileHandler(os.path.join(log_path, "debug.log"))
log_file_handler.setFormatter(StripAnsiFormatter("%(message)s"))
logger = logging.getLogger("torch._dynamo")
logger.addHandler(log_file_handler)

View File

@ -2673,30 +2673,6 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
class PrintHigherOrderVariable(TorchHigherOrderOperatorVariable):
def _call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
from .builder import wrap_fx_proxy
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
args_proxy = [arg.as_proxy() for arg in args]
kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()}
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(args_proxy),
kwargs=kwargs_proxy,
),
)
class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable):
def _call_function(
self,
@ -4561,7 +4537,6 @@ _hop_name_to_variable_class = {
"associative_scan": AssociativeScanHigherOrderVariable,
"scan": ScanHigherOrderVariable,
"call_torchbind": CallTorchbindHigherOrderVariable,
"print": PrintHigherOrderVariable,
"wrap_with_set_grad_enabled": WrapWithSetGradEnabledHigherOrderVariable,
"wrap_with_autocast": WrapWithAutocastHigherOrderVariable,
"dynamo_bypassing_wrapper": DynamoBypassingWrapperHigherOrderVariable,

View File

@ -33,7 +33,6 @@ from .graph_capture_wrappers import (
handle_effect_tokens_fn,
)
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
from .streams import assign_backward_streams
from .utils import (
call_and_expect_output_descs,
copy_fwd_metadata_to_bw_nodes,
@ -474,9 +473,6 @@ def aot_dispatch_autograd_graph(
# fw node match might be erased
copy_fwd_metadata_to_bw_nodes(fx_g)
# After copying metadata, assign streams to gradient accumulation nodes
assign_backward_streams(fx_g)
fx_g.graph.eliminate_dead_code()
if not aot_config.disable_functionalization:
# There should be *NO* mutating ops in the graph at this point.

View File

@ -1,53 +0,0 @@
from typing import Optional, TypeAlias
import torch.fx
import torch.fx.traceback
from torch._dynamo.graph_utils import _get_flat_args
Node: TypeAlias = torch.fx.Node
def is_gradient_acc(node: Node) -> bool:
return node.meta.get("is_gradient_acc", False)
def get_stream(node: Node) -> Optional[int]:
maybe_annotation = node.meta.get("custom", None)
if maybe_annotation is not None:
return node.meta["custom"].get("stream", None)
else:
return None
def set_stream(node: Node, ind: int) -> None:
if "custom" in node.meta:
node.meta["custom"].update({"stream": ind})
else:
node.meta["custom"] = {"stream": ind}
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
"""Assigns backward streams to gradient accumulation nodes"""
# NB: iterate in reverse order to more closely match eager
# the user node stream will be populated first
for node in reversed(list(gm.graph.nodes)):
if is_gradient_acc(node):
# Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream:
# 1. Match first stream assignment of the first user with a stream
# 2. Match first stream assignment encountered in the args from left to right
# This differs from eager in some cases:
# Specifically the eager code uses the autograd node to determine the stream,
# crucially this does not necessarily correspond to the FX graph node. For example,
# in the backward for an add node with a constant we will passthrough and during backward tracing,
# no op will be added to the FX graph, so our stream assignment will differ in this case.
gradients = _get_flat_args(node, {})
users = list(node.users.keys())
# All gradients will be on same device, they will be coerced if they were not with a .to() node
for neighbor in users + gradients:
ind = get_stream(neighbor)
if ind is not None:
set_stream(node, ind)
break

View File

@ -78,13 +78,11 @@ class TuningProcess:
def workloop():
while True:
job, extra_env = TuningProcess.recv(read_pipe)
job = TuningProcess.recv(read_pipe)
if job is None:
# None is a sentinel for the child to shut down
break
try:
if extra_env:
os.environ.update(extra_env)
result = job()
except Exception as e:
result = e
@ -97,10 +95,8 @@ class TuningProcess:
pass
@staticmethod
def send(
obj: Any, write_pipe: IO[bytes], extra_env: dict[str, str] | None = None
) -> None:
pickle.dump((obj, extra_env), write_pipe)
def send(obj: Any, write_pipe: IO[bytes]) -> None:
pickle.dump(obj, write_pipe)
write_pipe.flush()
@staticmethod
@ -162,13 +158,13 @@ class TuningProcess:
"""
return self.running and self.process.poll() is None
def put(self, req: Any, extra_env: dict[str, str] | None = None) -> None:
def put(self, req: Any) -> None:
"""
Push a work item to the child process.
"""
if not self.alive():
self.start()
TuningProcess.send(req, self.write_pipe, extra_env=extra_env)
TuningProcess.send(req, self.write_pipe)
def get(self, timeout: float = 120.0) -> Any:
"""
@ -178,7 +174,7 @@ class TuningProcess:
try:
if not self.selector.select(timeout):
raise TimeoutError(f"Timeout in autotune subprocess {self.process.pid}")
result, _ = TuningProcess.recv(self.read_pipe)
result = TuningProcess.recv(self.read_pipe)
except TimeoutError:
self.kill()
raise
@ -309,10 +305,8 @@ class TuningProcessPool:
"""
assert choice.bmreq is not None
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
process = self.process_queue.get()
process.put(choice.bmreq.benchmark, extra_env=extra_env)
process.put(choice.bmreq.benchmark)
try:
return process.get(
config.max_autotune_subproc_result_timeout_seconds,

View File

@ -2819,8 +2819,6 @@ class SIMDScheduling(BaseScheduling):
bad_size_additional_tiling_penalty = 1.025
good_size_tiling_penalty = 1.005
total_uncoalesced = sum(coalesce_analysis.uncoalesced_addrs.values())
def score_mod(t):
score_factor = 1.0
for tile_size in t[0].tiling.values():
@ -2829,19 +2827,12 @@ class SIMDScheduling(BaseScheduling):
else:
score_factor = score_factor / good_size_tiling_penalty
# Add uncoalesced memory score to prevent small coalesced benefits
# from dominating large amounts of uncoalesced memory
uncoalesced_penalty = total_uncoalesced * 0.05
return -(t[0].score + uncoalesced_penalty) * score_factor
return -t[0].score * score_factor
# apply penalty for longer tilings that dont increase score much
for cand, tiling_score in sorted(tilings, key=score_mod):
if (
cls.tiling_is_compatible(
node_schedule, pointwise_numel, reduction_numel, cand.tiling
)
or cand.tiling == default_tiling
if cls.tiling_is_compatible(
node_schedule, pointwise_numel, reduction_numel, cand.tiling
):
# we always include default reduction numel == 1, dont include
tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0)

View File

@ -176,7 +176,6 @@ class OverlapPreservingBucketer:
head = None
prev_event = None
position = 0
hiding_nodes = OrderedSet()
for node in self.scheduled:
node_type = None
@ -184,12 +183,11 @@ class OverlapPreservingBucketer:
# Determine if this node is relevant for this PG
if node in self.collective_info and get_group_name(node) == pg:
node_type = "starts"
hiding_nodes |= self.collective_info[node].hiding_nodes
elif is_wait_tensor(node):
wait_input = node.args[0]
if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg:
node_type = "waits"
elif is_compute_node(node) or node in hiding_nodes:
elif is_compute_node(node):
node_type = "compute"
if node_type is None:
@ -207,6 +205,7 @@ class OverlapPreservingBucketer:
prev_event = event
position += 1
return head
def _populate_node_to_event(self, pg: str) -> None:
@ -223,12 +222,10 @@ class OverlapPreservingBucketer:
Add hiding interval constraints: start -> compute -> wait.
"""
for start, info in self.collective_info.items():
if info.is_exposed:
continue
for hn in info.hiding_nodes:
if info.hiding_node and not info.is_exposed:
# Enforce: start -> compute -> wait
self.aug_graph.add_extra_dep(n=hn, dep=start)
self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn)
self.aug_graph.add_extra_dep(n=info.hiding_node, dep=start)
self.aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
def bucket_collectives(self) -> None:
"""Main entry point for bucketing collectives."""
@ -361,13 +358,13 @@ class OverlapPreservingBucketer:
def _get_intervals(
self, event: PGEvent
) -> tuple[Optional[tuple[int, int]], list[tuple[int, int]]]:
"""Get (execution_interval, hiding_intervals) for a collective event.
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
"""Get (execution_interval, hiding_interval) for a collective event.
Returns:
(execution_interval, hiding_intervals) where:
(execution_interval, hiding_interval) where:
- execution_interval is (start_pos, wait_pos) or None
- hiding_intervals is a list of (start_pos, compute_pos) tuples, one for each hiding node
- hiding_interval is (start_pos, compute_pos) or None if no hiding node
Works for both start and wait events by looking up the collective info.
"""
@ -378,13 +375,13 @@ class OverlapPreservingBucketer:
elif event.is_wait:
wait_input = event.node.args[0]
if not isinstance(wait_input, fx.Node):
return None, []
return None, None
coll = wait_input
else:
return None, []
return None, None
if coll not in self.collective_info:
return None, []
return None, None
info = self.collective_info[coll]
start_event = self.node_to_event[coll]
@ -392,17 +389,14 @@ class OverlapPreservingBucketer:
execution_interval = (start_event.position, wait_event.position)
hiding_intervals = []
if info.hiding_nodes:
for hiding_node in info.hiding_nodes:
hiding_intervals.append(
(
start_event.position,
self.node_to_event[hiding_node].position,
)
)
hiding_interval = None
if info.hiding_node:
hiding_interval = (
start_event.position,
self.node_to_event[info.hiding_node].position,
)
return execution_interval, hiding_intervals
return execution_interval, hiding_interval
def _preserves_hiding_intervals(
self,
@ -430,9 +424,9 @@ class OverlapPreservingBucketer:
# Collect hiding compute positions for the bucket
bucket_hiding_compute_positions = []
for coll in all_bucketed_colls:
for coll_hiding_node in self.collective_info[coll].hiding_nodes:
if hiding_node := self.collective_info[coll].hiding_node:
bucket_hiding_compute_positions.append(
self.node_to_event[coll_hiding_node].position
self.node_to_event[hiding_node].position
)
# Get new positions
@ -484,10 +478,11 @@ class OverlapPreservingBucketer:
curr_event.node not in all_bucketed_colls
and curr_event.node not in all_bucketed_waits
):
exec_interval, hiding_interval_list = self._get_intervals(curr_event)
exec_interval, hiding_interval = self._get_intervals(curr_event)
if exec_interval:
execution_intervals.append(exec_interval)
hiding_intervals.extend(hiding_interval_list)
if hiding_interval:
hiding_intervals.append(hiding_interval)
curr_event = curr_event.next
curr_event = new_wait_event.prev
@ -496,10 +491,11 @@ class OverlapPreservingBucketer:
curr_event.node not in all_bucketed_colls
and curr_event.node not in all_bucketed_waits
):
exec_interval, hiding_interval_list = self._get_intervals(curr_event)
exec_interval, hiding_interval = self._get_intervals(curr_event)
if exec_interval:
execution_intervals.append(exec_interval)
hiding_intervals.extend(hiding_interval_list)
if hiding_interval:
hiding_intervals.append(hiding_interval)
curr_event = curr_event.prev
# Check: no hiding interval should be enclosed by any execution interval
@ -663,12 +659,12 @@ class OverlapPreservingBucketer:
return True
# Check if existing hiding node conflicts with candidate wait
for old_hiding_node in self.collective_info[coll].hiding_nodes:
if self._ancestor_dep(old_hiding_node, candidate_wait):
if hiding_node := self.collective_info[coll].hiding_node:
if self._ancestor_dep(hiding_node, candidate_wait):
return True
# Check if candidate hiding node conflicts with existing wait
for new_hiding_node in candidate_info.hiding_nodes:
if new_hiding_node := candidate_info.hiding_node:
if self._ancestor_dep(new_hiding_node, coll_wait):
return True

View File

@ -5,7 +5,7 @@ import logging
import sys
from collections import Counter, defaultdict
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Literal
import torch
@ -190,7 +190,7 @@ class CollectiveInfo:
size_bytes: int
estimated_time_ms: float
exposed_time_ms: float # How much of this collective is still exposed
hiding_nodes: OrderedSet[fx.Node] = field(default_factory=OrderedSet)
hiding_node: fx.Node | None = None # Node that hides this collective
@property
def is_exposed(self) -> bool:
@ -533,8 +533,6 @@ class OverlapScheduler:
self._handle_collective_start(node)
elif is_wait_tensor(node):
self._handle_wait(node)
elif node.op == "placeholder":
self._schedule(node)
else:
self._handle_other(node)
@ -563,13 +561,11 @@ class OverlapScheduler:
additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
for start_node, info in self.collective_info.items():
if info.is_exposed:
continue
for hn in info.hiding_nodes:
if info.hiding_node and not info.is_exposed:
# Compute depends on collective start (compute must wait for collective to start)
additional_deps[hn].add(start_node)
additional_deps[info.hiding_node].add(start_node)
# Wait depends on compute (wait must wait for compute to finish)
additional_deps[info.wait_node].add(hn)
additional_deps[info.wait_node].add(info.hiding_node)
# Apply effect tokens to preserve these dependencies
if additional_deps:
@ -710,8 +706,9 @@ class OverlapScheduler:
overlap_amount = min(info.exposed_time_ms, available_compute)
info.exposed_time_ms -= overlap_amount
available_compute -= overlap_amount
info.hiding_nodes.add(node)
if available_compute == 0:
if info.exposed_time_ms == 0:
info.hiding_node = node
elif available_compute == 0:
break
# Then, look for unscheduled collectives we can overlap
@ -804,7 +801,8 @@ class OverlapScheduler:
# after scheduling, which will account for latency reduction of bucketing
overlap_amount = min(available_compute_time, info.exposed_time_ms)
info.exposed_time_ms -= overlap_amount
info.hiding_nodes.add(compute_node)
if info.exposed_time_ms == 0:
info.hiding_node = compute_node
available_compute_time -= overlap_amount
def _find_schedulable_path(
@ -831,7 +829,7 @@ class OverlapScheduler:
# it's fine to schedule it
if is_wait_tensor(node):
info = self.collective_info[self.wait_to_start[node]]
if info.hiding_nodes and curr_compute_node not in info.hiding_nodes:
if info.hiding_node and info.hiding_node != curr_compute_node:
continue
elif node not in self.potentially_hidden_waits:
continue
@ -867,7 +865,7 @@ class OverlapScheduler:
) -> bool:
assert is_wait_tensor(wait_node)
info = self.collective_info[self.wait_to_start[wait_node]]
return not info.is_exposed and compute_node not in info.hiding_nodes
return not info.is_exposed and info.hiding_node != compute_node
def _schedule_path_to_collective(
self, path: OrderedSet[fx.Node], curr_compute_node: fx.Node
@ -886,7 +884,7 @@ class OverlapScheduler:
continue
info = self.collective_info[self.wait_to_start[node]]
assert curr_compute_node not in info.hiding_nodes
assert info.hiding_node != curr_compute_node
self._handle_wait(node)
continue

View File

@ -145,41 +145,6 @@ def solve_for_tiling(expr: sympy.Expr) -> Optional[sympy.Expr]:
return None
def find_broadcast_var(
index: sympy.Expr, var_ranges: dict[sympy.Expr, int]
) -> Optional[sympy.Expr]:
"""
Try to find the variable that this index is broadcast over.
A broadcast pattern is one where consecutive values of a variable
access the same memory location (e.g., x // 10).
"""
# Approximate analysis by evaluating at 1 and 0
variables: dict[sympy.Symbol, int] = {}
for v in index.free_symbols:
if v in var_ranges:
variables[v] = 0
else:
variables[v] = get_hint(v)
zero_index = sympy_subs(index, variables)
for v in var_ranges.keys():
if v not in index.free_symbols:
continue
variables[v] = 1
try:
new_val = sympy_subs(index, variables)
except ZeroDivisionError:
loop_tiling_log.info("zero division error %s %s", index, variables)
continue
# Broadcast means the value doesn't change when the variable increments
if new_val == zero_index:
return v
variables[v] = 0
return None
def find_coalesced_var(
index: sympy.Expr, var_ranges: dict[sympy.Expr, int]
) -> Optional[sympy.Expr]:
@ -603,12 +568,11 @@ def extract_normalized_read_writes(
return fused_out
def get_score(
addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int], buf_names: OrderedSet[str]
) -> int:
def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int:
"""
Score addr according to its approximate size.
Score addr according to its approximate size
"""
# TODO - deduplicate with candidate_tilings
var_sizes = []
for v in addr.free_symbols:
@ -623,15 +587,6 @@ def get_score(
)
def try_get_buf_size(buf_name: str) -> Optional[int]:
buf = V.graph.try_get_buffer(buf_name)
if not buf:
return None
return V.graph.sizevars.atomically_apply_size_hint(
sympy_product(buf.get_size()), fallback=config.unbacked_symint_fallback
)
def get_hint(v: Union[sympy.Expr, int]) -> int:
if isinstance(v, int):
return v
@ -657,8 +612,6 @@ class CoalesceVarAnalysis:
# TODO: separate into dataclass that olds mem, dtype, is_write
coalesced_by_var: dict[sympy.Expr, int]
uncoalesced_addrs: dict[sympy.Expr, int]
norm_read_writes: FusedNormalizedReadsWrites
suggested_split: Optional[VarTiling] = None
@ -704,40 +657,28 @@ def analyze_memory_coalescing(
if indirect_expr:
continue
size = get_score(memory_expr, var_ranges, buf_names)
size = get_score(memory_expr, var_ranges)
if size == 0:
continue
maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges)
# while broadcasting vars are not technically coalesced,
# accesses at least stay in cache, so they provide most of the benefit.
# treat the same for now.
if maybe_coalesced_var is None:
maybe_coalesced_var = find_broadcast_var(memory_expr, var_ranges)
total_score = 0
byte_multipler = 0
for buf_name in buf_names:
if (buf := V.graph.try_get_buffer(buf_name)) and (
buf_size := try_get_buf_size(buf_name)
):
# constrain by buf size since we'll read at most that many elements
# score could be more through either masking or by broadcasting (e.g. x // 16)
total_score += min(buf_size, size) * buf.dtype.itemsize
if buf := V.graph.try_get_buffer(buf_name):
byte_multipler += buf.dtype.itemsize
# coalesced writes more important
total_score *= 1 if is_read else 2
byte_multipler *= 1 if is_read else 2
if maybe_coalesced_var:
coalesced_by_var[maybe_coalesced_var] += total_score
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
else:
uncoalesced_addrs[memory_expr] += total_score
uncoalesced_addrs[memory_expr] += size * byte_multipler
if not uncoalesced_addrs:
return CoalesceVarAnalysis(
coalesced_by_var=coalesced_by_var,
uncoalesced_addrs=uncoalesced_addrs,
norm_read_writes=norm_read_writes,
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
)
# map from var -> tiling -> total_score
@ -781,9 +722,7 @@ def analyze_memory_coalescing(
if len(tiling_scores) == 0:
return CoalesceVarAnalysis(
coalesced_by_var=coalesced_by_var,
uncoalesced_addrs=uncoalesced_addrs,
norm_read_writes=norm_read_writes,
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
)
best_tiling: Optional[tuple[sympy.Expr, int]] = None
@ -797,9 +736,7 @@ def analyze_memory_coalescing(
if best_tiling is None:
return CoalesceVarAnalysis(
coalesced_by_var=coalesced_by_var,
uncoalesced_addrs=uncoalesced_addrs,
norm_read_writes=norm_read_writes,
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
)
# TODO - for strictly pointwise fusions,
@ -808,7 +745,6 @@ def analyze_memory_coalescing(
# TODO - could also prefer index var splits to reduction, better tested
return CoalesceVarAnalysis(
coalesced_by_var=coalesced_by_var,
uncoalesced_addrs=uncoalesced_addrs,
norm_read_writes=norm_read_writes,
suggested_split=VarTiling(best_tiling[0], best_tiling[1], best_tiling_score),
)

View File

@ -1,9 +0,0 @@
from ._core import ComplexTensor
from ._ops import ComplexTensorMode, is_complex_tensor
__all__ = ["ComplexTensor", "ComplexTensorMode", "is_complex_tensor"]
ComplexTensor.__module__ = __name__
ComplexTensorMode.__module__ = __name__
is_complex_tensor.__module__ = __name__

View File

@ -1,151 +0,0 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING
from typing_extensions import Self
import torch
from torch import Tensor
from torch.autograd import Function
if TYPE_CHECKING:
from torch._ops import OpOverload
from torch._prims_common import DeviceLikeType
from torch.autograd.function import FunctionCtx
class ComplexTensor(Tensor):
"""A class that decomposes all ops on complex Tensors into their real and imaginary parts."""
_re: Tensor
_im: Tensor
def __new__(cls, real: Tensor, imag: Tensor) -> Self:
"""Initialize a ComplexTensor from its real and imaginary parts."""
from ._ops.common import REAL_TO_COMPLEX
shape = real.shape
device = real.device
# TODO (hameerabbasi): `torch.compile` sometimes fails here without making these
# contiguous. Why?
real = real.contiguous()
imag = imag.contiguous()
# TODO (hameerabbasi):
# What should we do with dtype?
# We could convert to the complex type (float32 -> complex64), but we
# can't use that model for say `bfloat16` which does not have a
# corresponding complex dtype.
# If we want to support this complex rep using any float type (see
# https://github.com/pytorch/pytorch/issues/95100)
# We either need to:
# 1) add the complex types for say `complexbf32`, knowing they can't really be used anywhere
# else.
# 2) We use the real float dtype here, and it is up to the user to know
# that dtype=float<size> here really means complex<2xSize> with dtype
# matching that of re/im parts alone
# I'm going with 1 for now, so that I can make gradcheck and some complex
# ops work properly, but might want to discuss this in the RFP.
dtype = REAL_TO_COMPLEX.get(real.dtype)
if dtype is None:
raise TypeError(
"Unsupported dtype for constituent tensors. Supported dtypes are: "
f"{set(REAL_TO_COMPLEX.keys())!r}."
)
storage_offset = real.storage_offset()
strides = real.stride()
layout = real.layout
pin_memory = real.is_pinned()
assert shape == imag.shape, f"Expected imag shape {shape}, got {imag.shape}"
assert device == imag.device, (
f"Expected imag device {device}, got {imag.device}"
)
assert real.dtype == imag.dtype, (
f"Expected imag dtype {real.dtype}, got {imag.dtype}"
)
assert pin_memory == imag.is_pinned(), (
f"Expected imag pinning {pin_memory}, got {imag.is_pinned()}"
)
res = Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
shape,
device=device,
dtype=dtype,
storage_offset=storage_offset,
strides=strides,
pin_memory=pin_memory,
layout=layout,
requires_grad=False,
)
res._re = real.clone().detach()
res._im = imag.clone().detach()
return res
@property
def re(self) -> Tensor:
return self._re
@property
def im(self) -> Tensor:
return self._im
@classmethod
def __torch_dispatch__(
cls,
func: OpOverload,
types: tuple[type, ...],
args: tuple = (),
kwargs: dict | None = None,
):
from ._ops.common import lookup_complex
kwargs = {} if kwargs is None else kwargs
impl = lookup_complex(func, *args, **kwargs)
if impl is None:
return NotImplemented
return impl(*args, **kwargs)
@staticmethod
def from_interleaved(t: Tensor) -> ComplexTensor:
t_real = torch.real(t)
t_imag = torch.imag(t) if t.dtype.is_complex else torch.zeros_like(t_real)
return Complex.apply(t_real, t_imag)
def as_interleaved(self) -> Tensor:
return torch.complex(self.real, self.imag)
@staticmethod
def __tensor_unflatten__(
inner_tensors: dict[str, Tensor],
meta: Any,
outer_size: tuple[int, ...],
outer_stride: tuple[int, ...],
) -> ComplexTensor:
assert meta is None
re, im = inner_tensors["re"], inner_tensors["im"]
return ComplexTensor(re, im)
def __tensor_flatten__(self) -> tuple[list[str], Any]:
return ["re", "im"], None
def __repr__(self, *, tensor_contents=None) -> str:
return f"ComplexTensor(real={self.re!r}, imag={self.im!r})"
def is_pinned(self, device: DeviceLikeType | None = None) -> bool:
return self.re.is_pinned(device)
class Complex(Function):
@staticmethod
def forward(ctx: FunctionCtx, real: Tensor, imag: Tensor) -> ComplexTensor: # type: ignore[bad-override]
return ComplexTensor(real, imag)
@staticmethod
def backward(ctx: FunctionCtx, grad_output: ComplexTensor) -> tuple[Tensor, Tensor]: # type: ignore[bad-override]
return grad_output.real, grad_output.imag

View File

@ -1,5 +0,0 @@
from . import aten, prims
from .common import ComplexTensorMode, is_complex_tensor
__all__ = ["ComplexTensorMode", "is_complex_tensor", "aten", "prims"]

View File

@ -1,921 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from .._core import ComplexTensor
from .common import (
_get_func_name,
COMPLEX_TO_REAL,
complex_to_real_dtype,
is_complex,
OpType,
promote_tensors,
register_binary_nonlinear,
register_complex,
register_error,
register_force_test,
register_simple,
split_complex_arg,
split_complex_tensor,
)
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from typing import Any
aten = torch.ops.aten
def register_binary_linear(op: OpType):
def impl_with_alpha(
lhs: ComplexTensor, rhs: ComplexTensor, *args, alpha, **kwargs
) -> ComplexTensor:
return op(lhs, aten.mul(rhs, alpha, *args, **kwargs), *args, **kwargs)
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
alpha = kwargs.pop("alpha", None)
if alpha is not None:
return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs)
a_r, a_i = split_complex_arg(lhs)
b_r, b_i = split_complex_arg(rhs)
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
u = op(a_r, b_r, *args, **kwargs)
v = op(a_i, b_i, *args, **kwargs)
return ComplexTensor(u.to(out_dt), v.to(out_dt))
return register_complex(op, impl)
@register_complex(aten.real)
def real_impl(self: ComplexTensor) -> torch.Tensor:
re, _ = split_complex_tensor(self)
return re
@register_complex(aten.imag)
def imag_impl(self: ComplexTensor) -> torch.Tensor:
_, im = split_complex_tensor(self)
return im
@register_complex(aten.is_pinned)
def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> bool:
return self.is_pinned(device)
SIMPLE_OPS_LIST = [
aten.slice,
aten.flatten,
aten.view,
aten.diagonal,
aten.expand,
aten.unsqueeze,
aten.unsqueeze_,
aten.mean,
aten.sum,
aten.clone,
aten.neg,
aten.flip,
aten.permute,
aten.repeat,
aten.index_select,
aten.split,
aten.split_with_sizes,
aten.cumsum,
aten.detach,
aten.select,
aten.squeeze,
aten.zero_,
aten.transpose,
aten.t,
aten.gather,
]
for simple_op in SIMPLE_OPS_LIST:
globals()[_get_func_name(simple_op)] = register_simple(simple_op)
# TODO (hameerabbasi): Not being tested
SIMPLE_FORCE_TESTED_OPS = [
aten.copy,
aten.col2im,
aten.alias,
aten.lift_fresh,
aten._unsafe_view,
aten.index,
aten._neg_view,
aten.avg_pool2d,
aten.avg_pool3d,
aten.avg_pool2d_backward,
aten.avg_pool3d_backward,
aten.masked_scatter_backward,
aten.select_backward,
aten.slice_backward,
aten.embedding,
]
for simple_op in SIMPLE_FORCE_TESTED_OPS:
globals()[_get_func_name(simple_op)] = register_force_test(
simple_op, register_simple(simple_op)
)
del simple_op
# some binary ops which we can stamp out
mul_impl = register_binary_nonlinear(aten.mul)
mul__impl = register_binary_nonlinear(aten.mul_)
mm_impl = register_binary_nonlinear(aten.mm)
dot_impl = register_binary_nonlinear(aten.dot)
bmm_impl = register_binary_nonlinear(aten.bmm)
# TODO (hameerabbasi): Not being tested
convolution_impl = register_force_test(
aten.convolution, register_binary_nonlinear(aten.convolution)
)
slice_scatter_impl = register_force_test(
aten.slice_scatter, register_binary_linear(aten.slice_scatter)
)
select_scatter_impl = register_force_test(
aten.select_scatter, register_binary_linear(aten.select_scatter)
)
add_impl = register_binary_linear(aten.add)
add__impl = register_binary_linear(aten.add_)
sub_impl = register_binary_linear(aten.sub)
sub__impl = register_binary_linear(aten.sub_)
diagonal_scatter_impl = register_binary_linear(aten.diagonal_scatter)
fill__impl = register_binary_linear(aten.fill_)
@register_complex(aten.rsub)
def rsub_impl(lhs: ComplexTensor, rhs: ComplexTensor, alpha=None) -> ComplexTensor:
if alpha is None:
return torch.sub(rhs, lhs) # type: ignore[bad-return]
return torch.sub(rhs, lhs, alpha=alpha) # type: ignore[bad-return]
@register_complex(aten.div)
@register_complex(aten.true_divide)
def div_impl(lhs: ComplexTensor, rhs: ComplexTensor, *, rounding_mode=None):
if rounding_mode is not None:
raise NotImplementedError(
"`rounding_mode` other than `None` not implemented for`ComplexTensor`."
)
a_r, a_i = split_complex_tensor(lhs)
if not is_complex(rhs):
return ComplexTensor(a_r / rhs, a_i / rhs)
b_r, b_i = split_complex_arg(rhs)
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
num_r = a_r * b_r + a_i * b_i
num_i = a_i * b_r - a_r * b_i
den = b_r * b_r + b_i * b_i
return ComplexTensor(
(num_r / den).to(out_dt),
(num_i / den).to(out_dt),
)
@register_complex(aten.reciprocal)
def reciprocal_impl(self: ComplexTensor):
self_r, self_i = split_complex_tensor(self)
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
den = self_r * self_r + self_i * self_i
return ComplexTensor(
aten.div(self_r, den).to(out_dt),
aten.div(-self_i, den).to(out_dt),
)
# reductions
@register_complex(aten.prod)
def prod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
out_dt, (self,) = promote_tensors(self)
dtype = kwargs.pop("dtype", out_dt)
kwargs["dtype"] = complex_to_real_dtype(self.dtype)
prod_r = torch.prod(torch.abs(self), *args, **kwargs)
sum_phi = torch.sum(torch.angle(self), *args, **kwargs)
u = prod_r * torch.cos(sum_phi)
v = prod_r * torch.sin(sum_phi)
return ComplexTensor(u, v).to(dtype) # type: ignore[bad-return]
@register_complex(aten.pow)
def pow_impl(self: ComplexTensor, exponent: ComplexTensor) -> ComplexTensor:
out_dt, (self, exponent) = promote_tensors(self, exponent)
return torch.exp(exponent * torch.log(self)).to(out_dt) # type: ignore[bad-return]
@register_complex(aten.cumprod)
def cumprod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor:
dtype = kwargs.pop("dtype", self.dtype)
kwargs["dtype"] = complex_to_real_dtype(dtype)
prod_r = torch.cumprod(torch.abs(self), *args, **kwargs)
sum_phi = torch.cumsum(torch.angle(self), *args, **kwargs)
u = prod_r * torch.cos(sum_phi)
v = prod_r * torch.sin(sum_phi)
return ComplexTensor(u, v)
# unary funcs,
# most of these are simple or require some kind of identity
@register_complex(aten.abs)
def abs_impl(self: ComplexTensor) -> torch.Tensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
result = torch.hypot(x, y)
return result.to(out_dt)
@register_complex(aten.angle)
def angle_impl(self: ComplexTensor) -> torch.Tensor:
x, y = split_complex_tensor(self)
return torch.atan2(y, x)
@register_complex(aten.acos)
def acos_impl(self: ComplexTensor) -> ComplexTensor:
_, y = split_complex_tensor(self)
acosh_z = torch.acosh(self)
assert isinstance(acosh_z, ComplexTensor)
acosh_z_re, acosh_z_im = split_complex_tensor(acosh_z)
sign_im = 2 * torch.signbit(y) - 1
return ComplexTensor(torch.abs(acosh_z_im), sign_im * torch.abs(acosh_z_re))
@register_complex(aten.asin)
def asin_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
asinh_iz = torch.asinh(ComplexTensor(-y, x))
assert isinstance(asinh_iz, ComplexTensor)
asinh_iz_re, asinh_iz_im = split_complex_tensor(asinh_iz)
return ComplexTensor(asinh_iz_im, -asinh_iz_re)
@register_complex(aten.atan)
def atan_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
tanh_iz = torch.atanh(ComplexTensor(-y, x))
assert isinstance(tanh_iz, ComplexTensor)
tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz)
return ComplexTensor(tanh_iz_im, -tanh_iz_re)
@register_complex(aten.asinh)
def asinh_impl(self: ComplexTensor) -> ComplexTensor:
out_dt, (self,) = promote_tensors(self)
return torch.log(self + torch.sqrt(self * self + 1)).to(out_dt) # type: ignore[bad-return]
@register_complex(aten.acosh)
def acosh_impl(self: ComplexTensor) -> ComplexTensor:
out_dt, (self,) = promote_tensors(self)
return torch.log(self + torch.sqrt(self * self - 1)).to(out_dt) # type: ignore[bad-return]
@register_complex(aten.atanh)
def atanh_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
ret = 0.5 * (
torch.log(ComplexTensor(1 + x, y)) - torch.log(ComplexTensor(1 - x, -y))
)
assert isinstance(ret, ComplexTensor)
ret_re, ret_im = split_complex_tensor(ret)
return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt))
@register_complex(aten.cos)
def cos_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
return torch.cosh(ComplexTensor(-y, x)) # type: ignore[bad-return]
@register_complex(aten.cosh)
def cosh_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
u = torch.cosh(x) * torch.cos(y)
v = torch.sinh(x) * torch.sin(y)
return ComplexTensor(u.to(out_dt), v.to(out_dt))
@register_complex(aten.sin)
def sin_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
sinh_iz = torch.sinh(ComplexTensor(-y, x))
assert isinstance(sinh_iz, ComplexTensor)
sinh_iz_re, sinh_iz_im = split_complex_tensor(sinh_iz)
return ComplexTensor(sinh_iz_im, -sinh_iz_re)
@register_complex(aten.sinh)
def sinh_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
u = torch.sinh(x) * torch.cos(y)
v = torch.cosh(x) * torch.sin(y)
return ComplexTensor(u.to(out_dt), v.to(out_dt))
@register_complex(aten.tan)
def tan_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
tanh_iz = torch.tanh(ComplexTensor(-y, x))
assert isinstance(tanh_iz, ComplexTensor)
tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz)
return ComplexTensor(tanh_iz_im, -tanh_iz_re)
@register_complex(aten.tanh)
def tanh_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
_2x = 2 * x
_2y = 2 * y
_d = torch.cosh(_2x) + torch.cos(_2y)
_2xsh = torch.sinh(_2x)
out_re = _2xsh / _d
out_im = torch.sin(_2y) / _d
return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt))
@register_complex(aten.exp)
def exp_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
ex = torch.exp(x)
u = ex * torch.cos(y)
v = ex * torch.sin(y)
return ComplexTensor(u.to(out_dt), v.to(out_dt))
@register_complex(aten.expm1)
def expm1_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
out_dt, (x, y) = promote_tensors(x, y)
# TODO (hameerabbasi): The two lines below may have numerical issues
ex = torch.exp(x)
u = ex * torch.cos(y) - 1
v = ex * torch.sin(y)
return ComplexTensor(u.to(out_dt), v.to(out_dt))
@register_complex(aten.log)
def log_impl(self: ComplexTensor) -> ComplexTensor:
out_dt, (self,) = promote_tensors(self)
re = torch.log(torch.abs(self))
im = torch.angle(self)
return ComplexTensor(re, im).to(out_dt) # type: ignore[bad-return]
@register_complex(aten.log1p)
def log1p_impl(self: ComplexTensor) -> ComplexTensor:
x, y = split_complex_tensor(self)
# TODO (hameerabbasi): The line below may have numerical issues
return torch.log(ComplexTensor(x + 1, y)) # type: ignore[bad-return]
@register_complex(aten.any)
def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
x, y = split_complex_tensor(self)
return torch.any(x, *args, **kwargs) | torch.any(y, *args, **kwargs)
@register_complex(aten.all)
def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
x, y = split_complex_tensor(self)
return torch.any(x, *args, **kwargs) & torch.any(y, *args, **kwargs)
@register_complex(aten.eq)
def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor:
a_r, a_i = split_complex_arg(self)
b_r, b_i = split_complex_arg(rhs)
return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs)
@register_complex(aten.ne)
def ne_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor:
a_r, a_i = split_complex_tensor(self)
b_r, b_i = split_complex_arg(rhs)
return torch.ne(a_r, b_r, *args, **kwargs) | torch.ne(a_i, b_i, *args, **kwargs)
@register_complex(aten.isnan)
def isnan_impl(self: ComplexTensor) -> torch.Tensor:
re, im = split_complex_tensor(self)
return torch.isnan(re) | torch.isnan(im)
@register_complex(aten.isinf)
def isinf_impl(self: ComplexTensor) -> torch.Tensor:
re, im = split_complex_tensor(self)
return torch.isinf(re) | torch.isinf(im)
@register_complex(aten.isfinite)
def isfinite_impl(self: ComplexTensor) -> torch.Tensor:
re, im = split_complex_tensor(self)
return torch.isfinite(re) & torch.isfinite(im)
@register_complex(aten.isclose)
def isclose_impl(
self: ComplexTensor,
rhs: ComplexTensor,
rtol=1e-5,
atol=1e-8,
equal_nan: bool = False,
) -> torch.Tensor:
abs_diff = torch.abs(self - rhs)
abs_other = torch.abs(rhs)
basic_condition = abs_diff <= (rtol * abs_other + atol)
# This is the nontrivial part
if equal_nan:
a_r, a_i = split_complex_tensor(self)
b_r, b_i = split_complex_arg(rhs)
a_r_nan = torch.isnan(a_r)
b_r_nan = torch.isnan(b_r)
a_i_nan = torch.isnan(a_i)
b_i_nan = torch.isnan(b_i)
a_nan = a_r_nan | a_i_nan
# This logical expression makes sure that the isnan of both the real and imaginary parts
# matches (so 1 + nan*i doesn't equal nan + 1*i)
equal_nan_condition = ((a_r_nan == b_r_nan) & (a_i_nan == b_i_nan)) & a_nan
return basic_condition | equal_nan_condition
return basic_condition
ERROR_OPS_LIST = [
aten.lt,
aten.le,
aten.gt,
aten.ge,
aten.amin,
aten.amax,
aten.clamp,
aten.ceil,
aten.floor,
aten.minimum,
aten.maximum,
aten.trunc,
aten.sign,
aten.argmax,
aten.argmin,
aten.sort,
aten.topk,
aten.round,
aten.fmod,
]
ERROR_TYPES = {
aten.minimum: RuntimeError,
aten.maximum: RuntimeError,
aten.argmax: RuntimeError,
aten.argmin: RuntimeError,
aten.sort: RuntimeError,
aten.topk: RuntimeError,
}
for err_op in ERROR_OPS_LIST:
globals()[_get_func_name(err_op)] = register_error(
err_op, ERROR_TYPES.get(err_op, NotImplementedError)
)
del err_op
@register_complex(aten.masked_scatter)
def masked_scatter_impl(
self: ComplexTensor, mask: torch.Tensor, source: ComplexTensor
) -> ComplexTensor:
self_r, self_i = split_complex_tensor(self)
source_r, source_i = split_complex_arg(source)
ret_r = torch.masked_scatter(self_r, mask, source_r)
ret_i = torch.masked_scatter(self_i, mask, source_i)
return ComplexTensor(ret_r, ret_i)
@register_complex(aten.where)
def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> ComplexTensor:
x_r, x_i = split_complex_arg(x)
y_r, y_i = split_complex_arg(y)
ret_r = torch.where(mask, x_r, y_r)
ret_i = torch.where(mask, x_i, y_i)
return ComplexTensor(ret_r, ret_i)
@register_complex(aten.full_like)
def full_like_impl(
input: ComplexTensor,
fill_value: complex,
*args,
dtype: torch.dtype | None = None,
**kwargs,
) -> torch.Tensor | ComplexTensor:
# Note: Cannot be merged with the cases below due to the `fill_value` argument
input_r, input_i = split_complex_tensor(input)
if dtype is not None and dtype not in COMPLEX_TO_REAL:
return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs)
if dtype is not None:
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
fv_r, fv_i = split_complex_arg(fill_value)
ret_r = torch.full_like(input_r, fv_r, *args, **kwargs)
ret_i = torch.full_like(input_i, fv_i, *args, **kwargs)
return ComplexTensor(ret_r, ret_i)
def register_like(op: OpType) -> Callable[..., torch.Tensor | ComplexTensor]:
def impl(
self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs
) -> torch.Tensor | ComplexTensor:
self_re, self_im = split_complex_tensor(self)
if dtype is not None and dtype not in COMPLEX_TO_REAL:
return op(self_re, *args, dtype=dtype, **kwargs)
if dtype is not None:
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
ret_re = op(self_re, *args, **kwargs)
ret_im = op(self_im, *args, **kwargs)
return ComplexTensor(ret_re, ret_im)
func_name = _get_func_name(op)
impl.__name__ = func_name
impl.__qualname__ = func_name
return register_complex(op, impl)
LIKE_OPS_LIST = [
aten.empty_like,
aten.zeros_like,
aten.randn_like,
aten.new_zeros,
]
for like_op in LIKE_OPS_LIST:
globals()[_get_func_name(like_op)] = register_like(like_op)
del like_op
@register_complex(aten.cat)
def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor:
tensors_r = []
tensors_i = []
for t in tensors:
t_r, t_i = split_complex_arg(t)
tensors_r.append(t_r)
tensors_i.append(t_i)
ret_r = torch.cat(tensors_r, dim=dim)
ret_i = torch.cat(tensors_i, dim=dim)
return ComplexTensor(ret_r, ret_i)
@register_complex(aten.sgn)
def sgn_impl(self: ComplexTensor) -> ComplexTensor:
self_r, self_i = split_complex_tensor(self)
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
abs_self = torch.abs(ComplexTensor(self_r, self_i))
mask = (self_r != 0) | (self_i != 0)
masked_sgn = ComplexTensor(
(self_r / abs_self).to(out_dt), (self_i / abs_self).to(out_dt)
)
return torch.where(mask, masked_sgn, 0) # type: ignore[bad-return]
@register_complex(aten.sqrt)
def sqrt_impl(self: ComplexTensor) -> ComplexTensor:
self_r, self_i = split_complex_tensor(self)
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
self = ComplexTensor(self_r, self_i)
self_abs_sqrt = torch.sqrt(torch.abs(self))
self_half_angle = 0.5 * torch.angle(self)
ret_r = self_abs_sqrt * torch.cos(self_half_angle)
ret_i = self_abs_sqrt * torch.sin(self_half_angle)
return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt))
@register_complex(aten.rsqrt)
def rsqrt_impl(self: ComplexTensor) -> ComplexTensor:
self_r, self_i = split_complex_tensor(self)
out_dt, (self_r, self_i) = promote_tensors(self_r, self_i)
self = ComplexTensor(self_r, self_i)
self_abs_rsqrt = torch.rsqrt(torch.abs(self))
self_neg_half_angle = -0.5 * torch.angle(self)
ret_r = self_abs_rsqrt * torch.cos(self_neg_half_angle)
ret_i = self_abs_rsqrt * torch.sin(self_neg_half_angle)
return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt))
@register_complex(aten.addmm)
def addmm_impl(
input: ComplexTensor,
mat1: ComplexTensor,
mat2: ComplexTensor,
out_dtype: torch.dtype | None = None,
beta: complex = 1,
alpha: complex = 1,
) -> ComplexTensor:
ret = beta * input + alpha * torch.mm(mat1, mat2)
assert isinstance(ret, ComplexTensor)
ret_r, ret_i = split_complex_tensor(ret)
if out_dtype is not None:
out_dtype = COMPLEX_TO_REAL[out_dtype]
ret_r, ret_i = ret_r.to(out_dtype), ret_i.to(out_dtype)
return ComplexTensor(ret_r, ret_i)
def elemwise_nonzero(self: ComplexTensor) -> torch.Tensor:
re, im = split_complex_tensor(self)
return (re != 0) | (im != 0)
def register_nonzero_impl(op: OpType):
def nonzero_impl(
self: ComplexTensor, other: ComplexTensor, *args, **kwargs
) -> torch.Tensor:
return op(elemwise_nonzero(self), elemwise_nonzero(other), *args, **kwargs)
func_name = _get_func_name(op)
nonzero_impl.__name__ = func_name
nonzero_impl.__qualname__ = func_name
return register_complex(op, nonzero_impl)
logical_and_impl = register_nonzero_impl(aten.logical_and)
logical_or_impl = register_nonzero_impl(aten.logical_or)
logical_xor_impl = register_nonzero_impl(aten.logical_xor)
@register_complex(aten.logical_not)
def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
return torch.logical_not(elemwise_nonzero(self), *args, **kwargs)
@register_complex(aten.view_as_real)
def view_as_real_impl(self: ComplexTensor) -> torch.Tensor:
re, im = split_complex_tensor(self)
return torch.stack([re, im], dim=-1)
@register_complex(aten.linalg_vector_norm)
def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs)
@register_force_test(aten.copy_)
def copy__impl(self: ComplexTensor, src, *args, **kwargs):
self_re, self_im = split_complex_tensor(self)
src_re, src_im = split_complex_arg(src)
ret_re = self_re.copy_(src_re, *args, **kwargs)
ret_im = self_im.copy_(src_im, *args, **kwargs)
return ComplexTensor(ret_re, ret_im)
@register_complex(aten._local_scalar_dense)
def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex:
x, y = split_complex_tensor(self)
u = aten._local_scalar_dense(x, *args, **kwargs)
v = aten._local_scalar_dense(y, *args, **kwargs)
return complex(u, v)
@register_complex(aten.allclose)
def allclose_impl(
input: torch.Tensor,
other: torch.Tensor,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> bool:
return torch.all(
torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
).item() # type: ignore[bad-return]
@register_complex(aten.stack)
def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor:
re_im_tuples = [split_complex_arg(self_i) for self_i in self]
u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs)
v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs)
return ComplexTensor(u, v)
# TODO (hameerabbasi): Not being tested
@register_complex(aten._conj_physical)
@register_complex(aten.conj_physical)
def conj_physical_impl(self: ComplexTensor) -> ComplexTensor:
re, im = split_complex_tensor(self)
return ComplexTensor(re, -im)
# TODO (hameerabbasi): Not being tested
@register_complex(aten._conj)
def _conj_impl(self: ComplexTensor) -> ComplexTensor:
re, im = split_complex_tensor(self)
return ComplexTensor(re, torch._neg_view(im))
@register_complex(aten.index_add)
def index_add_impl(
self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs
) -> ComplexTensor:
alpha = kwargs.pop("alpha", None)
if alpha is not None:
source = source * alpha
self_re, self_im = split_complex_arg(self)
source_re, source_im = split_complex_arg(source)
ret_re = self_re.index_add(dim, index, source_re)
ret_im = self_im.index_add(dim, index, source_im)
return ComplexTensor(ret_re, ret_im)
# TODO (hameerabbasi): Not being tested
@register_complex(aten.index_add_)
def index_add__impl(
self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs
) -> ComplexTensor:
alpha = kwargs.pop("alpha", None)
if alpha is not None:
source = source * alpha
self_re, self_im = split_complex_arg(self)
source_re, source_im = split_complex_arg(source)
ret_re = self_re.index_add_(dim, index, source_re)
ret_im = self_im.index_add_(dim, index, source_im)
return ComplexTensor(ret_re, ret_im)
@register_complex(aten.masked_fill)
def masked_fill_impl(
self: ComplexTensor, mask: torch.Tensor, value: complex
) -> ComplexTensor:
self_re, self_im = split_complex_arg(self)
value_re, value_im = split_complex_arg(value)
ret_re = self_re.masked_fill(mask, value_re)
ret_im = self_im.masked_fill(mask, value_im)
return ComplexTensor(ret_re, ret_im)
# TODO (hameerabbasi): Not being tested
@register_complex(aten.masked_fill_)
def masked_fill__impl(
self: ComplexTensor, mask: torch.Tensor, value: complex
) -> ComplexTensor:
self_re, self_im = split_complex_arg(self)
value_re, value_im = split_complex_arg(value)
ret_re = self_re.masked_fill_(mask, value_re)
ret_im = self_im.masked_fill_(mask, value_im)
return ComplexTensor(ret_re, ret_im)
@register_complex(aten.constant_pad_nd)
def constant_pad_nd_impl(
self: ComplexTensor, pad, value: complex | None = None
) -> ComplexTensor:
self_re, self_im = split_complex_tensor(self)
if value is None:
ret_re = aten.constant_pad_nd(self_re, pad)
ret_im = aten.constant_pad_nd(self_im, pad)
else:
value_re, value_im = split_complex_arg(value)
ret_re = aten.constant_pad_nd(self_re, pad, value_re)
ret_im = aten.constant_pad_nd(self_im, pad, value_im)
return ComplexTensor(ret_re, ret_im)
@register_complex(aten.var)
def var_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor:
self_re, self_im = split_complex_tensor(self)
return torch.var(self_re, *args, **kwargs) + torch.var(self_im, *args, **kwargs)
@register_complex(aten.scatter_add)
def scatter_add_impl(
self: ComplexTensor, dim, index, src: ComplexTensor
) -> ComplexTensor:
self_re, self_im = split_complex_arg(self)
src_re, src_im = split_complex_arg(src)
ret_re = torch.scatter_add(self_re, dim, index, src_re)
ret_im = torch.scatter_add(self_im, dim, index, src_im)
return ComplexTensor(ret_re, ret_im)
@register_complex(aten.scatter_add_)
def scatter_add__impl(
self: ComplexTensor, dim, index, src: ComplexTensor
) -> ComplexTensor:
self_re, self_im = split_complex_arg(self)
src_re, src_im = split_complex_arg(src)
out_re = self_re.scatter_add_(dim, index, src_re)
out_im = self_im.scatter_add_(dim, index, src_im)
return ComplexTensor(out_re, out_im)
@register_complex(aten.index_put_)
def index_put__impl(
self: ComplexTensor,
indices: tuple[torch.Tensor, ...],
values: ComplexTensor,
accumulate: bool = False,
) -> ComplexTensor:
self_re, self_im = split_complex_arg(self)
values_re, values_im = split_complex_arg(values)
out_re = self_re.index_put_(indices, values_re, accumulate=accumulate)
out_im = self_im.index_put_(indices, values_im, accumulate=accumulate)
return ComplexTensor(out_re, out_im)
@register_complex(aten.tanh_backward)
def tanh_backward(out_grad: torch.Tensor, y: torch.Tensor):
return out_grad * (1.0 - y * y).conj_physical()
@register_complex(aten.diagonal_backward)
def diagonal_backward(
grad_output: torch.Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int
):
grad_input = grad_output.new_zeros(input_sizes)
return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
def _dt_to_real(dt: torch.dtype | Any) -> torch.dtype | Any:
if not isinstance(dt, torch.dtype):
return dt
return COMPLEX_TO_REAL[dt]
def register_to_impl(op: OpType):
"""Register an op similar to `aten.to`, but may have different signatures."""
def impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor:
x, y = split_complex_tensor(self)
try:
args = tuple(_dt_to_real(a) for a in args)
kwargs = {k: _dt_to_real(v) for k, v in kwargs.items()}
except KeyError:
return op(x, *args, **kwargs)
return ComplexTensor(op(x, *args, **kwargs), op(y, *args, **kwargs))
func_name = _get_func_name(op)
impl.__name__ = func_name
impl.__qualname__ = func_name
return register_complex(op, impl)
to_impl = register_to_impl(aten.to)
_to_copy_impl = register_to_impl(aten._to_copy)

View File

@ -1,317 +0,0 @@
from collections.abc import Callable
from typing import Any, overload, TypeAlias
from typing_extensions import TypeIs
import torch
from torch import Tensor
from torch._decomp import get_decompositions
from torch._ops import OpOverload, OpOverloadPacket
from torch._refs import is_complex as _is_complex
from torch.types import Number
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from .._core import ComplexTensor
OpType: TypeAlias = OpOverloadPacket | OpOverload
TableType: TypeAlias = dict[OpType, Callable]
# Mapping from ops to implementations
COMPLEX_OPS_TABLE: TableType = {}
COMPLEX_TO_REAL = {
torch.complex128: torch.float64,
torch.complex64: torch.float32,
torch.complex32: torch.float16,
}
REAL_TO_COMPLEX = {v: k for k, v in COMPLEX_TO_REAL.items()}
# Used to promote dtypes in `promote_real_cpu_tensors`
PROMOTE_TYPES = {
torch.float16: torch.float32,
torch.bfloat16: torch.float32,
torch.complex32: torch.complex64,
}
def is_complex_tensor(obj: Any, /) -> TypeIs[ComplexTensor]:
r"""Returns True if the input is a ComplexTensor, else False
Args:
a: any input
Examples:
>>> # xdoctest: +SKIP
>>> from torch.complex import ComplexTensor
>>> data = torch.zeros((3, 2), dtype=torch.complex64)
>>> ct = ComplexTensor.from_interleaved(data)
>>> is_complex_tensor(ct)
True
"""
return isinstance(obj, ComplexTensor)
@overload
def promote_tensors(
*tensors: ComplexTensor,
) -> tuple[torch.dtype, tuple[ComplexTensor, ...]]: ...
@overload
def promote_tensors(
*tensors: Tensor,
) -> tuple[torch.dtype, tuple[Tensor, ...]]: ...
def promote_tensors(
*tensors: Tensor | ComplexTensor,
) -> tuple[torch.dtype, tuple[Tensor | ComplexTensor, ...]]:
"""
Promotes all tensors to a common dtype.
Additionally promotes CPU tensors to at least `float32`.
"""
tensor = next(t for t in tensors if isinstance(t, Tensor))
out_dt = tensor.dtype
for t in tensors:
if isinstance(t, Tensor):
out_dt = torch.promote_types(out_dt, t.dtype)
prom_dt = PROMOTE_TYPES.get(out_dt, out_dt)
return out_dt, tuple(
t.to(prom_dt) if isinstance(t, Tensor) else torch.asarray(t, dtype=prom_dt)
for t in tensors
)
def register_complex(
op: OpType,
func_impl: Callable | None = None,
):
"""Decorator to register an implementation for some ops in some dispatch tables"""
def inner(func):
if COMPLEX_OPS_TABLE.get(op, func) is not func:
raise RuntimeError(f"Attempted to register multiple functions for {op}")
COMPLEX_OPS_TABLE[op] = func
return func
if func_impl is None:
return inner
return inner(func_impl)
FORCE_TEST_LIST: list[OpType] = []
def register_force_test(op: OpType, *args, **kwargs):
"""Will attempt to test these ops even if they err on "normal" inputs"""
FORCE_TEST_LIST.append(op)
return register_complex(op, *args, **kwargs)
DECOMPOSITIONS = get_decompositions(list(torch.ops.aten)) # type: ignore[no-matching-overload]
def lookup_complex(func: OpOverload, *args, **kwargs) -> Callable | None:
"""
Lookup an impl from the table.
Try the particular overload first, then the overload packet.
If nothing is found, try the decompositions with both.
"""
return COMPLEX_OPS_TABLE.get(
func,
COMPLEX_OPS_TABLE.get(
func.overloadpacket,
DECOMPOSITIONS.get(func, DECOMPOSITIONS.get(func.overloadpacket)),
),
)
def is_complex(x: Any, /) -> bool:
"""Utility to detect if a given object is (known) to be complex."""
return (isinstance(x, Tensor) and _is_complex(x)) or isinstance(x, complex)
@overload
def split_complex_arg(
arg: Tensor | ComplexTensor,
) -> tuple[Tensor, Tensor]: ...
@overload
def split_complex_arg(
arg: complex | Number,
) -> tuple[Number, Number]: ...
def split_complex_arg(
arg: Tensor | ComplexTensor | complex | Number,
) -> tuple[Tensor, Tensor] | tuple[Number, Number]:
"""
Split a complex argument into a real/imaginary component.
If real, use zero for the imaginary part.
"""
if isinstance(arg, ComplexTensor):
return split_complex_tensor(arg)
if isinstance(arg, Tensor):
if is_complex(arg):
return arg.real, arg.imag
return arg, torch.zeros_like(arg)
# TODO (hameerabbasi): Should there be a `torch.SymComplex`?
if isinstance(arg, complex):
return arg.real, arg.imag
if isinstance(arg, float | torch.SymFloat):
return arg, 0.0
if isinstance(arg, int | torch.SymInt):
return arg, 0
if isinstance(arg, bool | torch.SymBool):
return arg, False
raise TypeError(f"Expected tensor or number got, {type(arg)}")
def split_complex_tensor(complex_tensor: ComplexTensor) -> tuple[Tensor, Tensor]:
"""Split a ComplexTensor into its real and imaginary parts."""
return complex_tensor.re, complex_tensor.im
def complex_to_real_dtype(dtype: torch.dtype) -> torch.dtype:
"""Convert a complex dtype to the dtype of its real part. Return other dtypes as-is."""
return COMPLEX_TO_REAL.get(dtype, dtype)
def _get_op_name(op: OpType) -> str:
"""Get the op name from the op."""
if isinstance(op, OpOverload):
op = op.overloadpacket
return str(op).split(".", 1)[1]
def _get_func_name(op: OpType) -> str:
"""Get the name of the implementation function from the op."""
return f"{_get_op_name(op)}_impl"
def register_error(op: OpType, exc_type: type[Exception] = NotImplementedError):
msg = f"`aten.{_get_op_name(op)}` not implemented for `{ComplexTensor.__name__}`."
def ordered_impl(*args, **kwargs):
raise exc_type(msg)
func_name = _get_func_name(op)
ordered_impl.__name__ = func_name
ordered_impl.__qualname__ = func_name
return register_force_test(op, ordered_impl)
def register_binary_nonlinear(op: OpType) -> Callable:
"""Register a "multiplication-style" op, e.g. aten.mul, aten.mm, ..."""
def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor:
a_r, a_i = split_complex_arg(lhs)
b_r, b_i = split_complex_arg(rhs)
out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i)
real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs)
imag = op(a_r, b_i, *args, **kwargs) + op(a_i, b_r, *args, **kwargs)
return ComplexTensor(real.to(out_dt), imag.to(out_dt))
func_name = _get_func_name(op)
impl.__name__ = func_name
impl.__qualname__ = func_name
return register_complex(op, impl)
def register_simple(op: OpType):
"""Register an op which can be applied independently to the real and complex parts to get the result."""
def impl(
self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs
) -> ComplexTensor:
x, y = split_complex_tensor(self)
if dtype is not None and dtype not in COMPLEX_TO_REAL:
raise RuntimeError(
"Non-complex `dtype` specified, please write custom impl."
)
if dtype in COMPLEX_TO_REAL:
assert dtype is not None
kwargs["dtype"] = COMPLEX_TO_REAL[dtype]
u = op(x, *args, **kwargs)
v = op(y, *args, **kwargs)
u_flat, u_spec = tree_flatten(u)
v_flat, v_spec = tree_flatten(v)
assert u_spec == v_spec
out_flat = [
ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False)
]
return tree_unflatten(out_flat, u_spec)
func_name = _get_func_name(op)
impl.__name__ = func_name
impl.__qualname__ = func_name
return register_complex(op, impl)
def _as_complex_tensor(arg: Tensor | Any) -> Tensor | ComplexTensor | Any:
"""Convert a Tensor with complex dtypes to a ComplexTensor. Pass along other args as-is."""
if (
not isinstance(arg, ComplexTensor)
and isinstance(arg, Tensor)
and arg.dtype in COMPLEX_TO_REAL
):
return ComplexTensor.from_interleaved(arg)
return arg
def _as_interleaved(arg: ComplexTensor | Any) -> Tensor | Any:
"""Convert a ComplexTensor to a Tensor with a complex dtype. Pass other arguments as-is."""
if isinstance(arg, ComplexTensor):
return arg.as_interleaved()
return arg
class ComplexTensorMode(TorchDispatchMode):
_compile: bool
""" A TorchDispatchMode to replace any Tensor that has a complex dtype with a ComplexTensor for the computation. """
def __init__(self, _dispatch_key=None, *, _compile: bool = False):
"""Initialize a ComplexTensorMode.
Args:
_dispatch_key: passed on to TorchDispatchMode
_compile: Compile the op before the computation
"""
super().__init__(_dispatch_key)
self._compile = _compile
def __torch_dispatch__(
self,
func: OpOverload,
types: tuple[type],
args: tuple = (),
kwargs: dict[str, Any] | None = None,
):
if kwargs is None:
kwargs = {}
# TODO (hameerabbasi): Test perf with `_compile` set to `True`
if self._compile:
func = torch.compile(func) # type: ignore[bad-assignment]
args = tree_map(_as_complex_tensor, args)
kwargs = tree_map(_as_complex_tensor, kwargs)
return tree_map(_as_interleaved, func(*args, **kwargs))

View File

@ -1,34 +0,0 @@
import torch
from .._core import ComplexTensor
from .common import (
complex_to_real_dtype,
register_complex,
register_force_test,
split_complex_tensor,
)
prims = torch.ops.prims
aten = torch.ops.aten
# TODO (hameerabbasi): Not being tested
@register_force_test(prims.convert_element_type)
def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor:
dtype = complex_to_real_dtype(dtype)
u, v = split_complex_tensor(x)
u_out = prims.convert_element_type(u, dtype)
v_out = prims.convert_element_type(v, dtype)
return ComplexTensor(u_out, v_out)
@register_complex(prims.conj_physical)
def conj_physical_impl(self: ComplexTensor) -> ComplexTensor:
return aten._conj_physical(self)
@register_complex(prims.conj)
def conj_impl(self: ComplexTensor) -> ComplexTensor:
return aten._conj(self)

View File

@ -492,17 +492,4 @@ void FileStore::wait(
}
}
std::vector<std::string> FileStore::listKeys() {
std::unique_lock<std::mutex> l(activeFileOpLock_);
File file(path_, O_RDONLY, timeout_);
auto lock = file.lockShared();
pos_ = refresh(file, pos_, cache_, deletePrefix_);
std::vector<std::string> keys;
keys.reserve(cache_.size());
for (const auto& kv : cache_) {
keys.push_back(kv.first.substr(regularPrefix_.size()));
}
return keys;
}
} // namespace c10d

View File

@ -45,8 +45,6 @@ class TORCH_API FileStore : public Store {
return path_;
}
std::vector<std::string> listKeys() override;
protected:
int64_t addHelper(const std::string& key, int64_t i);

View File

@ -217,14 +217,4 @@ int64_t HashStore::queueLen(const std::string& key) {
return static_cast<int64_t>(it->second.size());
}
std::vector<std::string> HashStore::listKeys() {
std::unique_lock<std::mutex> lock(m_);
std::vector<std::string> keys;
keys.reserve(map_.size());
for (const auto& kv : map_) {
keys.push_back(kv.first);
}
return keys;
}
} // namespace c10d

View File

@ -59,8 +59,6 @@ class TORCH_API HashStore : public Store {
int64_t queueLen(const std::string& key) override;
std::vector<std::string> listKeys() override;
protected:
bool checkLocked(
const std::unique_lock<std::mutex>& lock,

View File

@ -146,18 +146,4 @@ c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
return store;
}
std::vector<std::string> PrefixStore::listKeys() {
auto keys = store_->listKeys();
std::vector<std::string> filteredKeys;
filteredKeys.reserve(keys.size());
for (auto& key : keys) {
if (key.find(prefix_) == 0) {
key = key.substr(prefix_.size() + 1);
filteredKeys.push_back(std::move(key));
}
}
return filteredKeys;
}
} // namespace c10d

View File

@ -64,8 +64,6 @@ class TORCH_API PrefixStore : public Store {
// Recursively to fetch the store before layers of wrapping with PrefixStore.
c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore();
std::vector<std::string> listKeys() override;
protected:
std::string prefix_;
c10::intrusive_ptr<Store> store_;

View File

@ -114,11 +114,6 @@ class TORCH_API Store : public torch::CustomClassHolder {
C10_THROW_ERROR(NotImplementedError, "queue support is not implemented.");
}
virtual std::vector<std::string> listKeys() {
C10_THROW_ERROR(
NotImplementedError, "listKeys support is not implemented.");
}
protected:
std::chrono::milliseconds timeout_;
};

View File

@ -723,30 +723,6 @@ int64_t TCPStore::queueLen(const std::string& key) {
return client_->receiveValue<int64_t>();
}
std::vector<std::string> TCPStore::listKeys() {
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__list);
const std::lock_guard<std::mutex> lock(activeOpLock_);
detail::SendBuffer buffer(*client_, detail::QueryType::LIST_KEYS);
buffer.flush();
auto numKeys = client_->receiveValue<int64_t>();
std::vector<std::string> keys;
keys.reserve(numKeys);
for (auto i = 0; i < numKeys; ++i) {
auto bits = client_->receiveBits();
std::string str(bits.begin(), bits.end());
if (str.find(keyPrefix_) == 0) {
str = str.substr(keyPrefix_.size());
} else {
continue;
}
keys.emplace_back(str);
}
return keys;
}
bool TCPStore::hasExtendedApi() const {
return true;
}

View File

@ -121,8 +121,6 @@ class TORCH_API TCPStore : public Store {
int64_t queueLen(const std::string& key) override;
std::vector<std::string> listKeys() override;
// Waits for all workers to join.
void waitForWorkers();

View File

@ -78,7 +78,6 @@ class TCPStoreMasterDaemon : public BackgroundThread {
void multiGetHandler(int socket);
void multiSetHandler(int socket);
void cancelWaitHandler(int socket);
void listKeysHandler(int socket);
void addMiscellaneousSocket(int socket);
void removeMiscellaneousSocket(int socket);
bool isMiscellaneousSocket(int socket);
@ -296,8 +295,6 @@ void TCPStoreMasterDaemon::query(int socket) {
multiSetHandler(socket);
} else if (qt == QueryType::CANCEL_WAIT) {
cancelWaitHandler(socket);
} else if (qt == QueryType::LIST_KEYS) {
listKeysHandler(socket);
} else {
TORCH_CHECK(false, "Unexpected query type");
}
@ -485,13 +482,6 @@ void TCPStoreMasterDaemon::cancelWaitHandler(int socket) {
socket, detail::WaitResponseType::WAIT_CANCELED);
}
void TCPStoreMasterDaemon::listKeysHandler(int socket) {
tcputil::sendValue<size_t>(socket, tcpStore_.size());
for (const auto& kv : tcpStore_) {
tcputil::sendString(socket, kv.first);
}
}
bool TCPStoreMasterDaemon::checkKeys(
const std::vector<std::string>& keys) const {
return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) {

View File

@ -36,7 +36,6 @@ enum class QueryType : uint8_t {
QUEUE_PUSH,
QUEUE_POP,
QUEUE_LEN,
LIST_KEYS,
};
enum class CheckResponseType : uint8_t { READY, NOT_READY };

View File

@ -683,7 +683,6 @@ class LibUVStoreDaemon : public BackgroundThread {
const std::string& queueName,
const c10::intrusive_ptr<UvHandle>& client);
int64_t queueLen(const std::string& queueName);
std::vector<std::string> listKeys();
void registerClient(const c10::intrusive_ptr<UvHandle>& client);
void unregisterClient(const c10::intrusive_ptr<UvHandle>& client);
@ -823,10 +822,6 @@ class UvClient : public UvTcpSocket {
if (!parse_queue_len_command())
return;
break;
case QueryType::LIST_KEYS:
if (!parse_list_keys_command())
return;
break;
default:
C10D_DEBUG(
"Client sent invalid command. client:{} command:{}",
@ -1169,19 +1164,6 @@ class UvClient : public UvTcpSocket {
return true;
}
bool parse_list_keys_command() {
C10D_TRACE("list_keys address:{}", this->address());
auto keys = store->listKeys();
StreamWriter sw(iptr());
sw.write_value<int64_t>(static_cast<int64_t>(keys.size()));
for (const auto& key : keys) {
sw.write_string(key);
}
sw.send();
return true;
}
public:
explicit UvClient(uv_loop_t* loop, LibUVStoreDaemon* store)
: UvTcpSocket(loop), store(store) {}
@ -1560,16 +1542,6 @@ int64_t LibUVStoreDaemon::queueLen(const std::string& key) {
}
return static_cast<int64_t>(it->second.size());
}
std::vector<std::string> LibUVStoreDaemon::listKeys() {
std::vector<std::string> keys;
keys.reserve(tcpStore_.size());
for (const auto& kv : tcpStore_) {
keys.push_back(kv.first);
}
return keys;
}
#endif
std::unique_ptr<BackgroundThread> create_libuv_tcpstore_backend(

View File

@ -1656,12 +1656,6 @@ See queue_push for more details.
Arguments:
key (str): The key of the queue to get the length.
)")
.def(
"list_keys",
&::c10d::Store::listKeys,
R"(
Returns a list of all keys in the store.
)")
.def(
"has_extended_api",

View File

@ -26,18 +26,15 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding:
assert isinstance(input_spec, DTensorSpec)
assert isinstance(weight_spec, DTensorSpec)
# bias_spec can be None (optional parameter in aten.convolution schema)
if bias_spec is not None:
assert isinstance(bias_spec, DTensorSpec)
assert isinstance(bias_spec, DTensorSpec)
assert input_spec.tensor_meta is not None
assert weight_spec.tensor_meta is not None
in_shape = input_spec.tensor_meta.shape
weight_shape = weight_spec.tensor_meta.shape
assert isinstance(stride, list), f"stride must be list, got {type(stride)}"
assert isinstance(padding, list), f"padding must be list, got {type(padding)}"
assert isinstance(dilation, list), f"dilation must be list, got {type(dilation)}"
# weight_shape might not be torch.Size in all cases (e.g., SymIntArrayRef during tracing)
# so we don't assert its type, just use it
assert isinstance(stride, list)
assert isinstance(padding, list)
assert isinstance(dilation, list)
assert isinstance(weight_shape, torch.Size)
out_conv_shape = [
(d + 2 * padding[i] - dilation[i] * (weight_shape[i + 1] - 1) - 1) // stride[i]
+ 1
@ -85,21 +82,14 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
assert isinstance(grad_output_spec, DTensorSpec)
assert isinstance(input_spec, DTensorSpec)
assert isinstance(weight_spec, DTensorSpec)
# bias_shape_opt can be None (optional parameter in aten.convolution_backward schema)
if bias_shape_opt is not None:
assert isinstance(bias_shape_opt, list)
assert isinstance(bias_shape_opt, list)
assert input_spec.tensor_meta is not None
weight_tensor_meta = weight_spec.tensor_meta
# Only create bias_tensor_meta if bias_shape_opt is not None
if bias_shape_opt is not None:
bias_tensor_meta = TensorMeta(
torch.Size(bias_shape_opt),
(1,),
input_spec.tensor_meta.dtype,
)
else:
bias_tensor_meta = None
bias_tensor_meta = TensorMeta(
torch.Size(bias_shape_opt),
(1,),
input_spec.tensor_meta.dtype,
)
grad_input_spec = input_spec
grad_weight_spec = DTensorSpec.from_dim_map(
@ -108,18 +98,12 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
[0],
tensor_meta=weight_tensor_meta,
)
# Only create grad_bias_spec if we have bias_tensor_meta
if bias_tensor_meta is not None:
grad_bias_spec = DTensorSpec.from_dim_map(
input_spec.mesh,
[-1],
[0],
tensor_meta=bias_tensor_meta,
)
else:
grad_bias_spec = None
grad_bias_spec = DTensorSpec.from_dim_map(
input_spec.mesh,
[-1],
[0],
tensor_meta=bias_tensor_meta,
)
# TODO: actually the output_mask is not respected here, we should
# set the corresponding spec to `None` if the output_mask is not `False`
# for a certain output Tensor. This also applies to the conv handler

View File

@ -275,16 +275,14 @@ class ShardingPropagator:
output_tensor_meta_i = output_tensor_meta[i]
if not isinstance(output_tensor_meta_i, TensorMeta):
# NOTE: aten.convolution_backward.default is an exception and it
# needs extra handling because any Tensor in the output tuple
# can be `None` depending on the output_mask parameter. This can
# occur during double backpropagation or when certain gradients
# are not needed (e.g., grad_input when input has requires_grad=False,
# grad_weight/grad_bias when weight/bias have requires_grad=False,
# or grad_bias when bias is None). We explicitly allow the
# corresponding TensorMeta to be `None`.
# needs extra handling because the first Tensor in the output
# tuple can be `None` if the input Tensor to convolution op has
# `requires_grad=False` (e.g. convolution layer is the first
# layer in the model). We explicitly allow its corresponding
# TensorMeta to be `None`.
if (
op == aten.convolution_backward.default
and i in (0, 1, 2)
and i == 0
and output_tensor_meta_i is None
):
assert isinstance(output_specs, list)