Compare commits

..

3 Commits

Author SHA1 Message Date
ff57ecf4c8 nit 2025-11-17 22:47:11 -08:00
0193040a8e nit 2025-11-17 22:45:54 -08:00
7a899e06cb [DTensor] Fix deadlock after fast cache clear 2025-11-17 17:38:06 -08:00
315 changed files with 4891 additions and 8448 deletions

View File

@ -0,0 +1,19 @@
# Aarch64 (ARM/Graviton) Support Scripts
Scripts for building aarch64 PyTorch PIP Wheels. These scripts build the following wheels:
* torch
* torchvision
* torchaudio
* torchtext
* torchdata
## Aarch64_ci_build.sh
This script is design to support CD operations within PyPi manylinux aarch64 container, and be executed in the container. It prepares the container and then executes __aarch64_wheel_ci_build.py__ to build the wheels. The script "assumes" the PyTorch repo is located at: ```/pytorch``` and will put the wheels into ```/artifacts```.
### Usage
```DESIRED_PYTHON=<PythonVersion> aarch64_ci_build.sh```
__NOTE:__ CI build is currently __EXPERMINTAL__
## Build_aarch64_wheel.py
This app allows a person to build using AWS EC3 resources and requires AWS-CLI and Boto3 with AWS credentials to support building EC2 instances for the wheel builds. Can be used in a codebuild CD or from a local system.
### Usage
```build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch <RCtag>```

View File

@ -0,0 +1,53 @@
#!/bin/bash
set -eux -o pipefail
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
# Set CUDA architecture lists to match x86 build_cuda.sh
if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
elif [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
elif [[ "$GPU_ARCH_VERSION" == *"13.0"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;11.0;12.0+PTX"
fi
# Compress the fatbin with -compress-mode=size for CUDA 13
if [[ "$DESIRED_CUDA" == *"13"* ]]; then
export TORCH_NVCC_FLAGS="-compress-mode=size"
# Bundle ptxas into the cu13 wheel, see https://github.com/pytorch/pytorch/issues/163801
export BUILD_BUNDLE_PTXAS=1
fi
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
source $SCRIPTPATH/aarch64_ci_setup.sh
###############################################################################
# Run aarch64 builder python
###############################################################################
cd /
# adding safe directory for git as the permissions will be
# on the mounted pytorch repo
git config --global --add safe.directory /pytorch
pip install -r /pytorch/requirements.txt
pip install auditwheel==6.2.0 wheel
if [ "$DESIRED_CUDA" = "cpu" ]; then
echo "BASE_CUDA_VERSION is not set. Building cpu wheel."
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
else
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
export USE_SYSTEM_NCCL=1
# Check if we should use NVIDIA libs from PyPI (similar to x86 build_cuda.sh logic)
if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
echo "Bundling CUDA libraries with wheel for aarch64."
else
echo "Using nvidia libs from pypi for aarch64."
echo "Updated PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64: $PYTORCH_EXTRA_INSTALL_REQUIREMENTS"
export USE_NVIDIA_PYPI_LIBS=1
fi
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
fi

View File

@ -0,0 +1,21 @@
#!/bin/bash
set -eux -o pipefail
# This script is used to prepare the Docker container for aarch64_ci_wheel_build.py python script
# By creating symlinks from desired /opt/python to /usr/local/bin/
NUMPY_VERSION=2.0.2
if [[ "$DESIRED_PYTHON" == "3.13" || "$DESIRED_PYTHON" == "3.13t" ]]; then
NUMPY_VERSION=2.1.2
fi
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
source $SCRIPTPATH/../manywheel/set_desired_python.sh
pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2
for tool in python python3 pip pip3 ninja scons patchelf; do
ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin;
done
python --version

View File

@ -0,0 +1,333 @@
#!/usr/bin/env python3
# encoding: UTF-8
import os
import shutil
from subprocess import check_call, check_output
def list_dir(path: str) -> list[str]:
"""'
Helper for getting paths for Python
"""
return check_output(["ls", "-1", path]).decode().split("\n")
def replace_tag(filename) -> None:
with open(filename) as f:
lines = f.readlines()
for i, line in enumerate(lines):
if line.startswith("Tag:"):
lines[i] = line.replace("-linux_", "-manylinux_2_28_")
print(f"Updated tag from {line} to {lines[i]}")
break
with open(filename, "w") as f:
f.writelines(lines)
def patch_library_rpath(
folder: str,
lib_name: str,
use_nvidia_pypi_libs: bool = False,
desired_cuda: str = "",
) -> None:
"""Apply patchelf to set RPATH for a library in torch/lib"""
lib_path = f"{folder}/tmp/torch/lib/{lib_name}"
if use_nvidia_pypi_libs:
# For PyPI NVIDIA libraries, construct CUDA RPATH
cuda_rpaths = [
"$ORIGIN/../../nvidia/cudnn/lib",
"$ORIGIN/../../nvidia/nvshmem/lib",
"$ORIGIN/../../nvidia/nccl/lib",
"$ORIGIN/../../nvidia/cusparselt/lib",
]
if "130" in desired_cuda:
cuda_rpaths.append("$ORIGIN/../../nvidia/cu13/lib")
else:
cuda_rpaths.extend(
[
"$ORIGIN/../../nvidia/cublas/lib",
"$ORIGIN/../../nvidia/cuda_cupti/lib",
"$ORIGIN/../../nvidia/cuda_nvrtc/lib",
"$ORIGIN/../../nvidia/cuda_runtime/lib",
"$ORIGIN/../../nvidia/cufft/lib",
"$ORIGIN/../../nvidia/curand/lib",
"$ORIGIN/../../nvidia/cusolver/lib",
"$ORIGIN/../../nvidia/cusparse/lib",
"$ORIGIN/../../nvidia/nvtx/lib",
"$ORIGIN/../../nvidia/cufile/lib",
]
)
# Add $ORIGIN for local torch libs
rpath = ":".join(cuda_rpaths) + ":$ORIGIN"
else:
# For bundled libraries, just use $ORIGIN
rpath = "$ORIGIN"
if os.path.exists(lib_path):
os.system(
f"cd {folder}/tmp/torch/lib/; "
f"patchelf --set-rpath '{rpath}' --force-rpath {lib_name}"
)
def copy_and_patch_library(
src_path: str,
folder: str,
use_nvidia_pypi_libs: bool = False,
desired_cuda: str = "",
) -> None:
"""Copy a library to torch/lib and patch its RPATH"""
if os.path.exists(src_path):
lib_name = os.path.basename(src_path)
shutil.copy2(src_path, f"{folder}/tmp/torch/lib/{lib_name}")
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
def package_cuda_wheel(wheel_path, desired_cuda) -> None:
"""
Package the cuda wheel libraries
"""
folder = os.path.dirname(wheel_path)
os.mkdir(f"{folder}/tmp")
os.system(f"unzip {wheel_path} -d {folder}/tmp")
# Delete original wheel since it will be repackaged
os.system(f"rm {wheel_path}")
# Check if we should use PyPI NVIDIA libraries or bundle system libraries
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
if use_nvidia_pypi_libs:
print("Using nvidia libs from pypi - skipping CUDA library bundling")
# For PyPI approach, we don't bundle CUDA libraries - they come from PyPI packages
# We only need to bundle non-NVIDIA libraries
minimal_libs_to_copy = [
"/lib64/libgomp.so.1",
"/usr/lib64/libgfortran.so.5",
"/acl/build/libarm_compute.so",
"/acl/build/libarm_compute_graph.so",
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_lapack_core.so.0",
"/usr/local/lib/libnvpl_blas_core.so.0",
]
# Copy minimal libraries to unzipped_folder/torch/lib
for lib_path in minimal_libs_to_copy:
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
# Patch torch libraries used for searching libraries
torch_libs_to_patch = [
"libtorch.so",
"libtorch_cpu.so",
"libtorch_cuda.so",
"libtorch_cuda_linalg.so",
"libtorch_global_deps.so",
"libtorch_python.so",
"libtorch_nvshmem.so",
"libc10.so",
"libc10_cuda.so",
"libcaffe2_nvrtc.so",
"libshm.so",
]
for lib_name in torch_libs_to_patch:
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
else:
print("Bundling CUDA libraries with wheel")
# Original logic for bundling system CUDA libraries
# Common libraries for all CUDA versions
common_libs = [
# Non-NVIDIA system libraries
"/lib64/libgomp.so.1",
"/usr/lib64/libgfortran.so.5",
"/acl/build/libarm_compute.so",
"/acl/build/libarm_compute_graph.so",
# Common CUDA libraries (same for all versions)
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_lapack_core.so.0",
"/usr/local/lib/libnvpl_blas_core.so.0",
"/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so",
"/usr/local/cuda/lib64/libcudnn.so.9",
"/usr/local/cuda/lib64/libcusparseLt.so.0",
"/usr/local/cuda/lib64/libcurand.so.10",
"/usr/local/cuda/lib64/libnccl.so.2",
"/usr/local/cuda/lib64/libnvshmem_host.so.3",
"/usr/local/cuda/lib64/libcudnn_adv.so.9",
"/usr/local/cuda/lib64/libcudnn_cnn.so.9",
"/usr/local/cuda/lib64/libcudnn_graph.so.9",
"/usr/local/cuda/lib64/libcudnn_ops.so.9",
"/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9",
"/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9",
"/usr/local/cuda/lib64/libcudnn_heuristic.so.9",
"/usr/local/cuda/lib64/libcufile.so.0",
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
"/usr/local/cuda/lib64/libcusparse.so.12",
]
# CUDA version-specific libraries
if "13" in desired_cuda:
minor_version = desired_cuda[-1]
version_specific_libs = [
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13",
"/usr/local/cuda/lib64/libcublas.so.13",
"/usr/local/cuda/lib64/libcublasLt.so.13",
"/usr/local/cuda/lib64/libcudart.so.13",
"/usr/local/cuda/lib64/libcufft.so.12",
"/usr/local/cuda/lib64/libcusolver.so.12",
"/usr/local/cuda/lib64/libnvJitLink.so.13",
"/usr/local/cuda/lib64/libnvrtc.so.13",
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.{minor_version}",
]
elif "12" in desired_cuda:
# Get the last character for libnvrtc-builtins version (e.g., "129" -> "9")
minor_version = desired_cuda[-1]
version_specific_libs = [
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12",
"/usr/local/cuda/lib64/libcublas.so.12",
"/usr/local/cuda/lib64/libcublasLt.so.12",
"/usr/local/cuda/lib64/libcudart.so.12",
"/usr/local/cuda/lib64/libcufft.so.11",
"/usr/local/cuda/lib64/libcusolver.so.11",
"/usr/local/cuda/lib64/libnvJitLink.so.12",
"/usr/local/cuda/lib64/libnvrtc.so.12",
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.{minor_version}",
]
else:
raise ValueError(f"Unsupported CUDA version: {desired_cuda}.")
# Combine all libraries
libs_to_copy = common_libs + version_specific_libs
# Copy libraries to unzipped_folder/torch/lib
for lib_path in libs_to_copy:
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
# Make sure the wheel is tagged with manylinux_2_28
for f in os.scandir(f"{folder}/tmp/"):
if f.is_dir() and f.name.endswith(".dist-info"):
replace_tag(f"{f.path}/WHEEL")
break
os.system(f"wheel pack {folder}/tmp/ -d {folder}")
os.system(f"rm -rf {folder}/tmp/")
def complete_wheel(folder: str) -> str:
"""
Complete wheel build and put in artifact location
"""
wheel_name = list_dir(f"/{folder}/dist")[0]
# Please note for cuda we don't run auditwheel since we use custom script to package
# the cuda dependencies to the wheel file using update_wheel() method.
# However we need to make sure filename reflects the correct Manylinux platform.
if "pytorch" in folder and not enable_cuda:
print("Repairing Wheel with AuditWheel")
check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder)
repaired_wheel_name = list_dir(f"/{folder}/wheelhouse")[0]
print(f"Moving {repaired_wheel_name} wheel to /{folder}/dist")
os.rename(
f"/{folder}/wheelhouse/{repaired_wheel_name}",
f"/{folder}/dist/{repaired_wheel_name}",
)
else:
repaired_wheel_name = list_dir(f"/{folder}/dist")[0]
print(f"Copying {repaired_wheel_name} to artifacts")
shutil.copy2(
f"/{folder}/dist/{repaired_wheel_name}", f"/artifacts/{repaired_wheel_name}"
)
return repaired_wheel_name
def parse_arguments():
"""
Parse inline arguments
"""
from argparse import ArgumentParser
parser = ArgumentParser("AARCH64 wheels python CD")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")
parser.add_argument("--test-only", type=str)
parser.add_argument("--enable-mkldnn", action="store_true")
parser.add_argument("--enable-cuda", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
"""
Entry Point
"""
args = parse_arguments()
enable_mkldnn = args.enable_mkldnn
enable_cuda = args.enable_cuda
branch = check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd="/pytorch"
).decode()
print("Building PyTorch wheel")
build_vars = ""
# MAX_JOB=5 is not required for CPU backend (see commit 465d98b)
if enable_cuda:
build_vars += "MAX_JOBS=5 "
# Handle PyPI NVIDIA libraries vs bundled libraries
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
if use_nvidia_pypi_libs:
print("Configuring build for PyPI NVIDIA libraries")
# Configure for dynamic linking (matching x86 logic)
build_vars += "ATEN_STATIC_CUDA=0 USE_CUDA_STATIC_LINK=0 USE_CUPTI_SO=1 "
else:
print("Configuring build for bundled NVIDIA libraries")
# Keep existing static linking approach - already configured above
override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION")
desired_cuda = os.getenv("DESIRED_CUDA")
if override_package_version is not None:
version = override_package_version
build_vars += (
f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version} PYTORCH_BUILD_NUMBER=1 "
)
elif branch in ["nightly", "main"]:
build_date = (
check_output(["git", "log", "--pretty=format:%cs", "-1"], cwd="/pytorch")
.decode()
.replace("-", "")
)
version = (
check_output(["cat", "version.txt"], cwd="/pytorch").decode().strip()[:-2]
)
if enable_cuda:
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date}+{desired_cuda} PYTORCH_BUILD_NUMBER=1 "
else:
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 "
elif branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
if enable_mkldnn:
print("build pytorch with mkldnn+acl backend")
build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
build_vars += "ACL_ROOT_DIR=/acl "
if enable_cuda:
build_vars += "BLAS=NVPL "
else:
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS "
else:
print("build pytorch without mkldnn backend")
os.system(f"cd /pytorch; {build_vars} python3 -m build --wheel --no-isolation")
if enable_cuda:
print("Updating Cuda Dependency")
filename = os.listdir("/pytorch/dist/")
wheel_path = f"/pytorch/dist/{filename[0]}"
package_cuda_wheel(wheel_path, desired_cuda)
pytorch_wheel_name = complete_wheel("/pytorch/")
print(f"Build Complete. Created {pytorch_wheel_name}..")

View File

@ -0,0 +1,999 @@
#!/usr/bin/env python3
# This script is for building AARCH64 wheels using AWS EC2 instances.
# To generate binaries for the release follow these steps:
# 1. Update mappings for each of the Domain Libraries by adding new row to a table like this:
# "v1.11.0": ("0.11.0", "rc1"),
# 2. Run script with following arguments for each of the supported python versions and required tag, for example:
# build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch v1.11.0-rc3
import os
import subprocess
import sys
import time
from typing import Optional, Union
import boto3
# AMI images for us-east-1, change the following based on your ~/.aws/config
os_amis = {
"ubuntu20_04": "ami-052eac90edaa9d08f", # login_name: ubuntu
"ubuntu22_04": "ami-0c6c29c5125214c77", # login_name: ubuntu
"redhat8": "ami-0698b90665a2ddcf1", # login_name: ec2-user
}
ubuntu20_04_ami = os_amis["ubuntu20_04"]
def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]:
if key_name is None:
key_name = os.getenv("AWS_KEY_NAME")
if key_name is None:
return os.getenv("SSH_KEY_PATH", ""), ""
homedir_path = os.path.expanduser("~")
default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem")
return os.getenv("SSH_KEY_PATH", default_path), key_name
ec2 = boto3.resource("ec2")
def ec2_get_instances(filter_name, filter_value):
return ec2.instances.filter(
Filters=[{"Name": filter_name, "Values": [filter_value]}]
)
def ec2_instances_of_type(instance_type="t4g.2xlarge"):
return ec2_get_instances("instance-type", instance_type)
def ec2_instances_by_id(instance_id):
rc = list(ec2_get_instances("instance-id", instance_id))
return rc[0] if len(rc) > 0 else None
def start_instance(
key_name, ami=ubuntu20_04_ami, instance_type="t4g.2xlarge", ebs_size: int = 50
):
inst = ec2.create_instances(
ImageId=ami,
InstanceType=instance_type,
SecurityGroups=["ssh-allworld"],
KeyName=key_name,
MinCount=1,
MaxCount=1,
BlockDeviceMappings=[
{
"DeviceName": "/dev/sda1",
"Ebs": {
"DeleteOnTermination": True,
"VolumeSize": ebs_size,
"VolumeType": "standard",
},
}
],
)[0]
print(f"Create instance {inst.id}")
inst.wait_until_running()
running_inst = ec2_instances_by_id(inst.id)
print(f"Instance started at {running_inst.public_dns_name}")
return running_inst
class RemoteHost:
addr: str
keyfile_path: str
login_name: str
container_id: Optional[str] = None
ami: Optional[str] = None
def __init__(self, addr: str, keyfile_path: str, login_name: str = "ubuntu"):
self.addr = addr
self.keyfile_path = keyfile_path
self.login_name = login_name
def _gen_ssh_prefix(self) -> list[str]:
return [
"ssh",
"-o",
"StrictHostKeyChecking=no",
"-i",
self.keyfile_path,
f"{self.login_name}@{self.addr}",
"--",
]
@staticmethod
def _split_cmd(args: Union[str, list[str]]) -> list[str]:
return args.split() if isinstance(args, str) else args
def run_ssh_cmd(self, args: Union[str, list[str]]) -> None:
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
def check_ssh_output(self, args: Union[str, list[str]]) -> str:
return subprocess.check_output(
self._gen_ssh_prefix() + self._split_cmd(args)
).decode("utf-8")
def scp_upload_file(self, local_file: str, remote_file: str) -> None:
subprocess.check_call(
[
"scp",
"-i",
self.keyfile_path,
local_file,
f"{self.login_name}@{self.addr}:{remote_file}",
]
)
def scp_download_file(
self, remote_file: str, local_file: Optional[str] = None
) -> None:
if local_file is None:
local_file = "."
subprocess.check_call(
[
"scp",
"-i",
self.keyfile_path,
f"{self.login_name}@{self.addr}:{remote_file}",
local_file,
]
)
def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None:
self.run_ssh_cmd("sudo apt-get install -y docker.io")
self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}")
self.run_ssh_cmd("sudo service docker start")
self.run_ssh_cmd(f"docker pull {image}")
self.container_id = self.check_ssh_output(
f"docker run -t -d -w /root {image}"
).strip()
def using_docker(self) -> bool:
return self.container_id is not None
def run_cmd(self, args: Union[str, list[str]]) -> None:
if not self.using_docker():
return self.run_ssh_cmd(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + [
"docker",
"exec",
"-i",
self.container_id,
"bash",
]
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE)
p.communicate(
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
"utf-8"
)
)
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd)
def check_output(self, args: Union[str, list[str]]) -> str:
if not self.using_docker():
return self.check_ssh_output(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + [
"docker",
"exec",
"-i",
self.container_id,
"bash",
]
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
(out, err) = p.communicate(
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
"utf-8"
)
)
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err)
return out.decode("utf-8")
def upload_file(self, local_file: str, remote_file: str) -> None:
if not self.using_docker():
return self.scp_upload_file(local_file, remote_file)
tmp_file = os.path.join("/tmp", os.path.basename(local_file))
self.scp_upload_file(local_file, tmp_file)
self.run_ssh_cmd(
["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"]
)
self.run_ssh_cmd(["rm", tmp_file])
def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None:
if not self.using_docker():
return self.scp_download_file(remote_file, local_file)
tmp_file = os.path.join("/tmp", os.path.basename(remote_file))
self.run_ssh_cmd(
["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file]
)
self.scp_download_file(tmp_file, local_file)
self.run_ssh_cmd(["rm", tmp_file])
def download_wheel(
self, remote_file: str, local_file: Optional[str] = None
) -> None:
if self.using_docker() and local_file is None:
basename = os.path.basename(remote_file)
local_file = basename.replace(
"-linux_aarch64.whl", "-manylinux2014_aarch64.whl"
)
self.download_file(remote_file, local_file)
def list_dir(self, path: str) -> list[str]:
return self.check_output(["ls", "-1", path]).split("\n")
def wait_for_connection(addr, port, timeout=15, attempt_cnt=5):
import socket
for i in range(attempt_cnt):
try:
with socket.create_connection((addr, port), timeout=timeout):
return
except (ConnectionRefusedError, TimeoutError): # noqa: PERF203
if i == attempt_cnt - 1:
raise
time.sleep(timeout)
def update_apt_repo(host: RemoteHost) -> None:
time.sleep(5)
host.run_cmd("sudo systemctl stop apt-daily.service || true")
host.run_cmd("sudo systemctl stop unattended-upgrades.service || true")
host.run_cmd(
"while systemctl is-active --quiet apt-daily.service; do sleep 1; done"
)
host.run_cmd(
"while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done"
)
host.run_cmd("sudo apt-get update")
time.sleep(3)
host.run_cmd("sudo apt-get update")
def install_condaforge(
host: RemoteHost, suffix: str = "latest/download/Miniforge3-Linux-aarch64.sh"
) -> None:
print("Install conda-forge")
host.run_cmd(f"curl -OL https://github.com/conda-forge/miniforge/releases/{suffix}")
host.run_cmd(f"sh -f {os.path.basename(suffix)} -b")
host.run_cmd(f"rm -f {os.path.basename(suffix)}")
if host.using_docker():
host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc")
else:
host.run_cmd(
[
"sed",
"-i",
"'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH'",
".bashrc",
]
)
def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None:
if python_version == "3.6":
# Python-3.6 EOLed and not compatible with conda-4.11
install_condaforge(
host, suffix="download/4.10.3-10/Miniforge3-4.10.3-10-Linux-aarch64.sh"
)
host.run_cmd(f"conda install -y python={python_version} numpy pyyaml")
else:
install_condaforge(
host, suffix="download/4.11.0-4/Miniforge3-4.11.0-4-Linux-aarch64.sh"
)
# Pytorch-1.10 or older are not compatible with setuptools=59.6 or newer
host.run_cmd(
f"conda install -y python={python_version} numpy pyyaml setuptools>=59.5.0"
)
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
host.run_cmd("pip3 install auditwheel")
host.run_cmd(
"conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf"
)
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as tmp:
tmp.write(embed_library_script.encode("utf-8"))
tmp.flush()
host.upload_file(tmp.name, "embed_library.py")
print("Embedding libgomp into wheel")
if host.using_docker():
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
else:
host.run_cmd(f"python3 embed_library.py {wheel_name}")
def checkout_repo(
host: RemoteHost,
*,
branch: str = "main",
url: str,
git_clone_flags: str,
mapping: dict[str, tuple[str, str]],
) -> Optional[str]:
for prefix in mapping:
if not branch.startswith(prefix):
continue
tag = f"v{mapping[prefix][0]}-{mapping[prefix][1]}"
host.run_cmd(f"git clone {url} -b {tag} {git_clone_flags}")
return mapping[prefix][0]
host.run_cmd(f"git clone {url} -b {branch} {git_clone_flags}")
return None
def build_torchvision(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str,
run_smoke_tests: bool = True,
) -> str:
print("Checking out TorchVision repo")
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/vision",
git_clone_flags=git_clone_flags,
mapping={
"v1.7.1": ("0.8.2", "rc2"),
"v1.8.0": ("0.9.0", "rc3"),
"v1.8.1": ("0.9.1", "rc1"),
"v1.9.0": ("0.10.0", "rc1"),
"v1.10.0": ("0.11.1", "rc1"),
"v1.10.1": ("0.11.2", "rc1"),
"v1.10.2": ("0.11.3", "rc1"),
"v1.11.0": ("0.12.0", "rc1"),
"v1.12.0": ("0.13.0", "rc4"),
"v1.12.1": ("0.13.1", "rc6"),
"v1.13.0": ("0.14.0", "rc4"),
"v1.13.1": ("0.14.1", "rc2"),
"v2.0.0": ("0.15.1", "rc2"),
"v2.0.1": ("0.15.2", "rc2"),
},
)
print("Building TorchVision wheel")
# Please note libnpg and jpeg are required to build image.so extension
if use_conda:
host.run_cmd("conda install -y libpng jpeg")
# Remove .so files to force static linking
host.run_cmd(
"rm miniforge3/lib/libpng.so miniforge3/lib/libpng16.so miniforge3/lib/libjpeg.so"
)
# And patch setup.py to include libz dependency for libpng
host.run_cmd(
[
'sed -i -e \'s/image_link_flags\\.append("png")/image_link_flags += ["png", "z"]/\' vision/setup.py'
]
)
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"]
).strip()
if len(version) == 0:
# In older revisions, version was embedded in setup.py
version = (
host.check_output(["grep", '"version = \'"', "vision/setup.py"])
.strip()
.split("'")[1][:-2]
)
build_date = (
host.check_output("cd vision && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd vision && {build_vars} python3 -m build --wheel --no-isolation")
vision_wheel_name = host.list_dir("vision/dist")[0]
embed_libgomp(host, use_conda, os.path.join("vision", "dist", vision_wheel_name))
print("Copying TorchVision wheel")
host.download_wheel(os.path.join("vision", "dist", vision_wheel_name))
if run_smoke_tests:
host.run_cmd(
f"pip3 install {os.path.join('vision', 'dist', vision_wheel_name)}"
)
host.run_cmd("python3 vision/test/smoke_test.py")
print("Delete vision checkout")
host.run_cmd("rm -rf vision")
return vision_wheel_name
def build_torchdata(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchData repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/data",
git_clone_flags=git_clone_flags,
mapping={
"v1.13.1": ("0.5.1", ""),
"v2.0.0": ("0.6.0", "rc5"),
"v2.0.1": ("0.6.1", "rc1"),
},
)
print("Building TorchData wheel")
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f data/version.txt ]; then cat data/version.txt; fi"]
).strip()
build_date = (
host.check_output("cd data && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd data && {build_vars} python3 -m build --wheel --no-isolation")
wheel_name = host.list_dir("data/dist")[0]
embed_libgomp(host, use_conda, os.path.join("data", "dist", wheel_name))
print("Copying TorchData wheel")
host.download_wheel(os.path.join("data", "dist", wheel_name))
return wheel_name
def build_torchtext(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchText repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/text",
git_clone_flags=git_clone_flags,
mapping={
"v1.9.0": ("0.10.0", "rc1"),
"v1.10.0": ("0.11.0", "rc2"),
"v1.10.1": ("0.11.1", "rc1"),
"v1.10.2": ("0.11.2", "rc1"),
"v1.11.0": ("0.12.0", "rc1"),
"v1.12.0": ("0.13.0", "rc2"),
"v1.12.1": ("0.13.1", "rc5"),
"v1.13.0": ("0.14.0", "rc3"),
"v1.13.1": ("0.14.1", "rc1"),
"v2.0.0": ("0.15.1", "rc2"),
"v2.0.1": ("0.15.2", "rc2"),
},
)
print("Building TorchText wheel")
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f text/version.txt ]; then cat text/version.txt; fi"]
).strip()
build_date = (
host.check_output("cd text && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd text && {build_vars} python3 -m build --wheel --no-isolation")
wheel_name = host.list_dir("text/dist")[0]
embed_libgomp(host, use_conda, os.path.join("text", "dist", wheel_name))
print("Copying TorchText wheel")
host.download_wheel(os.path.join("text", "dist", wheel_name))
return wheel_name
def build_torchaudio(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchAudio repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/audio",
git_clone_flags=git_clone_flags,
mapping={
"v1.9.0": ("0.9.0", "rc2"),
"v1.10.0": ("0.10.0", "rc5"),
"v1.10.1": ("0.10.1", "rc1"),
"v1.10.2": ("0.10.2", "rc1"),
"v1.11.0": ("0.11.0", "rc1"),
"v1.12.0": ("0.12.0", "rc3"),
"v1.12.1": ("0.12.1", "rc5"),
"v1.13.0": ("0.13.0", "rc4"),
"v1.13.1": ("0.13.1", "rc2"),
"v2.0.0": ("2.0.1", "rc3"),
"v2.0.1": ("2.0.2", "rc2"),
},
)
print("Building TorchAudio wheel")
build_vars = ""
if branch == "nightly":
version = (
host.check_output(["grep", '"version = \'"', "audio/setup.py"])
.strip()
.split("'")[1][:-2]
)
build_date = (
host.check_output("cd audio && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(
f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \
&& ./packaging/ffmpeg/build.sh \
&& {build_vars} python3 -m build --wheel --no-isolation"
)
wheel_name = host.list_dir("audio/dist")[0]
embed_libgomp(host, use_conda, os.path.join("audio", "dist", wheel_name))
print("Copying TorchAudio wheel")
host.download_wheel(os.path.join("audio", "dist", wheel_name))
return wheel_name
def configure_system(
host: RemoteHost,
*,
compiler: str = "gcc-8",
use_conda: bool = True,
python_version: str = "3.8",
) -> None:
if use_conda:
install_condaforge_python(host, python_version)
print("Configuring the system")
if not host.using_docker():
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip")
else:
host.run_cmd("yum install -y sudo")
host.run_cmd("conda install -y ninja scons")
if not use_conda:
host.run_cmd(
"sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip"
)
host.run_cmd("pip3 install dataclasses typing-extensions")
if not use_conda:
print("Installing Cython + numpy from PyPy")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
def build_domains(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> tuple[str, str, str, str]:
vision_wheel_name = build_torchvision(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
audio_wheel_name = build_torchaudio(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
data_wheel_name = build_torchdata(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
text_wheel_name = build_torchtext(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
return (vision_wheel_name, audio_wheel_name, data_wheel_name, text_wheel_name)
def start_build(
host: RemoteHost,
*,
branch: str = "main",
compiler: str = "gcc-8",
use_conda: bool = True,
python_version: str = "3.8",
pytorch_only: bool = False,
pytorch_build_number: Optional[str] = None,
shallow_clone: bool = True,
enable_mkldnn: bool = False,
) -> tuple[str, str, str, str, str]:
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
if host.using_docker() and not use_conda:
print("Auto-selecting conda option for docker images")
use_conda = True
if not host.using_docker():
print("Disable mkldnn for host builds")
enable_mkldnn = False
configure_system(
host, compiler=compiler, use_conda=use_conda, python_version=python_version
)
if host.using_docker():
print("Move libgfortant.a into a standard location")
# HACK: pypa gforntran.a is compiled without PIC, which leads to the following error
# libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17' # noqa: E501, B950
# Workaround by copying gfortran library from the host
host.run_ssh_cmd("sudo apt-get install -y gfortran-8")
host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8")
host.run_ssh_cmd(
[
"docker",
"cp",
"/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a",
f"{host.container_id}:/opt/rh/devtoolset-10/root/usr/lib/gcc/aarch64-redhat-linux/10/",
]
)
print("Checking out PyTorch repo")
host.run_cmd(
f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}"
)
host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh")
print("Building PyTorch wheel")
build_opts = ""
if pytorch_build_number is not None:
build_opts += f" -C--build-option=--build-number={pytorch_build_number}"
# Breakpad build fails on aarch64
build_vars = "USE_BREAKPAD=0 "
if branch == "nightly":
build_date = (
host.check_output("cd pytorch && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
version = host.check_output("cat pytorch/version.txt").strip()[:-2]
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1"
if branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
if enable_mkldnn:
host.run_cmd("pytorch/.ci/docker/common/install_acl.sh")
print("build pytorch with mkldnn+acl backend")
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
build_vars += " BLAS=OpenBLAS"
build_vars += " OpenBLAS_HOME=/opt/OpenBLAS"
build_vars += " ACL_ROOT_DIR=/acl"
host.run_cmd(
f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
)
print("Repair the wheel")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
ld_library_path = "/acl/build:$HOME/pytorch/build/lib"
host.run_cmd(
f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}"
)
print("replace the original wheel with the repaired one")
pytorch_repaired_wheel_name = host.list_dir("wheelhouse")[0]
host.run_cmd(
f"cp $HOME/wheelhouse/{pytorch_repaired_wheel_name} $HOME/pytorch/dist/{pytorch_wheel_name}"
)
else:
print("build pytorch without mkldnn backend")
host.run_cmd(
f"cd pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
)
print("Deleting build folder")
host.run_cmd("cd pytorch && rm -rf build")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
embed_libgomp(host, use_conda, os.path.join("pytorch", "dist", pytorch_wheel_name))
print("Copying the wheel")
host.download_wheel(os.path.join("pytorch", "dist", pytorch_wheel_name))
print("Installing PyTorch wheel")
host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}")
if pytorch_only:
return (pytorch_wheel_name, None, None, None, None)
domain_wheels = build_domains(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
return (pytorch_wheel_name, *domain_wheels)
embed_library_script = """
#!/usr/bin/env python3
from auditwheel.patcher import Patchelf
from auditwheel.wheeltools import InWheelCtx
from auditwheel.elfutils import elf_file_filter
from auditwheel.repair import copylib
from auditwheel.lddtree import lddtree
from subprocess import check_call
import os
import shutil
import sys
from tempfile import TemporaryDirectory
def replace_tag(filename):
with open(filename, 'r') as f:
lines = f.read().split("\\n")
for i,line in enumerate(lines):
if not line.startswith("Tag: "):
continue
lines[i] = line.replace("-linux_", "-manylinux2014_")
print(f'Updated tag from {line} to {lines[i]}')
with open(filename, 'w') as f:
f.write("\\n".join(lines))
class AlignedPatchelf(Patchelf):
def set_soname(self, file_name: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name])
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name])
def embed_library(whl_path, lib_soname, update_tag=False):
patcher = AlignedPatchelf()
out_dir = TemporaryDirectory()
whl_name = os.path.basename(whl_path)
tmp_whl_name = os.path.join(out_dir.name, whl_name)
with InWheelCtx(whl_path) as ctx:
torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib')
ctx.out_wheel=tmp_whl_name
new_lib_path, new_lib_soname = None, None
for filename, elf in elf_file_filter(ctx.iter_files()):
if not filename.startswith('torch/lib'):
continue
libtree = lddtree(filename)
if lib_soname not in libtree['needed']:
continue
lib_path = libtree['libs'][lib_soname]['path']
if lib_path is None:
print(f"Can't embed {lib_soname} as it could not be found")
break
if lib_path.startswith(torchlib_path):
continue
if new_lib_path is None:
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
patcher.replace_needed(filename, lib_soname, new_lib_soname)
print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}')
if update_tag:
# Add manylinux2014 tag
for filename in ctx.iter_files():
if os.path.basename(filename) != 'WHEEL':
continue
replace_tag(filename)
shutil.move(tmp_whl_name, whl_path)
if __name__ == '__main__':
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
"""
def run_tests(host: RemoteHost, whl: str, branch="main") -> None:
print("Configuring the system")
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y python3-pip git")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
host.upload_file(whl, ".")
host.run_cmd(f"sudo pip3 install {whl}")
host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'")
host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch")
host.run_cmd("cd pytorch/test; python3 test_torch.py -v")
def get_instance_name(instance) -> Optional[str]:
if instance.tags is None:
return None
for tag in instance.tags:
if tag["Key"] == "Name":
return tag["Value"]
return None
def list_instances(instance_type: str) -> None:
print(f"All instances of type {instance_type}")
for instance in ec2_instances_of_type(instance_type):
ifaces = instance.network_interfaces
az = ifaces[0].subnet.availability_zone if len(ifaces) > 0 else None
print(
f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']} {az}"
)
def terminate_instances(instance_type: str) -> None:
print(f"Terminating all instances of type {instance_type}")
instances = list(ec2_instances_of_type(instance_type))
for instance in instances:
print(f"Terminating {instance.id}")
instance.terminate()
print("Waiting for termination to complete")
for instance in instances:
instance.wait_until_terminated()
def parse_arguments():
from argparse import ArgumentParser
parser = ArgumentParser("Build and test AARCH64 wheels using EC2")
parser.add_argument("--key-name", type=str)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")
parser.add_argument("--test-only", type=str)
group = parser.add_mutually_exclusive_group()
group.add_argument("--os", type=str, choices=list(os_amis.keys()))
group.add_argument("--ami", type=str)
parser.add_argument(
"--python-version",
type=str,
choices=[f"3.{d}" for d in range(6, 12)],
default=None,
)
parser.add_argument("--alloc-instance", action="store_true")
parser.add_argument("--list-instances", action="store_true")
parser.add_argument("--pytorch-only", action="store_true")
parser.add_argument("--keep-running", action="store_true")
parser.add_argument("--terminate-instances", action="store_true")
parser.add_argument("--instance-type", type=str, default="t4g.2xlarge")
parser.add_argument("--ebs-size", type=int, default=50)
parser.add_argument("--branch", type=str, default="main")
parser.add_argument("--use-docker", action="store_true")
parser.add_argument(
"--compiler",
type=str,
choices=["gcc-7", "gcc-8", "gcc-9", "clang"],
default="gcc-8",
)
parser.add_argument("--use-torch-from-pypi", action="store_true")
parser.add_argument("--pytorch-build-number", type=str, default=None)
parser.add_argument("--disable-mkldnn", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
ami = (
args.ami
if args.ami is not None
else os_amis[args.os]
if args.os is not None
else ubuntu20_04_ami
)
keyfile_path, key_name = compute_keyfile_path(args.key_name)
if args.list_instances:
list_instances(args.instance_type)
sys.exit(0)
if args.terminate_instances:
terminate_instances(args.instance_type)
sys.exit(0)
if len(key_name) == 0:
raise RuntimeError("""
Cannot start build without key_name, please specify
--key-name argument or AWS_KEY_NAME environment variable.""")
if len(keyfile_path) == 0 or not os.path.exists(keyfile_path):
raise RuntimeError(f"""
Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please
check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""")
# Starting the instance
inst = start_instance(
key_name, ami=ami, instance_type=args.instance_type, ebs_size=args.ebs_size
)
instance_name = f"{args.key_name}-{args.os}"
if args.python_version is not None:
instance_name += f"-py{args.python_version}"
inst.create_tags(
DryRun=False,
Tags=[
{
"Key": "Name",
"Value": instance_name,
}
],
)
addr = inst.public_dns_name
wait_for_connection(addr, 22)
host = RemoteHost(addr, keyfile_path)
host.ami = ami
if args.use_docker:
update_apt_repo(host)
host.start_docker()
if args.test_only:
run_tests(host, args.test_only)
sys.exit(0)
if args.alloc_instance:
if args.python_version is None:
sys.exit(0)
install_condaforge_python(host, args.python_version)
sys.exit(0)
python_version = args.python_version if args.python_version is not None else "3.10"
if args.use_torch_from_pypi:
configure_system(host, compiler=args.compiler, python_version=python_version)
print("Installing PyTorch wheel")
host.run_cmd("pip3 install torch")
build_domains(
host, branch=args.branch, git_clone_flags=" --depth 1 --shallow-submodules"
)
else:
start_build(
host,
branch=args.branch,
compiler=args.compiler,
python_version=python_version,
pytorch_only=args.pytorch_only,
pytorch_build_number=args.pytorch_build_number,
enable_mkldnn=not args.disable_mkldnn,
)
if not args.keep_running:
print(f"Waiting for instance {inst.id} to terminate")
inst.terminate()
inst.wait_until_terminated()

View File

@ -0,0 +1,87 @@
#!/usr/bin/env python3
import os
import shutil
import sys
from subprocess import check_call
from tempfile import TemporaryDirectory
from auditwheel.elfutils import elf_file_filter
from auditwheel.lddtree import lddtree
from auditwheel.patcher import Patchelf
from auditwheel.repair import copylib
from auditwheel.wheeltools import InWheelCtx
def replace_tag(filename):
with open(filename) as f:
lines = f.read().split("\\n")
for i, line in enumerate(lines):
if not line.startswith("Tag: "):
continue
lines[i] = line.replace("-linux_", "-manylinux2014_")
print(f"Updated tag from {line} to {lines[i]}")
with open(filename, "w") as f:
f.write("\\n".join(lines))
class AlignedPatchelf(Patchelf):
def set_soname(self, file_name: str, new_soname: str) -> None:
check_call(
["patchelf", "--page-size", "65536", "--set-soname", new_soname, file_name]
)
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
check_call(
[
"patchelf",
"--page-size",
"65536",
"--replace-needed",
soname,
new_soname,
file_name,
]
)
def embed_library(whl_path, lib_soname, update_tag=False):
patcher = AlignedPatchelf()
out_dir = TemporaryDirectory()
whl_name = os.path.basename(whl_path)
tmp_whl_name = os.path.join(out_dir.name, whl_name)
with InWheelCtx(whl_path) as ctx:
torchlib_path = os.path.join(ctx._tmpdir.name, "torch", "lib")
ctx.out_wheel = tmp_whl_name
new_lib_path, new_lib_soname = None, None
for filename, _ in elf_file_filter(ctx.iter_files()):
if not filename.startswith("torch/lib"):
continue
libtree = lddtree(filename)
if lib_soname not in libtree["needed"]:
continue
lib_path = libtree["libs"][lib_soname]["path"]
if lib_path is None:
print(f"Can't embed {lib_soname} as it could not be found")
break
if lib_path.startswith(torchlib_path):
continue
if new_lib_path is None:
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
patcher.replace_needed(filename, lib_soname, new_lib_soname)
print(f"Replacing {lib_soname} with {new_lib_soname} for {filename}")
if update_tag:
# Add manylinux2014 tag
for filename in ctx.iter_files():
if os.path.basename(filename) != "WHEEL":
continue
replace_tag(filename)
shutil.move(tmp_whl_name, whl_path)
if __name__ == "__main__":
embed_library(
sys.argv[1], "libgomp.so.1", len(sys.argv) > 2 and sys.argv[2] == "--update-tag"
)

View File

@ -4,17 +4,14 @@ set -ex
SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Source the common build script for architecture-specific configurations (MKLDNN, ACL, etc.)
source "${SCRIPTPATH}/../pytorch/build.sh" || true
case "${GPU_ARCH_TYPE:-BLANK}" in
cuda | cuda-aarch64)
cuda)
bash "${SCRIPTPATH}/build_cuda.sh"
;;
rocm)
bash "${SCRIPTPATH}/build_rocm.sh"
;;
cpu | cpu-cxx11-abi | cpu-aarch64 | cpu-s390x)
cpu | cpu-cxx11-abi | cpu-s390x)
bash "${SCRIPTPATH}/build_cpu.sh"
;;
xpu)

View File

@ -18,31 +18,12 @@ retry () {
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
}
# Detect architecture first
ARCH=$(uname -m)
echo "Detected architecture: $ARCH"
PLATFORM=""
# TODO move this into the Docker images
OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release)
if [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
retry yum install -q -y zip openssl
# Set platform based on architecture
case $ARCH in
x86_64)
PLATFORM="manylinux_2_28_x86_64"
;;
aarch64)
PLATFORM="manylinux_2_28_aarch64"
;;
s390x)
PLATFORM="manylinux_2_28_s390x"
;;
*)
echo "Unsupported architecture: $ARCH"
exit 1
;;
esac
PLATFORM="manylinux_2_28_x86_64"
elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
retry dnf install -q -y zip openssl
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
@ -57,8 +38,6 @@ else
exit 1
fi
echo "Platform set to: $PLATFORM"
# We use the package name to test the package by passing this to 'pip install'
# This is the env variable that setup.py uses to name the package. Note that
# pip 'normalizes' the name first by changing all - to _
@ -320,8 +299,8 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
# ROCm workaround for roctracer dlopens
if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then
patchedpath=$(fname_without_so_number $destpath)
# Keep the so number for XPU dependencies, libgomp.so.1, ACL libraries, and NVPL libraries to avoid twice load
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" || "$filename" == libarm_compute* || "$filename" == libnvpl* || "$filename" == "libgfortran.so.5" ]]; then
# Keep the so number for XPU dependencies and libgomp.so.1 to avoid twice load
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" ]]; then
patchedpath=$destpath
else
patchedpath=$(fname_with_sha256 $destpath)
@ -367,22 +346,9 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
done
# create Manylinux 2_28 tag this needs to happen before regenerate the RECORD
# Support all architectures (x86_64, aarch64, s390x)
if [[ "$IS_MANYLINUX2_28" == "1" && $GPU_ARCH_TYPE != "xpu" ]]; then
if [[ $PLATFORM == "manylinux_2_28_x86_64" && $GPU_ARCH_TYPE != "cpu-s390x" && $GPU_ARCH_TYPE != "xpu" ]]; then
wheel_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/WHEEL/g')
echo "Updating wheel tag for $ARCH architecture"
# Replace linux_* with manylinux_2_28_* based on architecture
case $ARCH in
x86_64)
sed -i -e 's#linux_x86_64#manylinux_2_28_x86_64#g' $wheel_file
;;
aarch64)
sed -i -e 's#linux_aarch64#manylinux_2_28_aarch64#g' $wheel_file
;;
s390x)
sed -i -e 's#linux_s390x#manylinux_2_28_s390x#g' $wheel_file
;;
esac
sed -i -e s#linux_x86_64#"${PLATFORM}"# $wheel_file;
fi
# regenerate the RECORD file with new hashes

View File

@ -15,10 +15,6 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
EXTRA_CAFFE2_CMAKE_FLAGS=()
fi
# Detect architecture
ARCH=$(uname -m)
echo "Building CPU wheel for architecture: $ARCH"
WHEELHOUSE_DIR="wheelhousecpu"
LIBTORCH_HOUSE_DIR="libtorch_housecpu"
if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then
@ -38,10 +34,8 @@ elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
LIBGOMP_PATH="/usr/lib64/libgomp.so.1"
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
if [[ "$ARCH" == "s390x" ]]; then
if [[ "$(uname -m)" == "s390x" ]]; then
LIBGOMP_PATH="/usr/lib/s390x-linux-gnu/libgomp.so.1"
elif [[ "$ARCH" == "aarch64" ]]; then
LIBGOMP_PATH="/usr/lib/aarch64-linux-gnu/libgomp.so.1"
else
LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1"
fi
@ -55,34 +49,6 @@ DEPS_SONAME=(
"libgomp.so.1"
)
# Add ARM-specific library dependencies for CPU builds
if [[ "$ARCH" == "aarch64" ]]; then
echo "Adding ARM-specific CPU library dependencies"
# ARM Compute Library (if available)
if [[ -d "/acl/build" ]]; then
echo "Adding ARM Compute Library for CPU"
DEPS_LIST+=(
"/acl/build/libarm_compute.so"
"/acl/build/libarm_compute_graph.so"
)
DEPS_SONAME+=(
"libarm_compute.so"
"libarm_compute_graph.so"
)
fi
# ARM system libraries
DEPS_LIST+=(
"/usr/lib64/libgfortran.so.5"
"/opt/OpenBLAS/lib/libopenblas.so.0"
)
DEPS_SONAME+=(
"libgfortran.so.5"
"libopenblas.so.0"
)
fi
rm -rf /usr/local/cuda*
SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )"

View File

@ -29,10 +29,6 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
EXTRA_CAFFE2_CMAKE_FLAGS=()
fi
# Detect architecture
ARCH=$(uname -m)
echo "Building for architecture: $ARCH"
# Determine CUDA version and architectures to build for
#
# NOTE: We should first check `DESIRED_CUDA` when determining `CUDA_VERSION`,
@ -57,60 +53,34 @@ fi
cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
# Function to remove architectures from a list
remove_archs() {
local result="$1"
shift
for arch in "$@"; do
result="${result//${arch};/}"
done
echo "$result"
}
# Function to filter CUDA architectures for aarch64
# aarch64 ARM GPUs only support certain compute capabilities
# Keep: 8.0 (A100), 9.0+ (Hopper, Grace Hopper, newer)
# Remove: < 8.0 (no ARM GPUs), 8.6 (x86_64 RTX 3090/A6000 only)
filter_aarch64_archs() {
local arch_list="$1"
# Explicitly remove architectures not needed on aarch64
arch_list=$(remove_archs "$arch_list" "5.0" "6.0" "7.0" "7.5" "8.6")
echo "$arch_list"
}
# Base: Common architectures across all modern CUDA versions
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0"
case ${CUDA_VERSION} in
12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;${TORCH_CUDA_ARCH_LIST}" ;; # Only 12.6 includes Legacy Maxwell/Pascal that will be removed in future releases
12.8) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0" ;; # +Hopper/Blackwell support
12.9) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0+PTX" # +Hopper/Blackwell support + PTX for forward compatibility
#removing sm_50-sm_60 as these architectures are deprecated in CUDA 12.8/9 and will be removed in future releases
#however we would like to keep sm_70 architecture see: https://github.com/pytorch/pytorch/issues/157517
12.8)
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0"
;;
12.9)
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0+PTX"
# WAR to resolve the ld error in libtorch build with CUDA 12.9
if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//7.0;/}" # Remove 7.0 to resolve the ld error
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//8.6;/}" # Remove 8.6 for libtorch
TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX"
fi
;;
13.0)
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;$([[ "$ARCH" == "aarch64" ]] && echo "11.0;" || echo "")12.0+PTX"
export TORCH_NVCC_FLAGS="-compress-mode=size"
export BUILD_BUNDLE_PTXAS=1
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX"
;;
12.6)
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0"
;;
*)
echo "unknown cuda version $CUDA_VERSION"
exit 1
;;
*) echo "unknown cuda version $CUDA_VERSION"; exit 1 ;;
esac
# Filter for aarch64: Remove < 8.0 and 8.6
[[ "$ARCH" == "aarch64" ]] && TORCH_CUDA_ARCH_LIST=$(filter_aarch64_archs "$TORCH_CUDA_ARCH_LIST")
echo "TORCH_CUDA_ARCH_LIST set to: $TORCH_CUDA_ARCH_LIST"
export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}
echo "${TORCH_CUDA_ARCH_LIST}"
# Disable MAGMA for aarch64 as pre-built libraries are x86-64 only
if [[ "$ARCH" == "aarch64" ]]; then
echo "Disabling MAGMA for aarch64 architecture"
export USE_MAGMA=0
fi
# Package directories
WHEELHOUSE_DIR="wheelhouse$cuda_version_nodot"
LIBTORCH_HOUSE_DIR="libtorch_house$cuda_version_nodot"
@ -274,51 +244,6 @@ else
exit 1
fi
# Add ARM-specific library dependencies
if [[ "$ARCH" == "aarch64" ]]; then
echo "Adding ARM-specific library dependencies"
# ARM Compute Library (if available)
if [[ -d "/acl/build" ]]; then
echo "Adding ARM Compute Library"
DEPS_LIST+=(
"/acl/build/libarm_compute.so"
"/acl/build/libarm_compute_graph.so"
)
DEPS_SONAME+=(
"libarm_compute.so"
"libarm_compute_graph.so"
)
fi
# ARM system libraries
DEPS_LIST+=(
"/lib64/libgomp.so.1"
"/usr/lib64/libgfortran.so.5"
)
DEPS_SONAME+=(
"libgomp.so.1"
"libgfortran.so.5"
)
# NVPL libraries (ARM optimized BLAS/LAPACK)
if [[ -d "/usr/local/lib" && -f "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ]]; then
echo "Adding NVPL libraries for ARM"
DEPS_LIST+=(
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0"
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0"
"/usr/local/lib/libnvpl_lapack_core.so.0"
"/usr/local/lib/libnvpl_blas_core.so.0"
)
DEPS_SONAME+=(
"libnvpl_lapack_lp64_gomp.so.0"
"libnvpl_blas_lp64_gomp.so.0"
"libnvpl_lapack_core.so.0"
"libnvpl_blas_core.so.0"
)
fi
fi
# run_tests.sh requires DESIRED_CUDA to know what tests to exclude
export DESIRED_CUDA="$cuda_version_nodot"
@ -326,11 +251,9 @@ export DESIRED_CUDA="$cuda_version_nodot"
rm -rf /usr/local/cuda || true
ln -s "/usr/local/cuda-${CUDA_VERSION}" /usr/local/cuda
# Switch `/usr/local/magma` to the desired CUDA version (skip for aarch64)
if [[ "$ARCH" != "aarch64" ]]; then
rm -rf /usr/local/magma || true
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
fi
# Switch `/usr/local/magma` to the desired CUDA version
rm -rf /usr/local/magma || true
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
export CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) # 10.0.130
export CUDA_VERSION_SHORT=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev | cut -f1,2 -d".") # 10.0

View File

@ -86,20 +86,10 @@ else
fi
fi
# Enable MKLDNN with ARM Compute Library for ARM builds
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
export USE_MKLDNN=1
# ACL is required for aarch64 builds
if [[ ! -d "/acl" ]]; then
echo "ERROR: ARM Compute Library not found at /acl"
echo "ACL is required for aarch64 builds. Check Docker image setup."
exit 1
fi
export USE_MKLDNN_ACL=1
export ACL_ROOT_DIR=/acl
echo "ARM Compute Library enabled for MKLDNN: ACL_ROOT_DIR=/acl"
fi
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then

View File

@ -100,6 +100,337 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
)
def _compile_and_extract_symbols(
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
) -> list[str]:
"""
Helper to compile a C++ file and extract all symbols.
Args:
cpp_content: C++ source code to compile
compile_flags: Compilation flags
exclude_list: List of symbol names to exclude. Defaults to ["main"].
Returns:
List of all symbols found in the object file (excluding those in exclude_list).
"""
import subprocess
import tempfile
if exclude_list is None:
exclude_list = ["main"]
with tempfile.TemporaryDirectory() as tmpdir:
tmppath = Path(tmpdir)
cpp_file = tmppath / "test.cpp"
obj_file = tmppath / "test.o"
cpp_file.write_text(cpp_content)
result = subprocess.run(
compile_flags + [str(cpp_file), "-o", str(obj_file)],
capture_output=True,
text=True,
timeout=60,
)
if result.returncode != 0:
raise RuntimeError(f"Compilation failed: {result.stderr}")
symbols = get_symbols(str(obj_file))
# Return all symbol names, excluding those in the exclude list
return [name for _addr, _stype, name in symbols if name not in exclude_list]
def check_stable_only_symbols(install_root: Path) -> None:
"""
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
This approach tests:
1. WITHOUT macros -> many torch symbols exposed
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
4. WITH both macros -> zero torch symbols (all hidden)
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
test_cpp_content = """
// Main torch C++ API headers
#include <torch/torch.h>
#include <torch/all.h>
// ATen tensor library
#include <ATen/ATen.h>
// Core c10 headers (commonly used)
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Optional.h>
int main() { return 0; }
"""
base_compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c", # Compile only, don't link
]
# Compile WITHOUT any macros
symbols_without = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=base_compile_flags,
)
# We expect constexpr symbols, inline functions used by other headers etc.
# to produce symbols
num_symbols_without = len(symbols_without)
print(f"Found {num_symbols_without} symbols without any macros defined")
assert num_symbols_without != 0, (
"Expected a non-zero number of symbols without any macros"
)
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
symbols_with_stable_only = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_stable_only,
)
num_symbols_with_stable_only = len(symbols_with_stable_only)
assert num_symbols_with_stable_only == 0, (
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
)
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
compile_flags_with_target_version = base_compile_flags + [
"-DTORCH_TARGET_VERSION=1"
]
symbols_with_target_version = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_target_version,
)
num_symbols_with_target_version = len(symbols_with_target_version)
assert num_symbols_with_target_version == 0, (
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
)
# Compile WITH both macros (expect 0 symbols)
compile_flags_with_both = base_compile_flags + [
"-DTORCH_STABLE_ONLY",
"-DTORCH_TARGET_VERSION=1",
]
symbols_with_both = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_both,
)
num_symbols_with_both = len(symbols_with_both)
assert num_symbols_with_both == 0, (
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
)
def check_stable_api_symbols(install_root: Path) -> None:
"""
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
stable_dir = include_dir / "torch" / "csrc" / "stable"
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
stable_headers = list(stable_dir.rglob("*.h"))
if not stable_headers:
raise RuntimeError("Could not find any stable headers")
includes = []
for header in stable_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_stable_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable = _compile_and_extract_symbols(
cpp_content=test_stable_content,
compile_flags=compile_flags,
)
num_symbols_stable = len(symbols_stable)
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
assert num_symbols_stable > 0, (
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable} symbols"
)
def check_headeronly_symbols(install_root: Path) -> None:
"""
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Find all headers in torch/headeronly
headeronly_dir = include_dir / "torch" / "headeronly"
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
headeronly_headers = list(headeronly_dir.rglob("*.h"))
if not headeronly_headers:
raise RuntimeError("Could not find any headeronly headers")
# Filter out platform-specific headers that may not compile everywhere
platform_specific_keywords = [
"cpu/vec",
]
filtered_headers = []
for header in headeronly_headers:
rel_path = header.relative_to(include_dir).as_posix()
if not any(
keyword in rel_path.lower() for keyword in platform_specific_keywords
):
filtered_headers.append(header)
includes = []
for header in filtered_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_headeronly_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_headeronly = _compile_and_extract_symbols(
cpp_content=test_headeronly_content,
compile_flags=compile_flags,
)
num_symbols_headeronly = len(symbols_headeronly)
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
assert num_symbols_headeronly > 0, (
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_headeronly} symbols"
)
def check_aoti_shim_symbols(install_root: Path) -> None:
"""
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_shim_content = """
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
int main() {
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
int32_t (*fp2)() = &aoti_torch_dtype_float32;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_shim = _compile_and_extract_symbols(
cpp_content=test_shim_content,
compile_flags=compile_flags,
)
num_symbols_shim = len(symbols_shim)
assert num_symbols_shim > 0, (
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_shim} symbols"
)
def check_stable_c_shim_symbols(install_root: Path) -> None:
"""
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Check if the stable C shim exists
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
if not stable_shim.exists():
raise RuntimeError("Could not find stable c shim")
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_stable_shim_content = """
#include <torch/csrc/stable/c/shim.h>
int main() {
// Reference stable C API functions to create undefined symbols
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable_shim = _compile_and_extract_symbols(
cpp_content=test_stable_shim_content,
compile_flags=compile_flags,
)
num_symbols_stable_shim = len(symbols_stable_shim)
assert num_symbols_stable_shim > 0, (
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable_shim} symbols"
)
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
print(f"lib: {lib}")
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
@ -129,6 +460,13 @@ def main() -> None:
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
# Check symbols when TORCH_STABLE_ONLY is defined
check_stable_only_symbols(install_root)
check_stable_api_symbols(install_root)
check_headeronly_symbols(install_root)
check_aoti_shim_symbols(install_root)
check_stable_c_shim_symbols(install_root)
if __name__ == "__main__":
main()

View File

@ -389,13 +389,6 @@ test_lazy_tensor_meta_reference_disabled() {
export -n TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE
}
test_dynamo_core() {
time python test/run_test.py \
--include-dynamo-core-tests \
--verbose \
--upload-artifacts-while-running
assert_git_not_dirty
}
test_dynamo_wrapped_shard() {
if [[ -z "$NUM_TEST_SHARDS" ]]; then
@ -1821,8 +1814,6 @@ elif [[ "${TEST_CONFIG}" == *inductor* ]]; then
test_inductor_shard "${SHARD_NUMBER}"
elif [[ "${TEST_CONFIG}" == *einops* ]]; then
test_einops
elif [[ "${TEST_CONFIG}" == *dynamo_core* ]]; then
test_dynamo_core
elif [[ "${TEST_CONFIG}" == *dynamo_wrapped* ]]; then
install_torchvision
test_dynamo_wrapped_shard "${SHARD_NUMBER}"

View File

@ -1 +1 @@
ee1a1350eb37804b94334768f328144f058f14e9
07b6cbde121417a70e4dc871adb6d27030e0ce3f

View File

@ -1 +1 @@
2d82dc5caa336d179d9b46ac4a0fb8c43d84c5cc
acccf86477759b2d3500f1ae1be065f7b1e409ec

View File

@ -1 +1 @@
94631807d22c09723dd006f7be5beb649d5f88d0
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a

View File

@ -7,7 +7,6 @@ ciflow_push_tags:
- ciflow/binaries
- ciflow/binaries_libtorch
- ciflow/binaries_wheel
- ciflow/dynamo
- ciflow/h100
- ciflow/h100-cutlass-backend
- ciflow/h100-distributed

View File

@ -50,7 +50,7 @@ def get_tag() -> str:
def get_base_version() -> str:
root = get_pytorch_root()
dirty_version = Path(root / "version.txt").read_text().strip()
dirty_version = open(root / "version.txt").read().strip()
# Strips trailing a0 from version.txt, not too sure why it's there in the
# first place
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)

View File

@ -260,8 +260,11 @@ jobs:
"${DOCKER_IMAGE}"
)
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
# Unified build script for all architectures (x86_64, aarch64, s390x)
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
if [[ ${BUILD_ENVIRONMENT} == *"aarch64"* ]]; then
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/aarch64_linux/aarch64_ci_build.sh"
else
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
fi
- name: Chown artifacts
if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }}

View File

@ -326,7 +326,7 @@ jobs:
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
DOCKER_IMAGE: ${{ inputs.docker-image }}
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}

View File

@ -1,70 +0,0 @@
# Workflow: Dynamo Unit Test
# runs unit tests for dynamo.
name: dynamo-unittest
on:
push:
tags:
- ciflow/dynamo/*
workflow_call:
schedule:
- cron: 29 8 * * * # about 1:29am PDT
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
opt_out_experiments: lf
dynamo-build:
name: dynamo-build
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
strategy:
matrix:
python-version: ['3.11', '3.12']
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
test-matrix: |
{ include: [
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
]}
secrets: inherit
dynamo-test:
name: dynamo-test
uses: ./.github/workflows/_linux-test.yml
needs: [get-label-type, dynamo-build]
strategy:
matrix:
python-version: ['3.11', '3.12']
with:
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
test-matrix: |
{ include: [
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
]}
secrets: inherit

View File

@ -1,330 +0,0 @@
import hashlib
import subprocess
import sys
from pathlib import Path
import click
import spin
def file_digest(file, algorithm: str):
try:
return hashlib.file_digest(file, algorithm)
except AttributeError:
pass # Fallback to manual implementation below
hash = hashlib.new(algorithm)
while chunk := file.read(8192):
hash.update(chunk)
return hash
def _hash_file(file):
with open(file, "rb") as f:
hash = file_digest(f, "sha256")
return hash.hexdigest()
def _hash_files(files):
hashes = {file: _hash_file(file) for file in files}
return hashes
def _read_hashes(hash_file: Path):
if not hash_file.exists():
return {}
with hash_file.open("r") as f:
lines = f.readlines()
hashes = {}
for line in lines:
hash = line[:64]
file = line[66:].strip()
hashes[file] = hash
return hashes
def _updated_hashes(hash_file, files_to_hash):
old_hashes = _read_hashes(hash_file)
new_hashes = _hash_files(files_to_hash)
if new_hashes != old_hashes:
return new_hashes
return None
@click.command()
def regenerate_version():
"""Regenerate version.py."""
cmd = [
sys.executable,
"-m",
"tools.generate_torch_version",
"--is-debug=false",
]
spin.util.run(cmd)
TYPE_STUBS = [
(
"Pytorch type stubs",
Path(".lintbin/.pytorch-type-stubs.sha256"),
[
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"tools/autograd/deprecated.yaml",
],
[
sys.executable,
"-m",
"tools.pyi.gen_pyi",
"--native-functions-path",
"aten/src/ATen/native/native_functions.yaml",
"--tags-path",
"aten/src/ATen/native/tags.yaml",
"--deprecated-functions-path",
"tools/autograd/deprecated.yaml",
],
),
(
"Datapipes type stubs",
None,
[],
[
sys.executable,
"torch/utils/data/datapipes/gen_pyi.py",
],
),
]
@click.command()
def regenerate_type_stubs():
"""Regenerate type stubs."""
for name, hash_file, files_to_hash, cmd in TYPE_STUBS:
if hash_file:
if hashes := _updated_hashes(hash_file, files_to_hash):
click.echo(
f"Changes detected in type stub files for {name}. Regenerating..."
)
spin.util.run(cmd)
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Type stubs and hashes updated.")
else:
click.echo(f"No changes detected in type stub files for {name}.")
else:
click.echo(f"No hash file for {name}. Regenerating...")
spin.util.run(cmd)
click.echo("Type stubs regenerated.")
@click.command()
def regenerate_clangtidy_files():
"""Regenerate clang-tidy files."""
cmd = [
sys.executable,
"-m",
"tools.linter.clang_tidy.generate_build_files",
]
spin.util.run(cmd)
#: These linters are expected to need less than 3s cpu time total
VERY_FAST_LINTERS = {
"ATEN_CPU_GPU_AGNOSTIC",
"BAZEL_LINTER",
"C10_NODISCARD",
"C10_UNUSED",
"CALL_ONCE",
"CMAKE_MINIMUM_REQUIRED",
"CONTEXT_DECORATOR",
"COPYRIGHT",
"CUBINCLUDE",
"DEPLOY_DETECTION",
"ERROR_PRONE_ISINSTANCE",
"EXEC",
"HEADER_ONLY_LINTER",
"IMPORT_LINTER",
"INCLUDE",
"LINTRUNNER_VERSION",
"MERGE_CONFLICTLESS_CSV",
"META_NO_CREATE_UNBACKED",
"NEWLINE",
"NOQA",
"NO_WORKFLOWS_ON_FORK",
"ONCE_FLAG",
"PYBIND11_INCLUDE",
"PYBIND11_SPECIALIZATION",
"PYPIDEP",
"PYPROJECT",
"RAWCUDA",
"RAWCUDADEVICE",
"ROOT_LOGGING",
"TABS",
"TESTOWNERS",
"TYPEIGNORE",
"TYPENOSKIP",
"WORKFLOWSYNC",
}
#: These linters are expected to take a few seconds, but less than 10s cpu time total
FAST_LINTERS = {
"CMAKE",
"DOCSTRING_LINTER",
"GHA",
"NATIVEFUNCTIONS",
"RUFF",
"SET_LINTER",
"SHELLCHECK",
"SPACES",
}
#: These linters are expected to take more than 10s cpu time total;
#: some need more than 1 hour.
SLOW_LINTERS = {
"ACTIONLINT",
"CLANGFORMAT",
"CLANGTIDY",
"CODESPELL",
"FLAKE8",
"GB_REGISTRY",
"PYFMT",
"PYREFLY",
"TEST_DEVICE_BIAS",
"TEST_HAS_MAIN",
}
ALL_LINTERS = VERY_FAST_LINTERS | FAST_LINTERS | SLOW_LINTERS
LINTRUNNER_CACHE_INFO = (
Path(".lintbin/.lintrunner.sha256"),
[
"requirements.txt",
"pyproject.toml",
".lintrunner.toml",
],
)
LINTRUNNER_BASE_CMD = [
"uvx",
"--python",
"3.10",
"lintrunner@0.12.7",
]
@click.command()
def setup_lint():
"""Set up lintrunner with current CI version."""
cmd = LINTRUNNER_BASE_CMD + ["init"]
subprocess.run(cmd, check=True, capture_output=True, text=True)
def _check_linters():
cmd = LINTRUNNER_BASE_CMD + ["list"]
ret = spin.util.run(cmd, output=False, stderr=subprocess.PIPE)
linters = {l.strip() for l in ret.stdout.decode().strip().split("\n")[1:]}
unknown_linters = linters - ALL_LINTERS
missing_linters = ALL_LINTERS - linters
if unknown_linters:
click.secho(
f"Unknown linters found; please add them to the correct category "
f"in .spin/cmds.py: {', '.join(unknown_linters)}",
fg="yellow",
)
if missing_linters:
click.secho(
f"Missing linters found; please update the corresponding category "
f"in .spin/cmds.py: {', '.join(missing_linters)}",
fg="yellow",
)
return unknown_linters, missing_linters
@spin.util.extend_command(
setup_lint,
doc=f"""
If configuration has changed, update lintrunner.
Compares the stored old hashes of configuration files with new ones and
performs setup via setup-lint if the hashes have changed.
Hashes are stored in {LINTRUNNER_CACHE_INFO[0]}; the following files are
considered: {", ".join(LINTRUNNER_CACHE_INFO[1])}.
""",
)
@click.pass_context
def lazy_setup_lint(ctx, parent_callback, **kwargs):
if hashes := _updated_hashes(*LINTRUNNER_CACHE_INFO):
click.echo(
"Changes detected in lint configuration files. Setting up linting tools..."
)
parent_callback(**kwargs)
hash_file = LINTRUNNER_CACHE_INFO[0]
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Linting tools set up and hashes updated.")
else:
click.echo("No changes detected in lint configuration files. Skipping setup.")
click.echo("Regenerating version...")
ctx.invoke(regenerate_version)
click.echo("Regenerating type stubs...")
ctx.invoke(regenerate_type_stubs)
click.echo("Done.")
_check_linters()
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def lint(ctx, apply_patches, **kwargs):
"""Lint all files."""
ctx.invoke(lazy_setup_lint)
all_files_linters = VERY_FAST_LINTERS | FAST_LINTERS
changed_files_linters = SLOW_LINTERS
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
all_files_cmd = cmd + [
"--take",
",".join(all_files_linters),
"--all-files",
]
spin.util.run(all_files_cmd)
changed_files_cmd = cmd + [
"--take",
",".join(changed_files_linters),
]
spin.util.run(changed_files_cmd)
@click.command()
@click.pass_context
def fixlint(ctx, **kwargs):
"""Autofix all files."""
ctx.invoke(lint, apply_patches=True)
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def quicklint(ctx, apply_patches, **kwargs):
"""Lint changed files."""
ctx.invoke(lazy_setup_lint)
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
spin.util.run(cmd)
@click.command()
@click.pass_context
def quickfix(ctx, **kwargs):
"""Autofix changed files."""
ctx.invoke(quicklint, apply_patches=True)

View File

@ -1,6 +1,5 @@
#pragma once
#include <torch/headeronly/core/TensorAccessor.h>
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Deprecated.h>
@ -12,37 +11,252 @@
namespace at {
using torch::headeronly::DefaultPtrTraits;
// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
// is used to enable the __restrict__ keyword/modifier for the data
// passed to cuda.
template <typename T>
struct DefaultPtrTraits {
typedef T* PtrType;
};
#if defined(__CUDACC__) || defined(__HIPCC__)
using torch::headeronly::RestrictPtrTraits;
template <typename T>
struct RestrictPtrTraits {
typedef T* __restrict__ PtrType;
};
#endif
// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.
// For CUDA tensors it is used in device code (only). This means that we restrict ourselves
// to functions and types available there (e.g. IntArrayRef isn't).
// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
using TensorAccessorBase = torch::headeronly::detail::TensorAccessorBase<c10::IntArrayRef, T, N, PtrTraits, index_t>;
class TensorAccessorBase {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST_DEVICE TensorAccessorBase(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: data_(data_), sizes_(sizes_), strides_(strides_) {}
C10_HOST IntArrayRef sizes() const {
return IntArrayRef(sizes_,N);
}
C10_HOST IntArrayRef strides() const {
return IntArrayRef(strides_,N);
}
C10_HOST_DEVICE index_t stride(index_t i) const {
return strides_[i];
}
C10_HOST_DEVICE index_t size(index_t i) const {
return sizes_[i];
}
C10_HOST_DEVICE PtrType data() {
return data_;
}
C10_HOST_DEVICE const PtrType data() const {
return data_;
}
protected:
PtrType data_;
const index_t* sizes_;
const index_t* strides_;
};
// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
// `Tensor.accessor<T, N>()`.
// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only
// indexing on the device uses `TensorAccessor`s.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
using TensorAccessor = torch::headeronly::detail::TensorAccessor<c10::IntArrayRef, T, N, PtrTraits, index_t>;
class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
namespace detail {
C10_HOST_DEVICE TensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
template <size_t N, typename index_t>
struct IndexBoundsCheck {
IndexBoundsCheck(index_t i) {
TORCH_CHECK_INDEX(
C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
}
C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
}
};
template<typename T, template <typename U> class PtrTraits, typename index_t>
class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST_DEVICE TensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
C10_HOST_DEVICE T & operator[](index_t i) {
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
return this->data_[this->strides_[0]*i];
}
C10_HOST_DEVICE const T & operator[](index_t i) const {
return this->data_[this->strides_[0]*i];
}
};
// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host
// and as
// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)
// in order to transfer them on the device when calling kernels.
// On the device, indexing of multidimensional tensors gives to `TensorAccessor`s.
// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
// on the device, so those functions are host only.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class GenericPackedTensorAccessorBase {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST GenericPackedTensorAccessorBase(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: data_(data_) {
std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
std::copy(strides_, strides_ + N, std::begin(this->strides_));
}
// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
C10_HOST GenericPackedTensorAccessorBase(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
: data_(data_) {
for (const auto i : c10::irange(N)) {
this->sizes_[i] = sizes_[i];
this->strides_[i] = strides_[i];
}
}
C10_HOST_DEVICE index_t stride(index_t i) const {
return strides_[i];
}
C10_HOST_DEVICE index_t size(index_t i) const {
return sizes_[i];
}
C10_HOST_DEVICE PtrType data() {
return data_;
}
C10_HOST_DEVICE const PtrType data() const {
return data_;
}
protected:
PtrType data_;
// NOLINTNEXTLINE(*c-arrays*)
index_t sizes_[N];
// NOLINTNEXTLINE(*c-arrays*)
index_t strides_[N];
C10_HOST void bounds_check_(index_t i) const {
TORCH_CHECK_INDEX(
0 <= i && i < index_t{N},
"Index ",
i,
" is not within bounds of a tensor of dimension ",
N);
}
}
};
} // namespace detail
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
using GenericPackedTensorAccessorBase = torch::headeronly::detail::GenericPackedTensorAccessorBase<detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
index_t* new_sizes = this->sizes_ + 1;
index_t* new_strides = this->strides_ + 1;
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
}
C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {
const index_t* new_sizes = this->sizes_ + 1;
const index_t* new_strides = this->strides_ + 1;
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
}
/// Returns a PackedTensorAccessor of the same dimension after transposing the
/// two dimensions given. Does not actually move elements; transposition is
/// made by permuting the size/stride arrays. If the dimensions are not valid,
/// asserts.
C10_HOST GenericPackedTensorAccessor<T, N, PtrTraits, index_t> transpose(
index_t dim1,
index_t dim2) const {
this->bounds_check_(dim1);
this->bounds_check_(dim2);
GenericPackedTensorAccessor<T, N, PtrTraits, index_t> result(
this->data_, this->sizes_, this->strides_);
std::swap(result.strides_[dim1], result.strides_[dim2]);
std::swap(result.sizes_[dim1], result.sizes_[dim2]);
return result;
}
};
template<typename T, template <typename U> class PtrTraits, typename index_t>
class GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
C10_DEVICE T & operator[](index_t i) {
return this->data_[this->strides_[0] * i];
}
C10_DEVICE const T& operator[](index_t i) const {
return this->data_[this->strides_[0]*i];
}
// Same as in the general N-dimensional case, but note that in the
// 1-dimensional case the returned PackedTensorAccessor will always be an
// identical copy of the original
C10_HOST GenericPackedTensorAccessor<T, 1, PtrTraits, index_t> transpose(
index_t dim1,
index_t dim2) const {
this->bounds_check_(dim1);
this->bounds_check_(dim2);
return GenericPackedTensorAccessor<T, 1, PtrTraits, index_t>(
this->data_, this->sizes_, this->strides_);
}
};
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
using GenericPackedTensorAccessor = torch::headeronly::detail::GenericPackedTensorAccessor<TensorAccessor<T, N-1, PtrTraits, index_t>, detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
// Can't put this directly into the macro function args because of commas
#define AT_X GenericPackedTensorAccessor<T, N, PtrTraits, index_t>

View File

@ -245,9 +245,6 @@ class TORCH_API TensorBase {
size_t weak_use_count() const noexcept {
return impl_.weak_use_count();
}
bool is_uniquely_owned() const noexcept {
return impl_.is_uniquely_owned();
}
std::string toString() const;

View File

@ -223,62 +223,6 @@ CONVERT_FROM_BF16_TEMPLATE(double)
CONVERT_FROM_BF16_TEMPLATE(float16_t)
#endif
#ifdef __ARM_FEATURE_BF16
// clang-[17, 20] crashes when autovectorizing static cast to bf16
// Below is a workaround to have some vectorization
// Works decently well for smaller int types
template <typename from_type>
inline void convertToBf16Impl(
const from_type* __restrict src,
c10::BFloat16* __restrict dst,
uint64_t n) {
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
uint64_t loopBound = n - (n % 16);
uint64_t i = 0;
for (; i < loopBound; i += 16) {
float32x4_t a, b, c, d;
a[0] = static_cast<float>(src[i]);
a[1] = static_cast<float>(src[i + 1]);
a[2] = static_cast<float>(src[i + 2]);
a[3] = static_cast<float>(src[i + 3]);
b[0] = static_cast<float>(src[i + 4]);
b[1] = static_cast<float>(src[i + 5]);
b[2] = static_cast<float>(src[i + 6]);
b[3] = static_cast<float>(src[i + 7]);
c[0] = static_cast<float>(src[i + 8]);
c[1] = static_cast<float>(src[i + 9]);
c[2] = static_cast<float>(src[i + 10]);
c[3] = static_cast<float>(src[i + 11]);
d[0] = static_cast<float>(src[i + 12]);
d[1] = static_cast<float>(src[i + 13]);
d[2] = static_cast<float>(src[i + 14]);
d[3] = static_cast<float>(src[i + 15]);
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
}
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
for (; i < n; i++) {
float a = static_cast<float>(src[i]);
dstPtr[i] = vcvth_bf16_f32(a);
}
}
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
template <> \
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
return convertToBf16Impl<from_type>(src, dst, n); \
}
CONVERT_TO_BF16_TEMPLATE(uint8_t)
CONVERT_TO_BF16_TEMPLATE(int8_t)
CONVERT_TO_BF16_TEMPLATE(int16_t)
CONVERT_TO_BF16_TEMPLATE(int32_t)
#endif
inline void convertBoolToBfloat16Impl(
const bool* __restrict src,
c10::BFloat16* __restrict dst,

View File

@ -3,7 +3,6 @@
#include <cstdint>
#include <map>
#include <shared_mutex>
#include <cuda_runtime_api.h>
#include <cusparse.h>
@ -89,13 +88,8 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
struct WorkspaceMapWithMutex {
std::map<std::tuple<void*, void*>, at::DataPtr> map;
std::shared_mutex mutex;
};
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();

View File

@ -175,24 +175,17 @@ void CUDAGraph::instantiate() {
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
// who prefer not to report error message through these arguments moving forward
// (they prefer return value, or errors on api calls internal to the capture)
// ROCM appears to fail with HIP error: invalid argument
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && !defined(USE_ROCM)
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, cudaGraphInstantiateFlagUseNodePriority));
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, 0));
#else
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
#endif
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
} else {
#if !defined(USE_ROCM)
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
graph_,
cudaGraphInstantiateFlagAutoFreeOnLaunch | cudaGraphInstantiateFlagUseNodePriority));
#else
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
graph_,
cudaGraphInstantiateFlagAutoFreeOnLaunch));
#endif
}
has_graph_exec_ = true;
}

View File

@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) {
// - Comments of @soumith copied from cuDNN handle pool implementation
#ifdef NO_CUDNN_DESTROY_HANDLE
#else
cublasDestroy(handle);
cublasDestroy(handle);
#endif
}
@ -107,27 +107,19 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
} // namespace
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
return instance;
}
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
return instance;
}
void clearCublasWorkspaces() {
{
auto& workspace = cublas_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
{
auto& workspace = cublaslt_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
cublas_handle_stream_to_workspace().clear();
cublaslt_handle_stream_to_workspace().clear();
}
size_t parseChosenWorkspaceSize() {
@ -241,38 +233,6 @@ at::DataPtr getNewCUDABlasLtWorkspace() {
return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize());
}
void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) {
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto& workspace = cublas_handle_stream_to_workspace();
size_t workspace_size = getChosenWorkspaceSize();
// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
handle, workspace_it->second.get(), workspace_size));
return;
}
}
// Slow path: allocate workspace outside the lock
auto new_workspace = getNewWorkspace();
// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.try_emplace(key, std::move(new_workspace)).first;
TORCH_CUDABLAS_CHECK(
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
}
}
void* getCUDABlasLtWorkspace() {
#ifndef USE_ROCM
static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true;
@ -281,10 +241,8 @@ void* getCUDABlasLtWorkspace() {
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
return workspace_it->second.mutable_get();
}
#endif
@ -292,29 +250,11 @@ void* getCUDABlasLtWorkspace() {
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto& workspace = cublaslt_handle_stream_to_workspace();
// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
return workspace_it->second.mutable_get();
}
}
// Slow path: allocate workspace outside the lock
auto new_workspace = getNewCUDABlasLtWorkspace();
// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it =
workspace.map.try_emplace(key, std::move(new_workspace)).first;
return workspace_it->second.mutable_get();
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
}
return workspace_it->second.mutable_get();
}
cublasHandle_t getCurrentCUDABlasHandle() {
@ -358,8 +298,13 @@ cublasHandle_t getCurrentCUDABlasHandle() {
// will allocate memory dynamically (even if they're cheap) outside
// PyTorch's CUDA caching allocator. It's possible that CCA used up
// all the memory and cublas's cudaMallocAsync will return OOM
setWorkspaceForHandle(handle, stream);
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
}
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
#if !defined(USE_ROCM)
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.

View File

@ -22,7 +22,6 @@ enum class MacOSVersion : uint32_t {
MACOS_VER_15_0_PLUS,
MACOS_VER_15_1_PLUS,
MACOS_VER_15_2_PLUS,
MACOS_VER_26_0_PLUS,
};
//-----------------------------------------------------------------

View File

@ -65,7 +65,6 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
static bool _macos_15_1_plus = is_os_version_at_least(15, 1);
static bool _macos_15_2_plus = is_os_version_at_least(15, 2);
static bool _macos_26_0_plus = is_os_version_at_least(26, 0);
switch (version) {
case MacOSVersion::MACOS_VER_14_4_PLUS:
@ -76,8 +75,6 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
return _macos_15_1_plus;
case MacOSVersion::MACOS_VER_15_2_PLUS:
return _macos_15_2_plus;
case MacOSVersion::MACOS_VER_26_0_PLUS:
return _macos_26_0_plus;
default:
return false;
}

View File

@ -1936,7 +1936,7 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
// We order the tensors. t1 will be the larger tensor
// We can always transpose tensor2 as the dimensions are always >= 1 (precondition from matmul)
// and tensor1_larger iff tensor2.dim() > tensor1.dim(9
// and tensor1_larger iff tensor2.dim() > tensor1.dim()
const auto t1 = tensor1_larger ? MaybeOwned<Tensor>::borrowed(tensor1)
: MaybeOwned<Tensor>::owned(tensor2.mT());
const int64_t dim_t1 = t1->dim();
@ -1948,20 +1948,11 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
return false;
}
// In this case we *do* incur in an extra copy to avoid creating an unnecessary large tensor in the backward
// Suppose we don't fold here. Let t1.shape = [b, m, n] t2.shape = [n, k] like in a transformer
// t2 will be expanded to a tensor of shape [b, n, k] and then we do t1.bmm(t2_expanded)
// The issue appears in the backward.
// The output gradient g of this operation would have shape [b, m, k]
// The backward wrt. t2 of bmm would be given by t1.mH @ g, which has shape [b, n, k]
// Then, the backward of expand is simply `sum(0)`. As such, we are instantiating a tensor
// of shape [b, n, k] unnecessarily, which may cause a large memory footprint, and in the
// worst case, an OOM
bool t2_requires_grad = tensor1_larger ? tensor2.requires_grad() : tensor1.requires_grad();
if (t2_requires_grad && !has_out) {
// We should be checking !at::GradMode::is_enabled(), but apparently
// this regresses performance in some cases:
// https://github.com/pytorch/pytorch/issues/118548#issuecomment-1916022394
// If we require a gradient, we should fold to minimize backward memory usage - even if this
// leads to a copy in forward because is needed in backward,
// only time we avoid this strict pre-allocated memory usage (has_out = True)
bool requires_grad = tensor1.requires_grad() || tensor2.requires_grad();
if (requires_grad && !has_out) {
return true;
}

View File

@ -1087,8 +1087,7 @@ TORCH_IMPL_FUNC(index_copy_out)
result.copy_(self);
// See Note [Enabling Deterministic Operations]
if ((result.is_cuda() || result.is_xpu()) &&
globalContext().deterministicAlgorithms()) {
if (result.is_cuda() && globalContext().deterministicAlgorithms()) {
torch::List<std::optional<Tensor>> indices;
indices.resize(dim + 1);
indices.set(dim, index);

View File

@ -904,11 +904,19 @@ Tensor mvlgamma(const Tensor& self, int64_t p) {
return args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER);
}
// since mvlgamma_ has different signature from its
// out and functional variant, we explicitly
// define it (instead of using structured kernel).
Tensor& mvlgamma_(Tensor& self, int64_t p) {
return at::mvlgamma_out(self, self, p);
mvlgamma_check(self, p);
Tensor args = native::arange(
-p *HALF + HALF,
HALF,
HALF,
optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().layout_opt(),
self.options().device_opt(),
self.options().pinned_memory_opt());
args = args.add(self.unsqueeze(-1));
const auto p2_sub_p = static_cast<double>(p * (p - 1));
return self.copy_(args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER));
}
Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {

View File

@ -296,7 +296,7 @@ template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmAndBiasCublasLt(
// args contains result which is modified
cublasCommonArgs& args,
const std::optional<Tensor>& self,
const Tensor& self,
const Scalar& alpha,
Activation activation = Activation::None
) {
@ -304,8 +304,12 @@ bool launchGemmAndBiasCublasLt(
// or when it can be squeezed to 1D.
// self_ptr == nullptr implies ignore bias epilogue
// and use standard gemm-like API.
const auto* self_ptr = self.has_value() ? self.value().const_data_ptr<scalar_t>() : static_cast<const scalar_t*>(nullptr);
const auto* self_ptr = [&]() -> auto {
if (self.dim() == 1 || self.squeeze().dim() == 1) {
return self.const_data_ptr<scalar_t>();
}
return static_cast<const scalar_t*>(nullptr);
}();
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
@ -388,30 +392,35 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
disable_addmm_cuda_lt = disable_addmm_cuda_lt || isGloballyDisabledAddmmCudaLt(self.device());
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
#endif
// Condition on the input
disable_addmm_cuda_lt = disable_addmm_cuda_lt || !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation);
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation) || disable_addmm_cuda_lt;
// }
at::ScalarType scalar_type = mat1.scalar_type();
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
#ifdef USE_ROCM
disable_addmm_cuda_lt = disable_addmm_cuda_lt || is_float_output_with_half_input;
#endif
bool use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
// for float output with half input cublasLT with bias produces wrong results
use_bias_ptr_lt &= !is_float_output_with_half_input;
// Handle result/self shapes
if (!result.is_same(self)) {
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
// We do not copy bias only when we need the bias ptr
// We use bias ptr in the Lt path only when bias is 1D
const auto use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
if (!use_bias_ptr_lt) {
// We do expand self even before
// check for beta != 0.0 to make sure that
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
// runs green.
return expand_size(self, result.sizes(), "addmm");
}
return c10::MaybeOwned<Tensor>::borrowed(self);
}();
// We do not copy bias only when we need the bias ptr
if (beta.toComplexDouble() != 0.0 && !use_bias_ptr_lt) {
// NOTE: self should broadcast over result
at::native::copy_(result, *expand_size(self, result.sizes(), "addmm"));
at::native::copy_(result, *self_maybe_expanded);
}
}
@ -459,7 +468,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation);
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
}
);
#endif
@ -471,7 +480,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation);
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
}
);
} // end is_float_output_with_half_input
@ -927,7 +936,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
return _int_mm_out_cuda(self, mat2, result);
}
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
@ -951,7 +960,7 @@ static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& bat
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
if (self_baddbmm.has_value()) {
if (!is_bmm && self_baddbmm.has_value()) {
const auto& self = self_baddbmm.value();
TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
TORCH_CHECK(self.sizes() == output_size, "self must have the same shape as the output");
@ -959,12 +968,15 @@ static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& bat
}
Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype));
IntArrayRef batch1_sizes = batch1.sizes();
IntArrayRef batch2_sizes = batch2.sizes();
Tensor out = at::empty({batch1_sizes[0], batch1_sizes[1], batch2_sizes[2]}, batch1.options().dtype(out_dtype));
return _bmm_out_dtype_cuda(batch1, batch2, out_dtype, out);
}
Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, Tensor &out) {
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype);
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype, true);
Scalar beta(0.0);
Scalar alpha(1.0);
{
@ -976,16 +988,14 @@ Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at
}
Tensor _baddbmm_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) {
TORCH_CHECK(self.scalar_type() == out_dtype || self.scalar_type() == batch1.dtype(),
"self dtype must match either out_dtype or batch1 dtype");
Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype));
return _baddbmm_out_dtype_cuda(self, batch1, batch2, out_dtype, beta, alpha, out);
// We need to copy the tensor
Tensor out = self.clone().to(self.options().dtype(out_dtype));
return _baddbmm_out_dtype_cuda(out, batch1, batch2, out_dtype, beta, alpha, out);
}
Tensor& _baddbmm_out_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, out);
// We need to copy the tensor
out.copy_(self);
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, false, self);
{
NoNamesGuard guard;
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha);
@ -1020,27 +1030,24 @@ Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::Sca
}
Tensor _addmm_dtype_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) {
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
Tensor result = at::empty({mat1.size(0), mat2.size(1)}, self.options().dtype(out_dtype));
Tensor result = at::empty(self.sizes(), self.options().dtype(out_dtype));
return _addmm_dtype_out_cuda(self, mat1, mat2, out_dtype, beta, alpha, result);
}
Tensor& _addmm_dtype_out_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
// repeat dimensionality checks for direct calls to `out` overload
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(out_dtype == mat1.scalar_type() ||
(out_dtype == at::ScalarType::Float && (mat1.scalar_type() == at::ScalarType::Half || mat1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(out_dtype == self.scalar_type() || self.scalar_type() == mat1.scalar_type(),
"self dtype must match either out_dtype or mat1 dtype");
TORCH_CHECK(out_dtype == self.scalar_type() ||
(out_dtype == at::ScalarType::Float && (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
addmm_out_cuda_impl(out, self, mat1, mat2, beta, alpha);

View File

@ -75,52 +75,30 @@ static inline bool can_use_int32_nhwc(
return true;
}
static inline bool can_use_int32_nchw(
int64_t nbatch, int64_t channels,
int64_t height, int64_t width,
int64_t pooled_height, int64_t pooled_width) {
int64_t hw = height * width;
return can_use_int32_nhwc(
nbatch, channels, height, width,
pooled_height, pooled_width,
channels * hw, // in_stride_n
hw, // in_stride_c
width, // in_stride_h
1 // in_stride_w
);
}
// kernels borrowed from Caffe
template <typename scalar_t, typename index_t>
__global__ void max_pool_forward_nchw(
const index_t nthreads,
const scalar_t* bottom_data,
const int64_t channels,
const int64_t height,
const int64_t width,
const int pooled_height,
const int pooled_width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
scalar_t* top_data,
template <typename scalar_t>
__global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom_data,
const int64_t channels, const int64_t height,
const int64_t width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, scalar_t* top_data,
int64_t* top_mask) {
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
index_t pw = index % pooled_width;
index_t ph = (index / pooled_width) % pooled_height;
index_t c = (index / pooled_width / pooled_height) % channels;
index_t n = index / pooled_width / pooled_height / channels;
index_t hstart = ph * stride_h - pad_h;
index_t wstart = pw * stride_w - pad_w;
index_t hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
index_t wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
while(hstart < 0)
hstart += dilation_h;
while(wstart < 0)
wstart += dilation_w;
scalar_t maxval = at::numeric_limits<scalar_t>::lower_bound(); // -Infinity
index_t maxidx = hstart * width + wstart;
int maxidx = hstart * width + wstart;
const scalar_t* btm_data = bottom_data + (n * channels + c) * height * width;
for (int h = hstart; h < hend; h += dilation_h) {
for (int w = wstart; w < wend; w += dilation_w) {
@ -273,39 +251,32 @@ __global__ void max_pool_forward_nhwc(
static constexpr int BLOCK_THREADS = 256;
template <typename scalar_t, typename accscalar_t, typename index_t>
template <typename scalar_t, typename accscalar_t>
#if defined (USE_ROCM)
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4)
#else
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 8)
#endif
__global__ void max_pool_backward_nchw(
const scalar_t* top_diff,
const int64_t* top_mask,
const index_t num,
const index_t channels,
const index_t height,
const index_t width,
const index_t pooled_height,
const index_t pooled_width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
__global__ void max_pool_backward_nchw(const scalar_t* top_diff,
const int64_t* top_mask, const int num, const int64_t channels,
const int64_t height, const int64_t width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
scalar_t* bottom_diff) {
CUDA_KERNEL_LOOP_TYPE(index, height*width, index_t) {
index_t h = index / width;
index_t w = index - h * width;
index_t phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
index_t phend = p_end(h, pad_h, pooled_height, stride_h);
index_t pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
index_t pwend = p_end(w, pad_w, pooled_width, stride_w);
for (index_t n = blockIdx.y; n < num; n += gridDim.y) {
for (index_t c = blockIdx.z; c < channels; c += gridDim.z) {
CUDA_KERNEL_LOOP(index, height*width) {
int h = index / width;
int w = index - h * width;
int phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
int phend = p_end(h, pad_h, pooled_height, stride_h);
int pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
int pwend = p_end(w, pad_w, pooled_width, stride_w);
for (int n = blockIdx.y; n < num; n += gridDim.y) {
for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
accscalar_t gradient = accscalar_t(0);
index_t offset = (n * channels + c) * pooled_height * pooled_width;
for (index_t ph = phstart; ph < phend; ++ph) {
for (index_t pw = pwstart; pw < pwend; ++pw) {
int offset = (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (top_mask[ph * pooled_width + pw + offset] == h * width + w) {
gradient += static_cast<accscalar_t>(top_diff[ph * pooled_width + pw + offset]);
}
@ -498,6 +469,8 @@ const Tensor& indices) {
const int64_t in_stride_h = input.stride(-2);
const int64_t in_stride_w = input.stride(-1);
const int count = safe_downcast<int, int64_t>(output.numel());
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"max_pool2d_with_indices_out_cuda_frame",
[&] {
@ -580,42 +553,14 @@ const Tensor& indices) {
break;
}
case MemoryFormat::Contiguous: {
const int threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
BLOCK_THREADS);
const int64_t nthreads = output.numel();
bool use_int32 = can_use_int32_nchw(
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
const int blocks = static_cast<int>(std::min<int64_t>(
ceil_div(nthreads, static_cast<int64_t>(threads)),
static_cast<int64_t>(maxGridX)));
auto stream = at::cuda::getCurrentCUDAStream();
if (use_int32) {
max_pool_forward_nchw<scalar_t, int32_t>
<<<blocks, threads, 0, stream>>>(
static_cast<int32_t>(nthreads),
input_data,
static_cast<int32_t>(nInputPlane),
static_cast<int32_t>(inputHeight),
static_cast<int32_t>(inputWidth),
static_cast<int32_t>(outputHeight),
static_cast<int32_t>(outputWidth),
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
output_data, indices_data);
} else {
max_pool_forward_nchw<scalar_t, int64_t>
<<<blocks, threads, 0, stream>>>(
nthreads,
input_data,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
output_data, indices_data);
}
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
BLOCK_THREADS);
max_pool_forward_nchw<scalar_t>
<<<ceil_div(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count, input_data,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
output_data, indices_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
@ -688,6 +633,8 @@ const Tensor& gradInput) {
gradInput.zero_();
int64_t count = input.numel();
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"max_pool2d_with_indices_out_cuda_frame",
[&] {
@ -745,45 +692,25 @@ const Tensor& gradInput) {
break;
}
case MemoryFormat::Contiguous: {
const int threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
BLOCK_THREADS);
const int imgcount = inputWidth * inputHeight;
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
const int maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const int maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
const int blocks_x = std::min(ceil_div(imgcount, threads), maxGridX);
dim3 grid(blocks_x, static_cast<unsigned>(std::min<int64_t>(nbatch, maxGridY)), static_cast<unsigned>(std::min<int64_t>(nInputPlane, maxGridZ)));
bool use_int32 = can_use_int32_nchw(
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
auto stream = at::cuda::getCurrentCUDAStream();
if (use_int32) {
max_pool_backward_nchw<scalar_t, accscalar_t, int32_t>
<<<grid, threads, 0, stream>>>(
gradOutput_data,
indices_data,
static_cast<int32_t>(nbatch),
static_cast<int32_t>(nInputPlane),
static_cast<int32_t>(inputHeight),
static_cast<int32_t>(inputWidth),
static_cast<int32_t>(outputHeight),
static_cast<int32_t>(outputWidth),
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
gradInput_data);
} else {
max_pool_backward_nchw<scalar_t, accscalar_t, int64_t>
<<<grid, threads, 0, stream>>>(
gradOutput_data,
indices_data,
nbatch,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
gradInput_data);
}
int imgcount = inputWidth * inputHeight;
dim3 grid;
const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS;
grid.x = blocks;
grid.y = nbatch;
uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
if (maxGridY < grid.y) grid.y = maxGridY;
grid.z = nInputPlane;
uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
if (maxGridZ < grid.z) grid.z = maxGridZ;
max_pool_backward_nchw<scalar_t, accscalar_t>
<<<grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
gradOutput_data,
indices_data,
nbatch,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
gradInput_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}

View File

@ -78,18 +78,9 @@ __global__ void EmbeddingBag_updateOutputKernel_max(
scalar_t weightFeatMax = 0;
int64_t bag_size_ = 0;
int64_t maxWord = -1;
// Separate validation loop reduces register pressure in the main loop below.
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
bool has_invalid_index = false;
for (int64_t emb = begin; emb < end; emb++) {
index_t input_idx = input[emb];
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
}
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
for (int64_t emb = begin; emb < end; emb++) {
bool pad = (input[emb] == padding_idx);
CUDA_KERNEL_ASSERT(input[emb] < numRows);
const int64_t weightRow = input[emb] * weight_stride0;
scalar_t weightValue = weightFeat[weightRow];
if (bag_size_ == 0 || weightValue > weightFeatMax) {
@ -138,19 +129,10 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
CUDA_KERNEL_ASSERT(end >= begin);
accscalar_t weightFeatSum = 0;
int64_t bag_size_ = 0;
// Separate validation loop reduces register pressure in the main loop below.
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
bool has_invalid_index = false;
for (int64_t emb = begin; emb < end; emb++) {
index_t input_idx = input[emb];
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
}
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
for (int64_t emb = begin; emb < end; emb++) {
index_t input_idx = input[emb];
bool pad = (input_idx == padding_idx);
CUDA_KERNEL_ASSERT(0 <= input_idx && input_idx < numRows);
const int64_t weightRow = input_idx * weight_stride0;
scalar_t weightValue = weightFeat[weightRow];
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;

View File

@ -78,9 +78,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const SwizzleType swizzle_a,
const SwizzleType& swizzle_a,
const Tensor& scale_b,
const SwizzleType swizzle_b,
const SwizzleType& swizzle_b,
const std::optional<at::Tensor>& offs,
Tensor& out) {
const bool a_is_2d = mat_a.dim() == 2;

View File

@ -5,11 +5,69 @@
#include <cuda_bf16.h>
#endif
// ROCm 6.3 is planned to have these functions, but until then here they are.
#if defined(USE_ROCM)
#include <device_functions.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#define ATOMICADD unsafeAtomicAdd
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
#if (defined(__gfx942__)) && \
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
union {
__hip_bfloat162_raw bf162_raw;
vec_short2 vs2;
} u{static_cast<__hip_bfloat162_raw>(value)};
u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2);
return static_cast<__hip_bfloat162>(u.bf162_raw);
#else
static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw));
union u_hold {
__hip_bfloat162_raw h2r;
unsigned int u32;
};
u_hold old_val, new_val;
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
do {
new_val.h2r = __hadd2(old_val.h2r, value);
} while (!__hip_atomic_compare_exchange_strong(
(unsigned int*)address, &old_val.u32, new_val.u32,
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
return old_val.h2r;
#endif
}
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
#if (defined(__gfx942__)) && \
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
// The api expects an ext_vector_type of half
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
static_assert(sizeof(vec_fp162) == sizeof(__half2_raw));
union {
__half2_raw h2r;
vec_fp162 fp16;
} u {static_cast<__half2_raw>(value)};
u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16);
return static_cast<__half2>(u.h2r);
#else
static_assert(sizeof(__half2_raw) == sizeof(unsigned int));
union u_hold {
__half2_raw h2r;
unsigned int u32;
};
u_hold old_val, new_val;
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
do {
new_val.h2r = __hadd2(old_val.h2r, value);
} while (!__hip_atomic_compare_exchange_strong(
(unsigned int*)address, &old_val.u32, new_val.u32,
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
return old_val.h2r;
#endif
}
#define ATOMICADD preview_unsafeAtomicAdd
#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f)
#else
#define ATOMICADD atomicAdd

View File

@ -2,250 +2,18 @@
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/native/cuda/ScanUtils.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/OpMathType.h>
#include <c10/util/MathConstants.h>
#include <c10/util/complex.h>
#include <cmath>
#include <limits>
// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
namespace at::native {
// custom min and max to be used in logaddexp for complex arguments
template <typename scalar_t, bool min>
__host__ __device__ c10::complex<scalar_t> _logaddexp_minmax(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
scalar_t xr = std::real(x);
scalar_t yr = std::real(y);
if (::isnan(yr) || (::isnan(std::imag(y)))) {
return y;
} else if (::isnan(xr) || (::isnan(std::imag(x)))) {
return x;
} else if (min) { // min
return (xr < yr) ? x : y;
} else { // max
return (xr >= yr) ? x : y;
}
}
template <typename scalar_t>
__host__ __device__ scalar_t _log_add_exp_helper(const scalar_t& x, const scalar_t& y) {
// Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
// Using the original expression: `at::_isnan(y) ? y : std::min(x, y)` causes an error in ROCM
const auto isnan_x = at::_isnan(x);
const auto isnan_y = at::_isnan(y);
scalar_t min = isnan_y ? y : (isnan_x ? x : std::min(x, y));
scalar_t max = isnan_y ? y : (isnan_x ? x : std::max(x, y));
if (min != max || ::isfinite(min)) {
// nan will be propagated here
return ::log1p(std::exp(min - max)) + max;
} else {
// special case to correctly handle infinite cases
return x;
}
}
template <typename scalar_t>
__host__ __device__ c10::complex<scalar_t> _fast_build_exp(const c10::complex<scalar_t>& x) {
// complex exponential function, but implemented manually to get fast compilation time
// this function only handles the case where the x is finite (not inf nor nan)
const auto xreal = std::real(x);
const auto ximag = std::imag(x);
const auto exp_x_abs = std::exp(xreal);
auto exp_x_real = exp_x_abs * std::cos(ximag);
auto exp_x_imag = exp_x_abs * std::sin(ximag);
return {exp_x_real, exp_x_imag};
}
template <typename scalar_t>
__host__ __device__ c10::complex<scalar_t> _fast_build_exp_inf(const c10::complex<scalar_t>& x) {
// complex exponential function, but implemented manually to get fast compilation time
// this function only handles the case where the real part of x is infinite
const auto ximag = std::imag(x);
constexpr auto exp_x_abs = std::numeric_limits<scalar_t>::infinity();
if (!::isfinite(ximag)) { // add this to make consitent with std::exp(x+yi)
return {exp_x_abs, std::numeric_limits<scalar_t>::quiet_NaN()};
}
const auto sin = std::sin(ximag);
const auto cos = std::cos(ximag);
// special case if the angle is exactly the multiple of pi/2
auto exp_x_real = (cos == 0) ? (scalar_t)0.0 : exp_x_abs * cos;
auto exp_x_imag = (sin == 0) ? (scalar_t)0.0 : exp_x_abs * sin;
return {exp_x_real, exp_x_imag};
}
template <typename scalar_t>
__host__ __device__ c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
c10::complex<scalar_t> min = _logaddexp_minmax<scalar_t, /*min=*/true>(x, y);
c10::complex<scalar_t> max = _logaddexp_minmax<scalar_t, /*min=*/false>(x, y);
scalar_t min_real = std::real(min);
scalar_t max_real = std::real(max);
if (::isnan(min_real) || ::isnan(std::imag(min))) {
// handling the "infectious" NaNs
return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
}
else if ((!::isfinite(min_real)) && (min_real == max_real)) {
if (min_real < 0) {
// handle the -inf case, the imaginary part here does not really matter as the exp(value)
// will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
// It does not matter if we're taking the exp of this value
return min;
} else {
// handle the +inf case, we don't need the special precision for log1p for small values
// and to avoid producing nan in case of real(max) == real(min) == +inf
const auto exp_min = _fast_build_exp_inf(min);
const auto exp_max = _fast_build_exp_inf(max);
return ::log1p(exp_min + exp_max - 1); // log1p(x - 1) builds faster than log
}
} else {
const auto minmax = min - max;
c10::complex<scalar_t> exp_minmax;
if (!::isfinite(minmax.real())) {
exp_minmax = minmax.real() < 0 ? c10::complex<scalar_t>{0.0, 0.0} : _fast_build_exp_inf(minmax);
} else {
exp_minmax = _fast_build_exp(minmax);
}
return ::log1p(exp_minmax) + max;
}
}
// Complex logaddexp jiterator string
const auto logaddexp_complex_string = jiterator_stringify(
template<typename T>
std::complex<T> log1p(const std::complex<T>& z)
{
using complex_t = std::complex<T>;
T x = z.real();
T y = z.imag();
T zabs = abs(z);
T theta = atan2(y, x + T(1));
if (zabs < 0.5) {
T r = x * (T(2) + x) + y * y;
if (r == 0) { // handle underflow
return complex_t(x, theta);
}
return complex_t(T(0.5) * std::log1p(r), theta);
} else {
T z0 = std::hypot(x + 1, y);
return complex_t(log(z0), theta);
}
}
// separated _logaddexp_minmax into 2 different functions for jiterator_string
template <typename T>
std::complex<T> logaddexp_min(const std::complex<T>& x, const std::complex<T>& y) {
T xr = x.real();
T yr = y.real();
if (isnan(yr) || isnan(y.imag())) {
return y;
} else if (isnan(xr) || isnan(x.imag())) {
return x;
} else {
return (xr < yr) ? x : y;
}
}
template <typename T>
std::complex<T> logaddexp_max(const std::complex<T>& x, const std::complex<T>& y) {
T xr = x.real();
T yr = y.real();
if (isnan(yr) || isnan(y.imag())) {
return y;
} else if (isnan(xr) || isnan(x.imag())) {
return x;
} else {
return (xr >= yr) ? x : y;
}
}
template <typename T>
std::complex<T> fast_build_exp(const std::complex<T>& x) {
const auto xreal = x.real();
const auto ximag = x.imag();
const auto exp_x_abs = exp(xreal);
auto exp_x_real = exp_x_abs * cos(ximag);
auto exp_x_imag = exp_x_abs * sin(ximag);
return std::complex<T>(exp_x_real, exp_x_imag);
}
template <typename T>
std::complex<T> fast_build_exp_inf(const std::complex<T>& x) {
using complex_t = std::complex<T>;
const auto ximag = x.imag();
const T exp_x_abs = INFINITY;
if (!isfinite(ximag)) {
return complex_t(exp_x_abs, NAN);
}
const auto sin_val = sin(ximag);
const auto cos_val = cos(ximag);
auto exp_x_real = (cos_val == T(0)) ? T(0) : exp_x_abs * cos_val;
auto exp_x_imag = (sin_val == T(0)) ? T(0) : exp_x_abs * sin_val;
return complex_t(exp_x_real, exp_x_imag);
}
template <typename complex_t>
complex_t logaddexp_complex(complex_t x, complex_t y) {
using T = typename complex_t::value_type;
complex_t min_val = logaddexp_min(x, y);
complex_t max_val = logaddexp_max(x, y);
T min_real = min_val.real();
T max_real = max_val.real();
if (isnan(min_real) || isnan(min_val.imag())) {
return complex_t(NAN, NAN);
}
else if ((!isfinite(min_real)) && (min_real == max_real)) {
if (min_real < T(0)) {
return min_val;
} else {
const auto exp_min = fast_build_exp_inf<T>(min_val);
const auto exp_max = fast_build_exp_inf<T>(max_val);
return log1p(exp_min + exp_max - complex_t(1, 0));
}
} else {
const auto minmax = min_val - max_val;
complex_t exp_minmax;
if (!isfinite(minmax.real())) {
exp_minmax = (minmax.real() < T(0)) ? complex_t(0, 0) : fast_build_exp_inf<T>(minmax);
} else {
exp_minmax = fast_build_exp<T>(minmax);
}
return log1p(exp_minmax) + max_val;
}
}
);
constexpr char logaddexp_complex_name[] = "logaddexp_complex";
void logaddexp_kernel_cuda(TensorIteratorBase& iter) {
if (at::isComplexType(iter.dtype())) {
#if AT_USE_JITERATOR()
AT_DISPATCH_COMPLEX_TYPES_AND(at::ScalarType::ComplexHalf, iter.dtype(), "logaddexp_cuda", [&]() {
jitted_gpu_kernel<
/*name=*/logaddexp_complex_name,
/*return_dtype=*/scalar_t,
/*common_dtype=*/scalar_t,
/*arity=*/2>(iter, logaddexp_complex_string);
});
#else
AT_DISPATCH_COMPLEX_TYPES_AND(at::ScalarType::ComplexHalf, iter.dtype(), "logaddexp_cuda", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t {
const auto a = static_cast<opmath_t>(a_);
const auto b = static_cast<opmath_t>(b_);
return static_cast<scalar_t>(_log_add_exp_helper(a, b));
});
});
#endif
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16, ScalarType::Half,
iter.dtype(), "logaddexp_cuda",
[&]() {
@ -261,7 +29,6 @@ void logaddexp_kernel_cuda(TensorIteratorBase& iter) {
}
});
});
}
}
void logaddexp2_kernel_cuda(TensorIteratorBase& iter) {

View File

@ -740,12 +740,7 @@ _scaled_rowwise_rowwise(
TORCH_CHECK_VALUE(scale_a.numel() == mat_a.size(0) && scale_a.scalar_type() == kFloat, "scale_a must have ", mat_a.size(0), " Float elements, got ", scale_a.numel())
TORCH_CHECK_VALUE(scale_b.numel() == mat_b.size(1) && scale_b.scalar_type() == kFloat, "scale_b must have ", mat_b.size(1), " Float elements, got ", scale_b.numel())
// if we have a scale of shape [256, 1] (say), then stride can be [1, 0] - handle this case
TORCH_CHECK_VALUE(
scale_a.stride(1) == 1 ||
scale_a.size(1) == 1,
"expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1)
);
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
auto scaling_choice_a = ScalingType::RowWise;
@ -1101,19 +1096,6 @@ _scaled_mxfp8_mxfp8(
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
void
_check_mxfp4_support() {
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
// Only on B200 GPUs
TORCH_CHECK_NOT_IMPLEMENTED(
// B200 = 10.0, B300 = 10.3
dprops->major == 10,
"MXFP4 scaling only supported in CUDA for B200/B300"
);
#endif
}
Tensor&
_scaled_mxfp4_mxfp4(
@ -1126,7 +1108,6 @@ _scaled_mxfp4_mxfp4(
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#else
_check_mxfp4_support();
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",

View File

@ -337,6 +337,10 @@ Tensor _convolution_out(
TORCH_CHECK(
3 == ndim || 4 == ndim || 5 == ndim,
"convolution only supports 3D, 4D, 5D tensor");
// get computation format for Conv/TransposedConv
bool is_channels_last_suggested =
use_channels_last_for_conv(input_r, weight_r);
Tensor input = input_r, weight = weight_r;
// PyTorch does not support ChannelsLast1D case,
// thus we need the transformation here
@ -344,8 +348,13 @@ Tensor _convolution_out(
input = view4d(input_r);
weight = view4d(weight_r);
}
// get computation format for Conv/TransposedConv
bool is_channels_last_suggested = use_channels_last_for_conv(input, weight);
// ensure the input/weight/bias/output are congituous in desired format
at::MemoryFormat mfmt = is_channels_last_suggested
? get_cl_tag_by_ndim(input.ndimension())
: at::MemoryFormat::Contiguous;
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
input = input.contiguous(mfmt);
weight = weight.contiguous(mfmt);
auto k = weight.ndimension();
if (k == input.ndimension() + 1) {
@ -379,14 +388,6 @@ Tensor _convolution_out(
expand_param_if_needed(output_padding_, "output_padding", dim);
params.groups = groups_;
}
// ensure the input/weight/bias/output are congituous in desired format
at::MemoryFormat mfmt = is_channels_last_suggested
? get_cl_tag_by_ndim(input.ndimension())
: at::MemoryFormat::Contiguous;
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
input = input.contiguous(mfmt);
weight = weight.contiguous(mfmt);
check_shape_forward(input, weight, bias, params, true);
Tensor output;
@ -513,9 +514,18 @@ Tensor convolution_overrideable(
at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;
auto k = weight_r.ndimension();
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
if (xpu_conv_use_channels_last(input_r, weight_r)) {
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d
: at::MemoryFormat::ChannelsLast;
}
Tensor input_c = input_r.contiguous(backend_memory_format);
Tensor weight_c = weight_r.contiguous(backend_memory_format);
return _convolution(
input_r,
weight_r,
input_c,
weight_c,
bias_r,
stride_,
padding_,

View File

@ -1,342 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/BlasBackend.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/ceil_div.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/xpu/Blas.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
namespace at::native {
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace {
/*
* Scaling Type Determination:
* ---------------------------
* Conditions and corresponding Scaling Types:
*
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
* - Returns BlockWise (with additional size checks).
*
* - Else if scale.numel() == 1:
* - Returns TensorWise.
*
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) ==
* 1:
* - Returns RowWise.
*
* - Otherwise:
* - Returns Error.
*/
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return at::isFloat8Type(t.scalar_type()) &&
scale.scalar_type() == at::kFloat && scale.numel() == 1;
}
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (
at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat &&
scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 &&
scale.is_contiguous());
}
bool is_desired_scaling(
const at::Tensor& t,
const at::Tensor& scale,
ScalingType desired_scaling) {
auto result = desired_scaling == ScalingType::TensorWise
? is_tensorwise_scaling(t, scale)
: is_rowwise_scaling(t, scale);
return result;
}
std::pair<ScalingType, ScalingType> get_joint_scaling(
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
const at::Tensor& a,
const at::Tensor& b,
const at::Tensor& scale_a,
const at::Tensor& scale_b) {
for (auto [lhs, rhs] : options) {
if (is_desired_scaling(a, scale_a, lhs) &&
is_desired_scaling(b.t(), scale_b.t(), rhs)) {
return {lhs, rhs};
}
}
TORCH_CHECK(
false,
"Invalid scaling configuration.\n"
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
a.size(0),
", 1) and scale_b should be (1, ",
b.size(1),
"), and both should be contiguous.\n"
"Got a.dtype()=",
a.scalar_type(),
", scale_a.dtype()=",
scale_a.scalar_type(),
", scale_a.size()=",
scale_a.sizes(),
", scale_a.stride()=",
scale_a.strides(),
", ",
"b.dtype()=",
b.scalar_type(),
", scale_b.dtype()=",
scale_b.scalar_type(),
", scale_b.size()=",
scale_b.sizes(),
" and scale_b.stride()=",
scale_b.strides());
}
Tensor& _scaled_gemm(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const ScalingType scaling_choice_a,
const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out,
const std::optional<Tensor>& alpha = std::nullopt) {
// TODO: scale_result and alpha is not defined or used!
std::optional<Tensor> scaled_result = std::nullopt;
at::native::onednn::scaled_matmul(
mat1,
mat2,
out,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
scaled_result,
use_fast_accum);
return out;
}
} // namespace
// Computes matrix multiply + bias while applying scaling to input and output
// matrices Scales are only applicable when matrices are of Float8 type and
// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit
// type, scale_result is not applied. Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
// - If 1-dimensional tensors are used then scale_a should be size =
// mat1.size(0)
// and scale_b should have size = to mat2.size(1)
// Arguments:
// - `mat1`: the first operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `mat2`: the second operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
// - `out_dtype`: the output dtype, can either be a float8 or a higher
// precision floating point type
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_result`: a scalar tensor with the scale of the output, only
// utilized if the output is a float8 type
// - `use_fast_accum`: Not applicable for XPU. For now, it should always be
// false.
// - `out`: a reference to the output tensor
Tensor& _scaled_mm_out_xpu(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
// Note: fast_accum is not supported in XPU for now.
TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now.");
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0],
"mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
" and ",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
")");
// Check what type of scaling we are doing based on inputs. This list is
// sorted by decreasing priority.
// List of supported datatypes for XPU with oneDNN:
// https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
{
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
},
mat1,
mat2,
scale_a,
scale_b);
TORCH_CHECK(
!scale_result ||
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
"scale_result must be a float scalar");
TORCH_CHECK(
!bias || bias->numel() == mat2.sizes()[1],
"Bias must be size ",
mat2.sizes()[1],
" but got ",
bias->numel());
TORCH_CHECK(
mat1.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
").");
TORCH_CHECK(
mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0,
"mat2 shape (",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
") must be divisible by 16");
// Check types
TORCH_CHECK(
!out_dtype || *out_dtype == out.scalar_type(),
"out_dtype must match output matrix type");
TORCH_CHECK(
at::isFloat8Type(mat1.scalar_type()),
"Expected mat1 to be Float8 matrix got ",
mat1.scalar_type());
TORCH_CHECK(
at::isFloat8Type(mat2.scalar_type()),
"Expected mat2 to be Float8 matrix got ",
mat2.scalar_type());
// TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not
// support 2D scales, only 1D. Needs to add more checks there.
if (bias) {
TORCH_CHECK(
bias->scalar_type() == kFloat ||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
bias->scalar_type() == c10::ScalarType::Half,
"Bias must be Float32 or BFloat16 or Half, but got ",
bias->scalar_type());
}
{
auto bias_ = bias.value_or(Tensor());
auto scale_result_ = scale_result.value_or(Tensor());
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{
{out, "out", 0},
{mat1, "mat1", 1},
{mat2, "mat2", 2},
{bias_, "bias", 3},
{scale_a, "scale_a", 4},
{scale_b, "scale_b", 5},
{scale_result_, "scale_result", 6}};
checkAllSameGPU(__func__, targs);
}
// Validation checks have passed lets resize the output to actual size
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
// kernels do not support this case).
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
// `out` was created with `at::empty`. In the case where we are multiplying
// MxK by KxN and K is the zero dim, we need to initialize here to properly
// return a tensor of zeros.
if (mat1_sizes[1] == 0) {
out.zero_();
}
return out;
}
// TODO: Scale_result is not supported by now!!
return _scaled_gemm(
mat1,
mat2,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
use_fast_accum,
out);
}
Tensor _scaled_mm_xpu(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_out_xpu(
mat_a,
mat_b,
scale_a,
scale_b,
bias,
scale_result,
out_dtype,
use_fast_accum,
out);
}
} // namespace at::native

View File

@ -1,4 +1,3 @@
#include <ATen/BlasBackend.h>
#include <ATen/Tensor.h>
#include <ATen/core/Tensor.h>
#include <c10/core/ScalarType.h>
@ -9,6 +8,7 @@
#include <oneapi/dnnl/dnnl.hpp>
namespace at::native::onednn {
at::Tensor broadcast_bias2D(
at::Tensor& dst,
at::Tensor& bias,
@ -328,236 +328,4 @@ void quantized_matmul(
result.copy_(dst);
}
// Describes how to configure oneDNN scales for a given role/ScalingType
struct ScaleSpec {
// specifies the way scale values will be applied to an ARG tensor.
int mask;
// specifies how scales are grouped along dimensions where
// multiple scale factors are used.
dnnl::memory::dims groups;
// specifies data type for scale factors.
dnnl::memory::data_type dtype;
// Helper to compute expected number of elements for scale tensors
// arg_type: "src" for SRC (groups pattern {1, X}),
// "wei" for WEIGHTS (groups pattern {X, 1})
int64_t expected_numel(
int64_t outer_dim,
int64_t inner_dim,
const std::string& arg_type) const {
if (groups == dnnl::memory::dims{1, 1})
return 1; // tensorwise scaling
TORCH_CHECK(
arg_type == "src" || arg_type == "wei",
"Expected arg_type to be 'src' or 'wei', but got '",
arg_type,
"'");
// For rowwise: SRC groups={1, K}, WEI groups={K, 1}
TORCH_INTERNAL_ASSERT(
(groups == dnnl::memory::dims{1, inner_dim} ||
groups == dnnl::memory::dims{inner_dim, 1}),
"The groups must be either {1, inner_dim} or {inner_dim, 1}. But got ",
groups,
".");
return outer_dim;
}
// Normalize an incoming scale tensor to contiguous storage and appropriate
// dtype/view
at::Tensor normalize(const at::Tensor& scale) const {
TORCH_INTERNAL_ASSERT(
dtype == dnnl::memory::data_type::f32,
"tensor scale currently must be f32, but got scale dtype: ",
scale.scalar_type());
return scale.to(at::kFloat).contiguous();
}
};
// This function defines how to set scales mask and groups according to:
// https://github.com/uxlfoundation/oneDNN/blob/main/tests/benchdnn/doc/knobs_attr.md#--attr-scales
// The returned value will be used in
// `set_scales(arg, mask, groups, data_type)`.
inline ScaleSpec make_scale_spec(
at::blas::ScalingType scaling_type,
int64_t M,
int64_t K,
int64_t N,
const std::string& arg_type) {
TORCH_CHECK(
arg_type == "src" || arg_type == "wei",
"Expected arg_type to be 'src' or 'wei', but got '",
arg_type,
"'");
TORCH_INTERNAL_ASSERT(
(scaling_type == at::blas::ScalingType::TensorWise ||
scaling_type == at::blas::ScalingType::RowWise),
"Currently only support scaling_type for TensorWise or RowWise");
int64_t dim = K; // Currently only K is used for grouping
bool is_src = (arg_type == "src");
if (scaling_type == at::blas::ScalingType::TensorWise) {
// Scale tensorwise. The same as `--attr-scales=common`.
// mask=0 : scale whole tensor
// groups={1, 1}: indicates that there is only one group for scaling
return {0, {1, 1}, dnnl::memory::data_type::f32};
} else {
// (scaling_type == at::blas::ScalingType::RowWise)
// Scale RowWise. The same as `--attr-scales=per_dim_01`.
// mask={(1 << 0) | (1 << 1)}: Scale on both dim0 and dim1
// SRC: groups={1, K}, WEIGHTS: groups={K, 1}
return {
(1 << 0) | (1 << 1),
is_src ? dnnl::memory::dims{1, dim} : dnnl::memory::dims{dim, 1},
dnnl::memory::data_type::f32};
}
}
sycl::event scaled_matmul(
const Tensor& mat1,
const Tensor& mat2,
Tensor& result,
const Tensor& scale_a,
const Tensor& scale_b,
at::blas::ScalingType scaling_choice_a,
at::blas::ScalingType scaling_choice_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
bool use_fast_accum) {
auto& engine = GpuEngineManager::Instance().get_engine();
auto& stream = GpuStreamManager::Instance().get_stream();
// This function will do steps with following steps
// 1. create memory descriptor
// 2. call write_to_dnnl_memory() to actually write memory
// 3. execute
const int64_t M = mat1.size(0);
const int64_t K = mat1.size(1);
const int64_t N = mat2.size(1);
// 1.1 Create memory descriptor
dnnl::memory::desc src_md = get_onednn_md(mat1);
dnnl::memory::desc weights_md = get_onednn_md(mat2);
dnnl::memory::desc dst_md = get_onednn_md(result);
// scale_a and scale_b has already be checked in `is_desired_scaling()` call.
// So we could directly get their memory desc and set later.
dnnl::memory::desc scale_a_md = get_onednn_md(scale_a);
dnnl::memory::desc scale_b_md = get_onednn_md(scale_b);
dnnl::memory::desc bias_md;
bool with_bias = bias.has_value();
at::Tensor possible_reshaped_bias = bias.value_or(at::Tensor());
if (with_bias) {
if (possible_reshaped_bias.dim() == 1) {
possible_reshaped_bias =
possible_reshaped_bias.reshape({1, possible_reshaped_bias.size(0)});
bias_md = get_onednn_md(possible_reshaped_bias);
} else {
bias_md = get_onednn_md(possible_reshaped_bias);
}
}
// 1.2 Create primitive descriptor and set scales mask
const ScaleSpec src_spec = make_scale_spec(scaling_choice_a, M, K, N, "src");
const ScaleSpec wei_spec = make_scale_spec(scaling_choice_b, M, K, N, "wei");
dnnl::primitive_attr op_attr = dnnl::primitive_attr();
#if ONEDNN_SUPPORT_DETERMINISTIC
if (at::globalContext().deterministicAlgorithms() ||
at::globalContext().deterministicMkldnn())
op_attr.set_deterministic(true);
#endif
std::vector<int64_t> default_groups;
op_attr.set_scales(
DNNL_ARG_SRC, src_spec.mask, src_spec.groups, src_spec.dtype);
op_attr.set_scales(
DNNL_ARG_WEIGHTS, wei_spec.mask, wei_spec.groups, wei_spec.dtype);
// scale_result tensor currently only supports scalar(TensorWise Scaling).
bool with_dst_scale = scale_result && scale_result->defined();
if (with_dst_scale) {
op_attr.set_scales(DNNL_ARG_DST, 0, {1}, dnnl::memory::data_type::f32);
}
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
// 1.3 Create the matmul primitive descriptor
dnnl::matmul::primitive_desc matmul_pd = with_bias
? dnnl::matmul::primitive_desc(
engine, src_md, weights_md, bias_md, dst_md, op_attr)
: dnnl::matmul::primitive_desc(
engine, src_md, weights_md, dst_md, op_attr);
// 1.4 (Possible) Additional Checks
// TODO: In case there are memory desc does not align with the actual tensor,
// we might need to reorder weights similar to CPU's reorder_if_differ_in()
// call. For example, weights not the same as matmul_pd.weights_desc(),
// 2. Prepare memory
// Create memory
auto src_usr_m = make_onednn_memory(src_md, engine, mat1.data_ptr());
auto weights_usr_m = make_onednn_memory(weights_md, engine, mat2.data_ptr());
auto dst_usr_m = make_onednn_memory(dst_md, engine, result.data_ptr());
dnnl::memory b_usr_m;
if (with_bias) {
b_usr_m =
make_onednn_memory(bias_md, engine, possible_reshaped_bias.data_ptr());
}
// Prepare runtime scale memories (flat 1-D views) using the specs
auto make_scale_mem_from_spec = [&](const ScaleSpec& spec,
int64_t expected_numel,
const at::Tensor& scale_tensor) {
at::Tensor prepared = spec.normalize(scale_tensor);
TORCH_CHECK(
prepared.numel() == expected_numel,
"Scale buffer length mismatch. Expected ",
expected_numel,
", got ",
prepared.numel());
dnnl::memory::desc scale_md(
{prepared.numel()}, spec.dtype, dnnl::memory::format_tag::x);
return make_onednn_memory(scale_md, engine, prepared.data_ptr());
};
auto scratchpad =
make_onednn_memory(matmul_pd.scratchpad_desc(), engine, nullptr);
// 3. Setup Args for exec
std::unordered_map<int, dnnl::memory> args;
args.insert({DNNL_ARG_SRC, src_usr_m});
args.insert({DNNL_ARG_WEIGHTS, weights_usr_m});
args.insert({DNNL_ARG_DST, dst_usr_m});
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
if (with_bias) {
args.insert({DNNL_ARG_BIAS, b_usr_m});
}
// Attach runtime scales using specs
auto src_sc_mem = make_scale_mem_from_spec(
src_spec, src_spec.expected_numel(M, K, "src"), scale_a);
auto wei_sc_mem = make_scale_mem_from_spec(
wei_spec, wei_spec.expected_numel(N, K, "wei"), scale_b);
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_mem});
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_sc_mem});
if (with_dst_scale) {
// Bind single f32 scalar as DST scale
at::Tensor dst_scale_f32 = scale_result->to(at::kFloat).contiguous();
dnnl::memory::desc dst_sc_md(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto dst_sc_mem =
make_onednn_memory(dst_sc_md, engine, dst_scale_f32.data_ptr());
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_mem});
}
dnnl::matmul matmul_p = dnnl::matmul(matmul_pd);
sycl::event matmul_fwd_event =
dnnl::sycl_interop::execute(matmul_p, stream, args);
return matmul_fwd_event;
}
} // namespace at::native::onednn

View File

@ -78,10 +78,6 @@ dnnl::memory::data_type get_onednn_dtype(
return dnnl::memory::data_type::f32;
case at::ScalarType::BFloat16:
return dnnl::memory::data_type::bf16;
case at::ScalarType::Float8_e4m3fn:
return dnnl::memory::data_type::f8_e4m3;
case at::ScalarType::Float8_e5m2:
return dnnl::memory::data_type::f8_e5m2;
default:
if (!allow_undef) {
TORCH_CHECK(

View File

@ -1,7 +1,6 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/BlasBackend.h>
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
@ -203,16 +202,4 @@ void sdpa_backward(
Tensor& grad_query,
Tensor& grad_key,
Tensor& grad_value);
sycl::event scaled_matmul(
const Tensor& mat1,
const Tensor& mat2,
Tensor& result,
const Tensor& scale_a,
const Tensor& scale_b,
at::blas::ScalingType scaling_choice_a,
at::blas::ScalingType scaling_choice_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
bool use_fast_accum);
} // namespace at::native::onednn

View File

@ -82,7 +82,6 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
std::string to_hex_key(float);
std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);

View File

@ -301,10 +301,6 @@ std::string getArrayRefString(const IntArrayRef s) {
return fmt::to_string(fmt::join(s, ","));
}
std::string to_hex_key(float f) {
return fmt::format("{:a}", f);
}
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
fmt::basic_memory_buffer<char, 100> buffer;
auto buf_iterator = std::back_inserter(buffer);

View File

@ -40,7 +40,7 @@ inline c10::metal::opmath_t<T> matmul_inner(
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint k = 0; k < TILE_DIM; k++) {
sum += c10::metal::mul(A_tile[tid.y][k], B_tile[k][tid.x]);
sum += A_tile[tid.y][k] * B_tile[k][tid.x];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@ -96,9 +96,7 @@ kernel void addmm(
auto bias =
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
static_cast<T>(
c10::metal::mul(alpha_beta[0], sum) +
c10::metal::mul(alpha_beta[1], bias));
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
}
}
@ -834,10 +832,6 @@ INSTANTIATE_MM_OPS(float);
INSTANTIATE_MM_OPS(half);
INSTANTIATE_MM_OPS(bfloat);
// Complex MM
INSTANTIATE_MM_OPS(float2);
INSTANTIATE_MM_OPS(half2);
// Integral MM
INSTANTIATE_MM_OPS(long);
INSTANTIATE_MM_OPS(int);

View File

@ -69,139 +69,75 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
auto out = at::empty({batchSize, num_head, qSize, headSize}, query.options());
auto attn = at::empty({batchSize, num_head, qSize, maxSeqLength}, query.options());
auto scale_factor = sdp::calculate_scale(query, scale).expect_float();
static const bool is_macOS_26_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_26_0_PLUS);
@autoreleasepool {
auto mkey = __func__ + getTensorsStringKey({query, key, value}) + ":" + std::to_string(is_causal) + ":" +
std::to_string(attn_mask.has_value());
auto cachedGraph =
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor
shape:getMPSShape({1})
dataType:MPSDataTypeFloat32];
CachedGraph* cachedGraph;
//if(is_macOS_26_0_or_newer) {
if(true) {
cachedGraph =
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
if (is_causal) {
MPSShape* maskShape = @[@(qSize), @(maxSeqLength)];
auto x = [mpsGraph coordinateAlongAxis:-1 withShape:@[@(qSize), @1] name:nil];
auto y = [mpsGraph coordinateAlongAxis:-2 withShape:@[@1, @(maxSeqLength)] name:nil];
auto isLess = [mpsGraph lessThanOrEqualToWithPrimaryTensor:x secondaryTensor:y name:nil];
auto causalMask = [mpsGraph selectWithPredicateTensor:isLess
truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:qTensor.dataType]
falsePredicateTensor:[mpsGraph constantWithScalar:-INFINITY dataType:qTensor.dataType]
name:nil];
graph->maskTensor = causalMask;
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
}
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
// bug in MacOS15, without this trick SDPA leaks memory, adding 0.0f gets ignored(still takes SDPA sequence
// path which leaks)
auto oneTensor = [mpsGraph constantWithScalar:1e-20f shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:oneTensor name:nil];
}
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
// Overwrites expected NANs in sm with zeros.
// auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
// auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
// auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
// auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
// auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
//
// auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
// MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
// truePredicateTensor:zeroTensor
// falsePredicateTensor:sm
// name:nil];
//
// auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
MPSGraphTensor* output;
if(graph->maskTensor != nil) {
output = [mpsGraph scaledDotProductAttentionWithQueryTensor:qTensor
keyTensor:kTensor
valueTensor:vTensor
maskTensor:graph->maskTensor
scale:scale_factor
name:@"MPSGraph SDPA"];
} else {
output = [mpsGraph scaledDotProductAttentionWithQueryTensor:qTensor
keyTensor:kTensor
valueTensor:vTensor
scale:scale_factor
name:@"MPSGraph SDPA"];
}
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
// graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
} else {
cachedGraph =
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor
shape:getMPSShape({1})
dataType:MPSDataTypeFloat32];
if (is_causal) {
auto causalMask = [mpsGraph constantWithScalar:1.0f
shape:getMPSShape({qSize, maxSeqLength})
dataType:MPSDataTypeBool];
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
truePredicateTensor:maskedMM
falsePredicateTensor:minusInf
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
name:nil];
}
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
// Overwrites expected NANs in sm with zeros.
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
// bug in MacOS15, without this trick SDPA leaks memory, adding 0.0f gets ignored(still takes SDPA sequence
// path which leaks)
auto oneTensor = [mpsGraph constantWithScalar:1e-20f shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:oneTensor name:nil];
}
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
truePredicateTensor:zeroTensor
falsePredicateTensor:sm
name:nil];
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
if (is_causal) {
auto causalMask = [mpsGraph constantWithScalar:1.0f
shape:getMPSShape({qSize, maxSeqLength})
dataType:MPSDataTypeBool];
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
truePredicateTensor:maskedMM
falsePredicateTensor:minusInf
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
name:nil];
}
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
// Overwrites expected NANs in sm with zeros.
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
truePredicateTensor:zeroTensor
falsePredicateTensor:sm
name:nil];
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
}
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);
auto vPlaceholder = Placeholder(cachedGraph->vTensor, value);
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, out);
// auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
NSDictionary* feeds = nil;
if (!attn_mask) {
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder);
@ -209,8 +145,7 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
auto mPlaceholder = Placeholder(cachedGraph->maskTensor, *attn_mask);
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder, mPlaceholder);
}
// NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder);
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outs);
}

View File

@ -121,7 +121,7 @@ Tensor& do_metal_addmm(const Tensor& self,
const Scalar& alpha,
const Scalar& beta,
const Tensor& bias) {
if (beta.isFloatingPoint() && alpha.isFloatingPoint() && beta.toDouble() == 0 && alpha.toDouble() == 1) {
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
return do_metal_mm(self, other, output);
}
auto stream = getCurrentMPSStream();
@ -147,15 +147,13 @@ Tensor& do_metal_addmm(const Tensor& self,
std::array<int64_t, 2> i64;
std::array<int32_t, 2> i32;
std::array<float, 2> f32;
std::array<c10::complex<float>, 2> c64;
} alpha_beta{};
} alpha_beta;
if (output.scalar_type() == kLong) {
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
} else if (c10::isIntegralType(output.scalar_type(), true)) {
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
} else if (c10::isComplexType(output.scalar_type())) {
alpha_beta.c64 = {alpha.toComplexFloat(), beta.toComplexFloat()};
} else {
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
}
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
@ -192,16 +190,10 @@ std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* gr
bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {
static bool always_use_metal = c10::utils::has_env("PYTORCH_MPS_PREFER_METAL");
constexpr auto max_stride_size = 32768;
constexpr auto max_complex_inner_size = 2048;
static bool is_macos_14_4_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
if (always_use_metal || c10::isIntegralType(self.scalar_type(), true)) {
return true;
}
// multiplicationWithPrimaryTensor: returns incorrect results if inner size exceeds 2048
// See https://github.com/pytorch/pytorch/issues/167727#issuecomment-3529308548
if (c10::isComplexType(self.scalar_type()) && self.size(1) > max_complex_inner_size) {
return true;
}
return !is_macos_14_4_or_newer &&
(self.stride(0) > max_stride_size || self.stride(1) > max_stride_size || self.size(0) > max_stride_size ||
self.size(1) > max_stride_size || other.stride(0) > max_stride_size || other.stride(1) > max_stride_size ||

View File

@ -91,30 +91,25 @@ static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/Repeat_metallib.h>
#endif
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
TORCH_CHECK(repeat.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
template <typename index_t>
void computeRepeatIndices(const index_t* repeat_ptr,
const int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {
id<MTLBuffer> repeatBuffer = reinterpret_cast<id<MTLBuffer>>(repeat_ptr);
id<MTLBuffer> cumsumBuffer = reinterpret_cast<id<MTLBuffer>>(cumsum_ptr);
id<MTLBuffer> resultBuffer = reinterpret_cast<id<MTLBuffer>>(result_ptr);
TORCH_CHECK(repeatBuffer && cumsumBuffer && resultBuffer);
std::string scalar_type;
if (repeat.scalar_type() == kInt) {
if constexpr (std::is_same_v<index_t, int32_t>) {
scalar_type = "int32_t";
} else if (repeat.scalar_type() == kLong) {
} else if constexpr (std::is_same_v<index_t, int64_t>) {
scalar_type = "int64_t";
} else {
TORCH_CHECK(false, "repeats has to be Long or Int tensor");
TORCH_CHECK(false, "repeat_interleave: unsupported indexing data type");
}
if (repeat.size(0) == 0) {
return at::empty_like(repeat, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
Tensor repeat_ = repeat.contiguous();
Tensor cumsum = repeat.cumsum(0);
int64_t total = 0;
if (output_size.has_value()) {
total = output_size.value();
} else {
total = cumsum[-1].item<int64_t>();
TORCH_CHECK((repeat >= 0).all().item<uint8_t>(), "repeats can not be negative");
}
auto result = at::empty({total}, repeat.options());
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@ -126,13 +121,20 @@ Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output
getMPSProfiler().beginProfileKernel(pipelineState, "repeat_interleave:" + scalar_type, false);
[computeEncoder setComputePipelineState:pipelineState];
mps::mtl_setArgs(computeEncoder, repeat_, cumsum, result, repeat.size(0));
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, repeat.size(0));
mps::mtl_setArgs(computeEncoder, repeatBuffer, cumsumBuffer, resultBuffer, size);
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, size);
getMPSProfiler().endProfileKernel(pipelineState);
}
});
return result;
}
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
Tensor output;
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
});
return output;
}
} // namespace at::native

View File

@ -5,7 +5,6 @@
#include <ATen/native/Resize.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/mps/OperationUtils.h>
#include <algorithm>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -90,21 +89,13 @@ static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor&
auto clamp_shape = clamp_opt->sizes();
auto input_shape = input_t.sizes();
if (num_clamp_dims > num_input_dims) {
auto leading_dims = num_clamp_dims - num_input_dims;
for (int64_t i = 0; i < leading_dims; ++i) {
TORCH_CHECK(clamp_shape[i] == 1,
op_name + ": clamp tensor leading shape must be 1 to broadcast with input tensor");
}
}
TORCH_CHECK(num_clamp_dims <= num_input_dims,
op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
auto clamp_idx = num_clamp_dims - 1;
auto input_idx = num_input_dims - 1;
auto common_dims = std::min(num_clamp_dims, num_input_dims);
for (int64_t i = 0; i < common_dims; ++i)
for (int i = 0; i < num_clamp_dims; i++)
// One of the indices is allowed to be 1; will be handled by broadcast
TORCH_CHECK(clamp_shape[clamp_idx - i] == input_shape[input_idx - i] || clamp_shape[clamp_idx - i] == 1 ||
input_shape[input_idx - i] == 1,
TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] ||
clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1,
op_name + ": clamp tensor trailing shape must match input tensor")
}
}
@ -145,6 +136,9 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
auto result_type = output_t.scalar_type();
IntArrayRef new_min_shape;
IntArrayRef new_max_shape;
auto num_min_dims = min_opt->dim();
auto num_max_dims = max_opt->dim();
auto num_input_dims = input_t.dim();
@ -152,32 +146,24 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
std::vector<int64_t> new_min_arr(num_input_dims);
std::vector<int64_t> new_max_arr(num_input_dims);
if (has_min && num_min_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes());
new_min_shape = IntArrayRef(new_min_arr);
}
if (has_max && num_max_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes());
new_max_shape = IntArrayRef(new_max_arr);
}
Tensor min_opt_tensor;
Tensor max_opt_tensor;
auto reshape_clamp_tensor = [&](const OptionalTensorRef clamp_tensor_ref,
int64_t num_clamp_dims,
std::vector<int64_t>& new_shape_storage) -> Tensor {
IntArrayRef clamp_shape = clamp_tensor_ref->sizes();
bool requires_view = false;
if (num_clamp_dims > num_input_dims) {
clamp_shape = clamp_shape.slice(num_clamp_dims - num_input_dims);
requires_view = true;
} else if (num_clamp_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_clamp_dims, new_shape_storage.data(), clamp_shape);
clamp_shape = IntArrayRef(new_shape_storage);
requires_view = true;
}
return requires_view ? (*clamp_tensor_ref).view(clamp_shape) : *clamp_tensor_ref;
};
if (has_min) {
min_opt_tensor = reshape_clamp_tensor(min_opt, num_min_dims, new_min_arr);
min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt;
}
if (has_max) {
max_opt_tensor = reshape_clamp_tensor(max_opt, num_max_dims, new_max_arr);
max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt;
}
@autoreleasepool {
@ -258,8 +244,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
@autoreleasepool {
// the optional min/max refs could affect how we build the cached graph
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
if (has_min)
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar

View File

@ -4225,7 +4225,7 @@
MTIA: mm_out_mtia
MPS: mm_out_mps
XPU: mm_out_xpu
SparseCPU, SparseCUDA, SparseMPS: _sparse_mm_out
SparseCPU, SparseCUDA: _sparse_mm_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm_out
- func: mm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor

View File

@ -61,7 +61,6 @@ list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp

View File

@ -1,77 +0,0 @@
#include <gtest/gtest.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <atomic>
#include <thread>
#include <vector>
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
// to verify that the data race fix is working correctly
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
if (!at::cuda::is_available()) {
return;
}
constexpr int num_accessor_threads = 15;
constexpr int num_clear_threads = 5;
constexpr int iterations_per_thread = 50;
std::atomic<bool> stop{false};
std::atomic<int> error_count{0};
std::vector<std::thread> threads;
threads.reserve(num_accessor_threads + num_clear_threads);
// Launch accessor threads
for (int i = 0; i < num_accessor_threads; ++i) {
threads.emplace_back([&stop, &error_count]() {
try {
at::cuda::CUDAGuard device_guard(0);
while (!stop.load(std::memory_order_relaxed)) {
const auto handle = at::cuda::getCurrentCUDABlasHandle();
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
if (handle == nullptr || workspace == nullptr) {
error_count++;
}
}
} catch (const std::exception& e) {
error_count++;
}
});
}
// Launch threads that clear workspaces
for (int i = 0; i < num_clear_threads; ++i) {
threads.emplace_back([&error_count]() {
try {
for (int j = 0; j < iterations_per_thread; ++j) {
at::cuda::clearCublasWorkspaces();
std::this_thread::yield();
}
} catch (const std::exception& e) {
error_count++;
}
});
}
// Let them run for a bit
std::this_thread::sleep_for(std::chrono::milliseconds(100));
stop.store(true, std::memory_order_relaxed);
for (auto& thread : threads) {
thread.join();
}
EXPECT_EQ(error_count.load(), 0);
}
int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
c10::cuda::CUDACachingAllocator::init(1);
return RUN_ALL_TESTS();
}

View File

@ -10,13 +10,6 @@
...
}
{
ignore_empty_generic_uninitialised_conditional_jump
Memcheck:Cond
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
...
}
{
Cond_cuda
Memcheck:Cond

View File

@ -9,61 +9,28 @@ def check_perf_csv(filename, threshold, threshold_scale):
"""
Basic performance checking.
"""
try:
df = pd.read_csv(filename)
except FileNotFoundError:
print(f"Error: File {filename} not found")
sys.exit(1)
effective_threshold = threshold * threshold_scale
print(f"Checking {filename} (speedup threshold >= {effective_threshold:.2f}x)\n")
df = pd.read_csv(filename)
failed = []
for _, row in df.iterrows():
model_name = row["name"]
speedup = float(row["speedup"])
abs_latency = float(row["abs_latency"])
compilation_latency = float(row["compilation_latency"])
compression_ratio = float(row["compression_ratio"])
eager_peak_mem = float(row["eager_peak_mem"])
dynamo_peak_mem = float(row["dynamo_peak_mem"])
speedup = row["speedup"]
if speedup < threshold * threshold_scale:
failed.append(model_name)
perf_summary = f"{model_name:34} speedup={speedup:.3f}x"
if pd.notna(abs_latency):
perf_summary += f", latency={abs_latency:.1f} ms/iter"
if pd.notna(compilation_latency):
perf_summary += f", compile={compilation_latency:.3f}s"
if pd.notna(compression_ratio):
perf_summary += f", mem_ratio={1 / compression_ratio:.2f}x"
if pd.notna(eager_peak_mem) and pd.notna(dynamo_peak_mem):
perf_summary += (
f" (eager={eager_peak_mem:.1f} GB, dynamo={dynamo_peak_mem:.1f} GB)"
)
if speedup < effective_threshold:
failed.append((model_name, speedup))
print(perf_summary)
print(f"{model_name:34} {speedup}")
if failed:
print(
textwrap.dedent(
f"""
Error {len(failed)} model(s) performance regressed
{" ".join([name for name, _ in failed])}
Error {len(failed)} models performance regressed
{" ".join(failed)}
"""
)
)
for name, sp in sorted(failed, key=lambda x: x[1]):
pct_from_target = (sp / effective_threshold - 1.0) * 100.0
print(
f" - {name}: {sp:.3f}x (< {effective_threshold:.2f}x; {pct_from_target:.1f}% from target)"
)
sys.exit(1)
else:
print(
f"\nAll {len(df)} model(s) passed threshold check (>= {effective_threshold:.2f}x)"
)
if __name__ == "__main__":
@ -77,7 +44,7 @@ if __name__ == "__main__":
"-s",
type=float,
default=1.0,
help="multiply threshold by this value to relax the check",
help="multiple threshold by this value to relax the check",
)
args = parser.parse_args()
check_perf_csv(args.file, args.threshold, args.threshold_scale)

View File

@ -2379,9 +2379,7 @@ class BenchmarkRunner:
print(
f"Load model outputs from {self.args.compare_model_outputs_with} to compare"
)
saved_result = torch.load(
self.args.compare_model_outputs_with, weights_only=False
)
saved_result = torch.load(self.args.compare_model_outputs_with)
is_bitwise_same = bitwise_same(saved_result, new_result)
if not is_bitwise_same:
print(

View File

@ -189,10 +189,6 @@ skip:
- hf_Whisper
- hf_distil_whisper
- timm_vision_transformer_large
# https://github.com/pytorch/pytorch/issues/167895
- stable_diffusion
- stable_diffusion_text_encoder
- stable_diffusion_unet
device:
cpu:

View File

@ -2,7 +2,6 @@
# These load paths point to different files in internal and OSS environment
load("@bazel_skylib//lib:paths.bzl", "paths")
load("//tools/build_defs:cell_defs.bzl", "get_fbsource_cell")
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
@ -591,9 +590,6 @@ def pt_operator_query_codegen(
pt_allow_forced_schema_registration = True,
compatible_with = [],
apple_sdks = None):
if get_fbsource_cell() == "fbcode":
return
oplist_dir_name = name + "_pt_oplist"
# @lint-ignore BUCKLINT
@ -869,9 +865,6 @@ def define_buck_targets(
pt_xplat_cxx_library = fb_xplat_cxx_library,
c2_fbandroid_xplat_compiler_flags = [],
labels = []):
if get_fbsource_cell() == "fbcode":
return
# @lint-ignore BUCKLINT
fb_native.filegroup(
name = "metal_build_srcs",

View File

@ -44,7 +44,7 @@ struct C10_API SafePyObject {
(*other.pyinterpreter_)->incref(other.data_);
}
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_);
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
}
data_ = other.data_;
pyinterpreter_ = other.pyinterpreter_;
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
~SafePyObject() {
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_);
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
}
}

View File

@ -34,6 +34,20 @@ namespace c10 {
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
// regarding macros.
template <typename T>
struct CppTypeToScalarType;
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
template <> \
struct CppTypeToScalarType<cpp_type> \
: std:: \
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
#undef SPECIALIZE_CppTypeToScalarType
#define DEFINE_CONSTANT(_, name) \
constexpr ScalarType k##name = ScalarType::name;
@ -92,6 +106,13 @@ inline bool isComplexType(ScalarType t) {
t == ScalarType::ComplexDouble);
}
inline bool isQIntType(ScalarType t) {
// Don't forget to extend this when adding new QInt types
return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
t == ScalarType::QUInt2x4;
}
inline bool isBitsType(ScalarType t) {
return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||

View File

@ -48,30 +48,6 @@ void warnDeprecatedDataPtr() {
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
}
void StorageImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void StorageImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool StorageImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
// Allowlist verification.
// Only if the devicetype is in the allowlist,

View File

@ -105,12 +105,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
data_ptr_.clear();
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
size_t nbytes() const {
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
TORCH_CHECK(!size_bytes_is_heap_allocated_);
@ -376,18 +370,4 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
bool resizable,
std::optional<at::Device> device_opt);
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
} // namespace c10

View File

@ -277,6 +277,7 @@ void TensorImpl::release_resources() {
if (storage_) {
storage_ = {};
}
pyobj_slot_.maybe_destroy_pyobj();
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
@ -988,30 +989,6 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
}
}
void TensorImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void TensorImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool TensorImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
namespace impl {
namespace {

View File

@ -2178,12 +2178,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return &pyobj_slot_;
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
private:
// See NOTE [std::optional operator usage in CUDA]
// We probably don't want to expose this publicly until
@ -3085,19 +3079,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
friend class C10_TensorImpl_Size_Check_Dummy_Class;
};
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
// Note [TensorImpl size constraints]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Changed the size of TensorImpl? If the size went down, good for

View File

@ -11,11 +11,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
void incref(PyObject* pyobj) const override {} // do nothing
void decref(PyObject* pyobj) const override {} // do nothing
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
return false;
}
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
} // do nothing
#define PANIC(m) \
TORCH_INTERNAL_ASSERT( \
@ -23,10 +20,6 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
"attempted to call " #m \
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
size_t refcnt(PyObject* pyobj) const override {
PANIC(refcnt);
}
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
PANIC(detach);
}

View File

@ -18,9 +18,6 @@ namespace c10 {
struct IValue;
class OperatorHandle;
struct TensorImpl;
namespace impl {
struct PyObjectSlot;
} // namespace impl
} // namespace c10
namespace torch::jit {
@ -129,12 +126,9 @@ struct C10_API PyInterpreterVTable {
// Run Py_INCREF on a PyObject.
virtual void incref(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call.
virtual void decref(PyObject* pyobj) const = 0;
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL.
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
// Run Py_REFCNT on a PyObject.
virtual size_t refcnt(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
// Perform a detach by deferring to the __torch_dispatch__ implementation of
// detach, which will also arrange for the PyObject to get copied in this

View File

@ -0,0 +1,56 @@
#include <c10/core/impl/PyObjectSlot.h>
namespace c10::impl {
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot::~PyObjectSlot() {
maybe_destroy_pyobj();
}
void PyObjectSlot::maybe_destroy_pyobj() {
if (owns_pyobj()) {
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
(*pyobj_interpreter_.load(std::memory_order_acquire))
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
// NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references
// to the PyObject (if there are references to the PyObject,
// then the PyObject holds an owning reference to the tensor).
// So it is OK to clear pyobj_ here as it is impossible for it to
// be used again (modulo weak reference races)
pyobj_ = nullptr; // for safety
}
}
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
}
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter) {
return *interpreter;
}
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
}
bool PyObjectSlot::owns_pyobj() {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
}
void PyObjectSlot::set_owns_pyobj(bool b) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
pyobj_ = reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
}
} // namespace c10::impl

View File

@ -8,58 +8,117 @@
#include <atomic>
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10::impl {
struct C10_API PyObjectSlot {
public:
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot();
~PyObjectSlot();
void maybe_destroy_pyobj();
// Associate the TensorImpl with the specified PyObject, and, if necessary,
// also tag the interpreter.
//
// NB: This lives in a header so that we can inline away the switch on status
//
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
// PyObject if necessary!
void init_pyobj(PyObject* pyobj) {
pyobj_interpreter_.store(
getGlobalPyInterpreter(), std::memory_order_relaxed);
pyobj_ = pyobj;
}
// Query the PyObject interpreter. This may return null if there is no
// interpreter.
PyInterpreter* pyobj_interpreter() const {
return pyobj_interpreter_.load(std::memory_order_acquire);
// interpreter. This is racy!
PyInterpreter* pyobj_interpreter();
PyObject* _unchecked_untagged_pyobj() const;
// Test the interpreter tag. If tagged for the current interpreter, return
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
// returns a nullopt. If it is definitely invalid, raises an error.
//
// If `ignore_hermetic_tls` is false and this function is called from a
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
// context is ignored, allowing you to check the interpreter tag of a
// nonhermetic PyObject from within a hermetic context. This is necessary
// because there are some cases where the deallocator function of a
// nonhermetic PyObject is called from within a hermetic context, so it must
// be properly treated as a nonhermetic PyObject.
//
// NB: this lives in header so that we can avoid actually creating the
// std::optional
// @todo alban: I'm not too sure what's going on here, we can probably delete
// it but it's worthwhile making sure
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter == nullptr) {
return std::nullopt;
}
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
return _unchecked_untagged_pyobj();
}
}
PyInterpreter& load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
TORCH_INTERNAL_ASSERT(
interpreter, "cannot access PyObject for Tensor - no interpreter set");
return *interpreter;
}
PyInterpreter& load_pyobj_interpreter() const;
PyObject* load_pyobj() const {
return pyobj_.load(std::memory_order_acquire);
}
bool owns_pyobj();
void store_pyobj(PyObject* obj) {
pyobj_.store(obj, std::memory_order_release);
}
bool has_unique_reference() const {
PyObject* pyobj = load_pyobj();
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
}
void clear() {
pyobj_.store(nullptr, std::memory_order_relaxed);
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
}
void set_owns_pyobj(bool b);
private:
// This is now always the global interpreter if the PyObject is set.
// Maybe we can remove this field some day...
// This field contains the interpreter tag for this object. See
// Note [Python interpreter tag] for general context
//
// Note [Memory ordering on Python interpreter tag]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What memory_order do we need when accessing this atomic? We don't
// need a single total modification order (as provided by
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
// transition from -1 to some positive integer and never changes afterwards.
// Because there is only one modification, it trivially already has a total
// modification order (e.g., we don't need fences or locked instructions on
// x86)
//
// In fact, one could make a reasonable argument that relaxed reads are OK,
// due to the presence of external locking (GIL) to ensure that interactions
// with other data structures are still correctly synchronized, so that
// we fall in the "Single-Location Data Structures" case as described in
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
// as I get the same assembly in both cases. So I just use the more
// conservative acquire (which will impede compiler optimizations but I don't
// care)
std::atomic<PyInterpreter*> pyobj_interpreter_;
// The PyObject representing this Tensor or nullptr. Ownership is managed
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
// reference is already dead.
std::atomic<PyObject*> pyobj_;
friend class torch::utils::PyObjectPreservation;
// This field contains a reference to a PyObject representing this Tensor.
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
// PyObject for it and set this field. This field does not have to be
// protected by an atomic as it is only allowed to be accessed when you hold
// the GIL, or during destruction of the tensor.
//
// When a PyObject dies, you are obligated to clear this field
// (otherwise, you will try to use-after-free the pyobj); this currently
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
//
// NB: Ordinarily, this should not be a strong reference, as if the
// PyObject owns the Tensor, this would create a reference cycle.
// However, sometimes this ownership flips. To track who owns
// who, this has a single pointer tag indicating whether or not the
// C++ object owns the PyObject (the common case, zero, means PyObject
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
// or check_pyobj for checked access. See references to PyObject
// resurrection in torch/csrc/autograd/python_variable.cpp
PyObject* pyobj_;
};
} // namespace c10::impl

View File

@ -50,13 +50,7 @@ namespace c10 {
/// However, you should prefer to use ArrayRef when possible, because its use
/// of TORCH_CHECK will lead to better user-facing error messages.
template <typename T>
// ArrayRef cannot be derived from. Normally, we would use `final`
// specifier to force this constraint at compile time. However, Intel
// compiler does not recognize ArrayRef as a class template (which is
// required in the definition of at::TensorAccessor, for instance)
// when `final` specifier is used. So, we cannot define ArrayRef as
// final because of the Intel compiler issue.
class ArrayRef : public HeaderOnlyArrayRef<T> {
class ArrayRef final : public HeaderOnlyArrayRef<T> {
public:
/// @name Constructors, all inherited from HeaderOnlyArrayRef except for
/// SmallVector. As inherited constructors won't work with class template

View File

@ -379,11 +379,7 @@ C10_API std::string GetExceptionString(const std::exception& e);
// ----------------------------------------------------------------------------
#ifdef STRIP_ERROR_MESSAGES
#define TORCH_RETHROW(e, ...) \
do { \
(void)e; /* Suppress unused variable warning */ \
throw; \
} while (false)
#define TORCH_RETHROW(e, ...) throw
#else
#define TORCH_RETHROW(e, ...) \
do { \

View File

@ -12,10 +12,6 @@ template <typename, typename...>
class class_;
}
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10 {
class intrusive_ptr_target;
namespace raw {
@ -37,8 +33,6 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
constexpr uint64_t kReferenceCountOne = 1;
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
// Indicates whether the object has a PyObject wrapper.
constexpr uint64_t kHasPyObject = (uint64_t(1) << 63);
template <class TTarget>
struct intrusive_target_default_null_type final {
@ -61,11 +55,7 @@ inline uint32_t refcount(uint64_t combined_refcount) {
}
inline uint32_t weakcount(uint64_t combined_refcount) {
return static_cast<uint32_t>((combined_refcount & ~kHasPyObject) >> 32);
}
inline bool has_pyobject(uint64_t combined_refcount) {
return (combined_refcount & kHasPyObject) != 0;
return static_cast<uint32_t>(combined_refcount >> 32);
}
// The only requirement for refcount increment is that it happens-before
@ -76,6 +66,12 @@ inline uint64_t atomic_combined_refcount_increment(
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
}
inline uint32_t atomic_refcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::refcount(atomic_combined_refcount_increment(
combined_refcount, kReferenceCountOne));
}
inline uint32_t atomic_weakcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::weakcount(atomic_combined_refcount_increment(
@ -103,11 +99,6 @@ inline uint32_t atomic_weakcount_decrement(
combined_refcount, kWeakReferenceCountOne));
}
template <class T, class = void>
struct TargetTraits {
static constexpr bool can_have_pyobject = false;
};
} // namespace detail
/**
@ -164,23 +155,6 @@ class C10_API intrusive_ptr_target {
// we can atomically operate on both at the same time for performance
// and defined behaviors.
//
// Note [PyObject preservation for Tensor and Storages]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// intrusive_ptr has special support for preserving PyObject wrappers
// for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of
// the combined_refcount_ is used to indicate whether the object has a
// PyObject wrapper.
//
// - The PyObject, if it exists, holds a strong reference to the
// intrusive_ptr_target.
//
// - When the refcount goes from 1 to 2, we incref the PyObject.
//
// - When the refcount goes from 2 to 1, we decref the PyObject.
//
// In other words, the intrusive_ptr keeps the PyObject alive as long as there
// are other C++ references to the intrusive_ptr_target.
mutable std::atomic<uint64_t> combined_refcount_;
static_assert(sizeof(std::atomic<uint64_t>) == 8);
static_assert(alignof(std::atomic<uint64_t>) == 8);
@ -198,8 +172,6 @@ class C10_API intrusive_ptr_target {
template <typename T>
friend struct ExclusivelyOwnedTensorTraits;
friend class torch::utils::PyObjectPreservation;
protected:
// protected destructor. We never want to destruct intrusive_ptr_target*
// directly.
@ -283,16 +255,6 @@ class C10_API intrusive_ptr_target {
*/
virtual void release_resources() {}
/**
* These two methods are called when the refcount transitions between one
* and two and the object has a PyObject wrapper.
*/
virtual void incref_pyobject() const {}
virtual void decref_pyobject() const {}
virtual bool try_incref_pyobject() const {
return false;
}
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
return detail::refcount(combined_refcount_.load(order));
}
@ -303,19 +265,6 @@ class C10_API intrusive_ptr_target {
}
};
namespace detail {
#ifndef C10_MOBILE
template <>
struct TargetTraits<c10::intrusive_ptr_target> {
// A generic intrusive_ptr<intrusive_ptr_target> may actually be a TensorImpl
// or StorageImpl, so we have to allow for PyObject support.
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
template <class TTarget, class NullType>
class weak_intrusive_ptr;
@ -365,34 +314,18 @@ class intrusive_ptr final {
void retain_() {
if (target_ != NullType::singleton()) {
uint64_t combined = detail::atomic_combined_refcount_increment(
target_->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined);
uint32_t new_refcount =
detail::atomic_refcount_increment(target_->combined_refcount_);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
new_refcount != 1,
"intrusive_ptr: Cannot increase refcount after it reached zero.");
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 1 to 2, we need to incref the
// PyObject. In other words, we need to ensure that the PyObject stays
// alive now that we have a C++ reference to this object in addition to
// the PyObject itself.
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
target_->incref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!detail::has_pyobject(combined),
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
}
void reset_() noexcept {
if (target_ != NullType::singleton()) {
if (is_uniquely_owned()) {
if (target_->combined_refcount_.load(std::memory_order_acquire) ==
detail::kUniqueRef) {
// Both counts are 1, so there are no weak references and
// we are releasing the last strong reference. No other
// threads can observe the effects of this target_ deletion
@ -404,10 +337,9 @@ class intrusive_ptr final {
auto combined_refcount = detail::atomic_combined_refcount_decrement(
target_->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined_refcount);
bool has_pyobject = detail::has_pyobject(combined_refcount);
if (new_refcount == 0) {
bool should_delete = detail::weakcount(combined_refcount) == 1;
if (detail::refcount(combined_refcount) == 0) {
bool should_delete =
(combined_refcount == detail::kWeakReferenceCountOne);
// See comment above about weakcount. As long as refcount>0,
// weakcount is one larger than the actual number of weak references.
// So we need to decrement it here.
@ -424,18 +356,6 @@ class intrusive_ptr final {
if (should_delete) {
delete target_;
}
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 2 to 1, we need to decref the
// PyObject. In other words, we don't want to keep the PyObject alive if
// there are no C++ references to this object other than the PyObject
// itself.
if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) {
target_->decref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!has_pyobject,
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
}
@ -602,16 +522,6 @@ class intrusive_ptr final {
return use_count() == 1;
}
/**
* Stronger than unique() in that it must not have any weakrefs as well.
*/
bool is_uniquely_owned() const noexcept {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton());
uint64_t combined =
target_->combined_refcount_.load(std::memory_order_acquire);
return (combined & ~detail::kHasPyObject) == detail::kUniqueRef;
}
/**
* Returns an owning (!) pointer to the underlying object and makes the
* intrusive_ptr instance invalid. That means the refcount is not decreased.
@ -1022,7 +932,6 @@ class weak_intrusive_ptr final {
if (target_ == NullType::singleton()) {
return intrusive_ptr<TTarget, NullType>();
} else {
bool increfed = false;
auto combined_refcount =
target_->combined_refcount_.load(std::memory_order_relaxed);
do {
@ -1031,31 +940,12 @@ class weak_intrusive_ptr final {
// Return nullptr.
return intrusive_ptr<TTarget, NullType>();
}
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (detail::has_pyobject(combined_refcount) &&
detail::refcount(combined_refcount) == 1 && !increfed) {
// Object has a python wrapper with no other C++ references.
// We need to to incref the Python object before we acquire a
// strong reference to the C++ object to avoid a situation
// where the Python object is deallocated concurrently.
if (!target_->try_incref_pyobject()) {
return intrusive_ptr<TTarget, NullType>();
}
increfed = true;
}
}
} while (!target_->combined_refcount_.compare_exchange_weak(
combined_refcount,
combined_refcount + detail::kReferenceCountOne,
std::memory_order_acquire,
std::memory_order_relaxed));
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (increfed && detail::refcount(combined_refcount) != 1) {
target_->decref_pyobject();
}
}
return intrusive_ptr<TTarget, NullType>(
target_, raw::DontIncreaseRefcount{});
}
@ -1170,18 +1060,7 @@ namespace intrusive_ptr {
// NullType::singleton to this function
inline void incref(intrusive_ptr_target* self) {
if (self) {
uint64_t combined = detail::atomic_combined_refcount_increment(
self->combined_refcount_, detail::kReferenceCountOne);
#ifndef C10_MOBILE
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
self->incref_pyobject();
}
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!detail::has_pyobject(combined));
#endif
detail::atomic_refcount_increment(self->combined_refcount_);
}
}

View File

@ -15,8 +15,6 @@ using namespace c10::CachingDeviceAllocator;
// newly allocated memory with 512-byte alignment.
constexpr size_t kDeviceAlignment = 512;
class XPUAllocator;
namespace {
using stream_set = ska::flat_hash_set<xpu::XPUStream>;
@ -25,19 +23,14 @@ typedef bool (*Comparison)(const Block*, const Block*);
bool BlockComparatorSize(const Block* a, const Block* b);
bool BlockComparatorAddress(const Block* a, const Block* b);
struct PrivatePool;
struct BlockPool {
BlockPool(bool small, PrivatePool* private_pool = nullptr)
BlockPool(bool small)
: blocks(BlockComparatorSize),
unmapped(BlockComparatorAddress),
is_small(small),
owner_PrivatePool(private_pool) {}
is_small(small) {}
std::set<Block*, Comparison> blocks;
std::set<Block*, Comparison> unmapped;
const bool is_small;
PrivatePool* owner_PrivatePool;
};
struct ExpandableSegment;
@ -356,43 +349,6 @@ struct AllocParams {
StatTypes stat_types = {};
};
// Internal implementation that manages actual memory blocks.
// high level MemPool interface wraps PrivatePool via MempoolId.
struct PrivatePool {
PrivatePool(MempoolId_t id, XPUAllocator* allocator = nullptr)
: id(std::move(id)),
allocator_(allocator),
large_blocks(/*small=*/false, this),
small_blocks(/*small=*/true, this) {}
PrivatePool(const PrivatePool&) = delete;
PrivatePool(PrivatePool&&) = delete;
PrivatePool& operator=(const PrivatePool&) = delete;
PrivatePool& operator=(PrivatePool&&) = delete;
~PrivatePool() = default;
// default Mempool when no Mempool is specified
MempoolId_t id{0, 0};
// Number of live graphs using this pool
int use_count{1};
// Number of unfreed allocations made for this pool. When use_count and
// allocation_count drop to zero, we can delete this PrivatePool from
// graph_pools.
int allocation_count{0};
XPUAllocator* allocator_;
BlockPool large_blocks;
BlockPool small_blocks;
public:
XPUAllocator* allocator() {
return allocator_;
}
};
struct MempoolIdHash {
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
}
};
} // anonymous namespace
class DeviceCachingAllocator {
@ -409,13 +365,6 @@ class DeviceCachingAllocator {
bool set_fraction = false;
std::vector<ExpandableSegment*> expandable_segments;
std::vector<c10::DeviceIndex> devices_with_peer_access; // reserved
std::vector<std::pair<MempoolId_t, std::function<bool(sycl::queue*)>>>
captures_underway;
ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
graph_pools;
// Pools no longer referenced by any graph.
ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash>
graph_pools_freeable;
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
if (!src || src->allocated || src->event_count > 0 ||
@ -514,22 +463,7 @@ class DeviceCachingAllocator {
}
}
BlockPool& get_pool(size_t size, sycl::queue* queue) {
if (C10_UNLIKELY(!captures_underway.empty())) {
for (auto& entry : captures_underway) {
// lookup for mempool id matching current capture graph
if (entry.second(queue)) {
auto it1 = graph_pools.find(entry.first);
// lookup mempool
TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
if (size <= kSmallSize) {
return it1->second->small_blocks;
} else {
return it1->second->large_blocks;
}
}
}
}
BlockPool& get_pool(size_t size) {
if (size < kSmallSize) {
return small_blocks;
} else {
@ -735,10 +669,6 @@ class DeviceCachingAllocator {
if (!ptr) {
return false;
}
if (p.pool->owner_PrivatePool) {
p.pool->owner_PrivatePool->allocation_count++;
}
p.block = new Block(device, p.queue(), size, p.pool, ptr);
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].increase(size);
@ -747,14 +677,11 @@ class DeviceCachingAllocator {
return true;
}
void synchronize_and_free_events(PrivatePool* pool = nullptr) {
void synchronize_and_free_events() {
for (auto& xe : xpu_events) {
for (auto& e : xe.second) {
auto event = e.first;
auto* block = e.second;
if (pool && block->pool->owner_PrivatePool != pool) {
continue;
}
event.wait();
block->event_count--;
if (block->event_count == 0) {
@ -858,13 +785,6 @@ class DeviceCachingAllocator {
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].decrease(unmapped.size);
});
if (block->pool->owner_PrivatePool) {
// The Freed block belonged to a XPU graph's PrivatePool.
TORCH_INTERNAL_ASSERT(
block->pool->owner_PrivatePool->allocation_count > 0);
block->pool->owner_PrivatePool->allocation_count--;
}
}
void release_blocks(BlockPool& pool) {
@ -892,41 +812,13 @@ class DeviceCachingAllocator {
}
}
bool release_cached_blocks(MempoolId_t mempool_id) {
if (mempool_id.first == 0 && mempool_id.second == 0 &&
captures_underway.empty()) {
synchronize_and_free_events();
// See Note [Safe to Free Blocks on BlockPool]
c10::xpu::syncStreamsOnDevice(device_index);
bool release_cached_blocks() {
synchronize_and_free_events();
// See Note [Safe to Free Blocks on BlockPool]
c10::xpu::syncStreamsOnDevice(device_index);
release_blocks(large_blocks);
release_blocks(small_blocks);
}
for (auto it = graph_pools_freeable.begin();
it != graph_pools_freeable.end();) {
if (mempool_id.first != 0 || mempool_id.second != 0) {
if (it->first == mempool_id) {
// If there is an active mempool, we sync only the events
// associated with the pool
synchronize_and_free_events(it->second);
} else {
// otherwise we move on
++it;
continue;
}
}
TORCH_INTERNAL_ASSERT(it->second->use_count == 0);
release_blocks(it->second->small_blocks);
release_blocks(it->second->large_blocks);
if (it->second->allocation_count == 0) {
auto erase_count = graph_pools.erase(it->first);
TORCH_INTERNAL_ASSERT(erase_count == 1);
it = graph_pools_freeable.erase(it);
} else {
++it;
}
}
release_blocks(large_blocks);
release_blocks(small_blocks);
return true;
}
@ -1011,30 +903,6 @@ class DeviceCachingAllocator {
}
}
void create_or_incref_pool(
MempoolId_t mempool_id,
XPUAllocator* allocator = nullptr) {
auto it = graph_pools.find(mempool_id);
if (it == graph_pools.end()) {
// mempool_id does not reference an existing pool.
// Make a new pool for XPU graph capture or memory pool usage.
graph_pools.emplace(
mempool_id, std::make_unique<PrivatePool>(mempool_id, allocator));
} else {
// mempool_id references an existing pool, which the current XPU graph
// capture will share.
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
TORCH_INTERNAL_ASSERT(allocator == nullptr);
it->second->use_count++;
}
}
PrivatePool* get_private_pool(MempoolId_t mempool_id) {
auto it = graph_pools.find(mempool_id);
TORCH_INTERNAL_ASSERT(it != graph_pools.end());
return it->second.get();
}
public:
DeviceCachingAllocator(DeviceIndex device_index)
: large_blocks(/* small */ false),
@ -1043,11 +911,9 @@ class DeviceCachingAllocator {
Block* malloc(DeviceIndex device, size_t orig_size, sycl::queue& queue) {
std::scoped_lock<std::recursive_mutex> lock(mutex);
if (C10_LIKELY(captures_underway.empty())) {
process_events();
}
process_events();
size_t size = round_size(orig_size);
auto& pool = get_pool(size, &queue);
auto& pool = get_pool(size);
const size_t alloc_size = get_allocation_size(size);
AllocParams params(device, size, &queue, &pool, alloc_size);
params.stat_types = get_stat_types_for_pool(pool);
@ -1057,7 +923,7 @@ class DeviceCachingAllocator {
// Can't reuse an existing block, try to get a new one.
if (!block_found) {
block_found = alloc_block(params, false) ||
(release_cached_blocks({0, 0}) && alloc_block(params, true));
(release_cached_blocks() && alloc_block(params, true));
}
if (!block_found) {
const auto& raw_device = c10::xpu::get_raw_device(device);
@ -1150,9 +1016,9 @@ class DeviceCachingAllocator {
block->stream_uses.insert(stream);
}
void emptyCache(MempoolId_t mempool_id) {
void emptyCache() {
std::scoped_lock<std::recursive_mutex> lock(mutex);
release_cached_blocks(mempool_id);
release_cached_blocks();
}
DeviceStats getStats() {
@ -1306,9 +1172,9 @@ class XPUAllocator : public DeviceAllocator {
}
}
void emptyCache(MempoolId_t mempool_id) override {
void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override {
for (auto& da : device_allocators) {
da->emptyCache(mempool_id);
da->emptyCache();
}
}
@ -1424,8 +1290,8 @@ void init(DeviceIndex device_count) {
return allocator.init(device_count);
}
void emptyCache(MempoolId_t mempool_id) {
return allocator.emptyCache(mempool_id);
void emptyCache() {
return allocator.emptyCache();
}
void resetPeakStats(DeviceIndex device) {

View File

@ -10,7 +10,7 @@ C10_XPU_API Allocator* get();
C10_XPU_API void init(DeviceIndex device_count);
C10_XPU_API void emptyCache(MempoolId_t mempool_id = {0, 0});
C10_XPU_API void emptyCache();
C10_XPU_API void resetPeakStats(DeviceIndex device);

View File

@ -734,7 +734,7 @@ void PyTorchStreamWriter::setup(const string& file_name) {
file_name,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary
);
} catch (const std::ios_base::failure&) {
} catch (const std::ios_base::failure& e) {
#ifdef _WIN32
// Windows have verbose error code, we prefer to use it than std errno.
uint32_t error_code = GetLastError();
@ -773,20 +773,8 @@ void PyTorchStreamWriter::writeRecord(
bool compress) {
AT_ASSERT(!finalized_);
AT_ASSERT(!archive_name_plus_slash_.empty());
if (files_written_.count(name) > 0) {
// Allow multiple writes for triton binaries
bool is_triton_extension =
c10::ends_with(name, ".so") ||
c10::ends_with(name, ".cubin") ||
c10::ends_with(name, ".hsaco");
if (is_triton_extension) {
LOG(WARNING) << "File '" << name << "' is being serialized multiple times";
return;
}
TORCH_INTERNAL_ASSERT(false, "Tried to serialize file twice: ", name);
}
TORCH_INTERNAL_ASSERT(
files_written_.count(name) == 0, "Tried to serialize file twice: ", name);
if (name == kSerializationIdRecordName && serialization_id_.empty()) {
// In case of copying records from another file, skip writing a different
// serialization_id than the one computed in this writer.

View File

@ -10,7 +10,7 @@ API. This API can roughly be divided into five parts:
- **TorchScript**: An interface to the TorchScript JIT compiler and interpreter.
- **C++ Extensions**: A means of extending the Python API with custom C++ and CUDA routines.
Combined, these building blocks form a research and
Combining, these building blocks form a research and
production ready C++ library for tensor computation and dynamic neural
networks with strong emphasis on GPU acceleration as well as fast CPU
performance. It is currently in use at Facebook in research and
@ -76,7 +76,7 @@ C++ Frontend
------------
The PyTorch C++ frontend provides a high level, pure C++ modeling interface for
neural networks and general ML (Machine Learning) research and production use cases,
neural network and general ML(Machine Learning) research and production use cases,
largely following the Python API in design and provided functionality. The C++
frontend includes the following:

View File

@ -1,113 +0,0 @@
# Device Management
## Background
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
## Design
Accelerator vendors need to implement these core functions:
| Function Name | Description | Application Scenarios |
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
## Implementation
This section shows how to implement device management using `set_device` as an example. The implementation requires:
1. C++ wrappers around the device runtime
2. Python bindings to expose the C++ functions
3. User-friendly Python APIs
### C++ Side
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
:linenos:
```
### Binding
Expose the C++ functions to Python using pybind11:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
:linenos:
:emphasize-lines: 5
```
### Python Side
Wrap the C++ bindings with user-friendly Python functions:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
:language: python
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
:linenos:
```
Here's the complete mapping from C++ to Python:
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
## Guard
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:linenos:
```
**What needs to be implemented:**
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
2. **getDevice()**: Get the current device
3. **setDevice()**: Set the active device
4. **Type checking**: Validate that device type matches the backend
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"

View File

@ -42,7 +42,6 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
:glob:
:maxdepth: 1
device
hooks
autoload
operators

View File

@ -254,7 +254,7 @@ To toggle the reduced precision reduction flags in C++, one can do
.. _fp16accumulation:
Full FP16 Accumulation in FP16 GEMMs
Full FP16 Accmumulation in FP16 GEMMs
-------------------------------------
Certain GPUs have increased performance when doing _all_ FP16 GEMM accumulation

View File

@ -30,6 +30,5 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
skip_guard_on_all_nn_modules_unsafe
keep_tensor_guards_unsafe
skip_guard_on_globals_unsafe
skip_all_guards_unsafe
nested_compile_region
```

View File

@ -24,11 +24,15 @@ def gen_data(special_op_lists, analysis_name):
all_ops = get_ops_for_key(None)
composite_ops = get_ops_for_key("CompositeImplicitAutograd")
noncomposite_ops = all_ops - composite_ops
with open("../../aten/src/ATen/native/native_functions.yaml") as f:
ops = yaml.load(f.read(), Loader=yaml.CLoader)
with open("annotated_ops") as f:
annotated_ops = {a.strip(): b.strip() for a, b in csv.reader(f)}
ops = yaml.load(
open("../../aten/src/ATen/native/native_functions.yaml").read(),
Loader=yaml.CLoader,
)
annotated_ops = {
a.strip(): b.strip() for a, b in list(csv.reader(open("annotated_ops")))
}
uniq_ops = []
uniq_names = set()

View File

@ -376,19 +376,3 @@ keep-runtime-typing = true
[tool.codespell]
ignore-words = "tools/linter/dictionary.txt"
[tool.spin]
package = 'torch'
[tool.spin.commands]
"Build" = [
".spin/cmds.py:lint",
".spin/cmds.py:fixlint",
".spin/cmds.py:quicklint",
".spin/cmds.py:quickfix",
]
"Regenerate" = [
".spin/cmds.py:regenerate_version",
".spin/cmds.py:regenerate_type_stubs",
".spin/cmds.py:regenerate_clangtidy_files",
]

View File

@ -32,7 +32,7 @@ project-excludes = [
"torch/utils/tensorboard/summary.py",
# formatting issues, will turn on after adjusting where suppressions can be
# in import statements
"torch/distributed/flight_recorder/components/types.py",
"tools/flight_recorder/components/types.py",
"torch/linalg/__init__.py",
"torch/package/importer.py",
"torch/package/_package_pickler.py",

View File

@ -14,7 +14,6 @@ lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
networkx>=2.5.1
optree>=0.13.0
psutil
spin
sympy>=1.13.3
typing-extensions>=4.13.2
wheel

View File

@ -1358,6 +1358,45 @@ class concat_license_files:
# Need to create the proper LICENSE.txt for the wheel
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
Excludes:
- torch/include/torch/headeronly/*
- torch/include/torch/csrc/stable/*
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
- torch/include/torch/csrc/inductor/aoti_torch/generated/
"""
header_extensions = (".h", ".hpp", ".cuh")
header_files = [
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
]
# Paths to exclude from wrapping
exclude_dir_patterns = [
"torch/include/torch/headeronly/",
"torch/include/torch/csrc/stable/",
"torch/include/torch/csrc/inductor/aoti_torch/c/",
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
]
for header_file in header_files:
rel_path = header_file.relative_to(bdist_dir).as_posix()
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
report(f"Skipping header: {rel_path}")
continue
original_content = header_file.read_text(encoding="utf-8")
wrapped_content = (
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
f"{original_content}"
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
)
header_file.write_text(wrapped_content, encoding="utf-8")
report(f"Wrapped header: {rel_path}")
def run(self) -> None:
with concat_license_files(include_files=True):
super().run()
@ -1380,6 +1419,14 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
# need an __init__.py file otherwise we wouldn't have a package
(bdist_dir / "torch" / "__init__.py").touch()
# Wrap all header files with TORCH_STABLE_ONLY macro
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
bdist_dir = Path(self.bdist_dir)
report(
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
)
self._wrap_headers_with_macro(bdist_dir)
class clean(Command):
user_options: ClassVar[list[tuple[str, str | None, str]]] = []
@ -1585,7 +1632,7 @@ def configure_extension_build() -> tuple[
if cmake_cache_vars["USE_DISTRIBUTED"]:
# Only enable fr_trace command if distributed is enabled
entry_points["console_scripts"].append(
"torchfrtrace = torch.distributed.flight_recorder.fr_trace:main",
"torchfrtrace = tools.flight_recorder.fr_trace:main",
)
return ext_modules, cmdclass, packages, entry_points, extra_install_requires

View File

@ -8,7 +8,6 @@ set(AOTI_ABI_CHECK_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_abi_check)
# Build the cpp gtest binary containing the cpp-only tests.
set(AOTI_ABI_CHECK_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_accessor.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch.cpp

View File

@ -1,50 +0,0 @@
#include <gtest/gtest.h>
#include <torch/headeronly/core/TensorAccessor.h>
#include <string>
TEST(TestAccessor, HeaderOnlyTensorAccessor) {
std::vector<int32_t> v = {11, 12, 13, 21, 22, 23};
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = {3, 1};
auto acc = torch::headeronly::HeaderOnlyTensorAccessor<int32_t, 2>(
v.data(), sizes.data(), strides.data());
EXPECT_EQ(acc[0][0], 11);
EXPECT_EQ(acc[0][1], 12);
EXPECT_EQ(acc[0][2], 13);
EXPECT_EQ(acc[1][0], 21);
EXPECT_EQ(acc[1][1], 22);
EXPECT_EQ(acc[1][2], 23);
}
TEST(TestAccessor, HeaderOnlyGenericPackedTensorAccessor) {
std::vector<int32_t> v = {11, 12, 13, 21, 22, 23};
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = {3, 1};
auto acc =
torch::headeronly::HeaderOnlyGenericPackedTensorAccessor<int32_t, 2>(
v.data(), sizes.data(), strides.data());
EXPECT_EQ(acc[0][0], 11);
EXPECT_EQ(acc[0][1], 12);
EXPECT_EQ(acc[0][2], 13);
EXPECT_EQ(acc[1][0], 21);
EXPECT_EQ(acc[1][1], 22);
EXPECT_EQ(acc[1][2], 23);
auto tacc = acc.transpose(0, 1);
EXPECT_EQ(tacc[0][0], 11);
EXPECT_EQ(tacc[0][1], 21);
EXPECT_EQ(tacc[1][0], 12);
EXPECT_EQ(tacc[1][1], 22);
EXPECT_EQ(tacc[2][0], 13);
EXPECT_EQ(tacc[2][1], 23);
try {
acc.transpose(0, 2);
} catch (const std::exception& e) {
EXPECT_TRUE(
std::string(e.what()).find("HeaderOnlyIndexBoundsCheck") !=
std::string::npos);
}
}

View File

@ -13,17 +13,6 @@ TEST(TestScalarType, ScalarTypeToCPPTypeT) {
#undef DEFINE_CHECK
}
TEST(TestScalarType, CppTypeToScalarType) {
using torch::headeronly::CppTypeToScalarType;
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
EXPECT_EQ(CppTypeToScalarType<TYPE>::value, ScalarType::SCALARTYPE);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
{ \
EXPECT_EQ( \
@ -101,14 +90,3 @@ TEST(TestScalarType, toUnderlying) {
AT_FORALL_FLOAT8_TYPES(DEFINE_CHECK);
#undef DEFINE_CHECK
}
TEST(TestScalarType, isQIntType) {
using torch::headeronly::isQIntType;
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(_, name) EXPECT_TRUE(isQIntType(ScalarType::name));
AT_FORALL_QINT_TYPES(DEFINE_CHECK);
#undef DEFINE_CHECK
#define DEFINE_CHECK(_, name) EXPECT_FALSE(isQIntType(ScalarType::name));
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CHECK);
#undef DEFINE_CHECK
}

View File

@ -15,7 +15,7 @@ namespace jit {
TEST(CustomOperatorTest, InferredSchema) {
torch::RegisterOperators reg(
"foo::bar", [](double a, at::Tensor b) { return a + b; });
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
ASSERT_EQ(ops.size(), 1);
auto& op = ops.front();
@ -43,7 +43,8 @@ TEST(CustomOperatorTest, ExplicitSchema) {
"foo::bar_with_schema(float a, Tensor b) -> Tensor",
[](double a, at::Tensor b) { return a + b; });
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
auto& ops =
getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
ASSERT_EQ(ops.size(), 1);
auto& op = ops.front();
@ -76,7 +77,7 @@ TEST(CustomOperatorTest, ListParameters) {
torch::List<c10::complex<double>> complexdoubles,
torch::List<at::Tensor> tensors) { return floats; });
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
ASSERT_EQ(ops.size(), 1);
auto& op = ops.front();
@ -122,7 +123,7 @@ TEST(CustomOperatorTest, ListParameters2) {
"foo::lists2(Tensor[] tensors) -> Tensor[]",
[](torch::List<at::Tensor> tensors) { return tensors; });
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
ASSERT_EQ(ops.size(), 1);
auto& op = ops.front();
@ -212,7 +213,7 @@ TEST(TestCustomOperator, OperatorGeneratorUndeclared) {
},
aliasAnalysisFromSchema())});
auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist"));
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist"));
ASSERT_EQ(ops.size(), 0);
}
@ -231,7 +232,7 @@ TEST(TestCustomOperator, OperatorGeneratorBasic) {
},
aliasAnalysisFromSchema())});
auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar"));
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar"));
ASSERT_EQ(ops.size(), 1);
auto& op = ops.front();

View File

@ -1,30 +0,0 @@
#include "kernel.h"
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <cuda_runtime.h>
using torch::stable::Tensor;
Tensor mv_tensor_accessor_cuda(Tensor m, Tensor v) {
STD_TORCH_CHECK(m.dim() == 2, "m must be 2D");
STD_TORCH_CHECK(v.dim() == 1, "v must be 1D");
STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold");
STD_TORCH_CHECK(m.scalar_type() == v.scalar_type(), "m and v must have the same dtype");
STD_TORCH_CHECK(m.device() == v.device(), "m and v must be on the same device");
Tensor res = new_empty(m, {m.size(0)});
THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cuda",
AT_WRAP(([&]() {
auto resa = Accessor_cuda<scalar_t, 1>(reinterpret_cast<scalar_t*>(res.data_ptr()), res.sizes().data(), res.strides().data());
auto ma = Accessor_cuda<scalar_t, 2>(reinterpret_cast<scalar_t*>(m.data_ptr()), m.sizes().data(), m.strides().data());
auto va = Accessor_cuda<scalar_t, 1>(reinterpret_cast<scalar_t*>(v.data_ptr()), v.sizes().data(), v.strides().data());
mv_tensor_accessor_kernel<Accessor_cuda, scalar_t><<<1, 1, 0, 0>>>(resa, ma, va);
})),
AT_FLOATING_TYPES);
return res;
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
m.impl("mv_tensor_accessor", TORCH_BOX(&mv_tensor_accessor_cuda));
}

View File

@ -1,5 +1,3 @@
#include "kernel.h"
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/device.h>
@ -310,7 +308,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_amax(Tensor a) -> Tensor");
m.def("my_amax_vec(Tensor a) -> Tensor");
m.def("my_is_cpu(Tensor t) -> bool");
m.def("test_default_constructor(bool undefined) -> bool");
m.def("test_default_constructor(bool undefined) -> bool");
}
bool test_default_constructor(bool defined) {
@ -332,47 +330,12 @@ bool test_default_constructor(bool defined) {
return out.defined();
}
uint64_t get_any_data_ptr(Tensor t, bool mutable_) {
if (mutable_) {
return reinterpret_cast<uint64_t>(t.mutable_data_ptr());
} else {
return reinterpret_cast<uint64_t>(t.const_data_ptr());
}
}
uint64_t get_template_any_data_ptr(Tensor t, c10::ScalarType dtype, bool mutable_) {
#define DEFINE_CASE(T, name) \
case torch::headeronly::ScalarType::name: { \
if (mutable_) { \
return reinterpret_cast<uint64_t>(t.mutable_data_ptr<T>()); \
} else { \
return reinterpret_cast<uint64_t>(t.const_data_ptr<T>()); \
} \
}
switch (dtype) {
// per aten/src/ATen/templates/TensorMethods.cpp:
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
DEFINE_CASE(uint16_t, UInt16)
DEFINE_CASE(uint32_t, UInt32)
DEFINE_CASE(uint64_t, UInt64)
default:
return 0;
}
#undef DEFINE_CASE
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("get_any_data_ptr(Tensor t, bool mutable_) -> int");
m.def("get_template_any_data_ptr(Tensor t, ScalarType dtype, bool mutable_) -> int");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_zero_", TORCH_BOX(&my_zero_));
m.impl("my_amax", TORCH_BOX(&my_amax));
m.impl("my_amax_vec", TORCH_BOX(&my_amax_vec));
m.impl("test_default_constructor", TORCH_BOX(&test_default_constructor));
m.impl("get_any_data_ptr", TORCH_BOX(&get_any_data_ptr));
m.impl("get_template_any_data_ptr", TORCH_BOX(&get_template_any_data_ptr));
}
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
@ -551,32 +514,6 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_device_is_cpu", &boxed_test_device_is_cpu);
}
Tensor mv_tensor_accessor_cpu(Tensor m, Tensor v) {
STD_TORCH_CHECK(m.dim() == 2, "m must be 2D");
STD_TORCH_CHECK(v.dim() == 1, "v must be 1D");
STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold");
STD_TORCH_CHECK(m.scalar_type() == v.scalar_type(), "m and v must have the same dtype");
STD_TORCH_CHECK(m.device() == v.device(), "m and v must be on the same device");
Tensor res = new_empty(m, {m.size(0)});
THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cpu",
AT_WRAP(([&]() {
auto resa = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(res.data_ptr()), res.sizes().data(), res.strides().data());
auto ma = Accessor_cpu<scalar_t, 2>(reinterpret_cast<scalar_t*>(m.data_ptr()), m.sizes().data(), m.strides().data());
auto va = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(v.data_ptr()), v.sizes().data(), v.strides().data());
mv_tensor_accessor_kernel<Accessor_cpu, scalar_t>(resa, ma, va);
})),
AT_FLOATING_TYPES);
return res;
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("mv_tensor_accessor(Tensor m, Tensor v) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("mv_tensor_accessor", TORCH_BOX(&mv_tensor_accessor_cpu));
}
// Test functions for torch::stable::accelerator APIs
#ifdef LAE_USE_CUDA
@ -697,38 +634,3 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_parallel_for", &boxed_test_parallel_for);
m.impl("test_get_num_threads", &boxed_test_get_num_threads);
}
Tensor my_empty(
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
std::optional<torch::headeronly::ScalarType> dtype,
std::optional<torch::stable::Device> device,
std::optional<bool> pin_memory) {
return empty(size, dtype, device, pin_memory);
}
Tensor my_flatten(Tensor t, int64_t start_dim, int64_t end_dim) {
return flatten(t, start_dim, end_dim);
}
Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) {
return reshape(t, shape);
}
Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) {
return view(t, size);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def(
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor");
m.def("my_reshape(Tensor t, int[] shape) -> Tensor");
m.def("my_view(Tensor t, int[] size) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_empty", TORCH_BOX(&my_empty));
m.impl("my_flatten", TORCH_BOX(&my_flatten));
m.impl("my_reshape", TORCH_BOX(&my_reshape));
m.impl("my_view", TORCH_BOX(&my_view));
}

View File

@ -1,26 +0,0 @@
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/TensorAccessor.h>
template <typename T, size_t N>
using Accessor_cpu = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
#if defined(__CUDACC__) || defined(__HIPCC__)
#define MAYBE_GLOBAL __global__
template <typename T, size_t N>
using Accessor_cuda = torch::headeronly::HeaderOnlyGenericPackedTensorAccessor<T, N, torch::headeronly::RestrictPtrTraits>;
#else
#define MAYBE_GLOBAL
#endif
template <template <typename, size_t> class Accessor, typename scalar_t>
MAYBE_GLOBAL void mv_tensor_accessor_kernel(Accessor<scalar_t, 1> resa, Accessor<scalar_t, 2> ma, Accessor<scalar_t, 1> va) {
for (int64_t i = 0; i < resa.size(0); i++) {
scalar_t val = 0;
for (int64_t j = 0; j < ma.size(1); j++) {
val += ma[i][j] * va[j];
}
resa[i] = val;
}
}

View File

@ -227,37 +227,6 @@ def test_tensor_device(t):
return torch.ops.libtorch_agnostic.test_tensor_device.default(t)
def get_any_data_ptr(t, mutable) -> int:
"""
Return data pointer value of the tensor.
Args:
t: Input tensor
mutable: whether data pointer qualifier is mutable or const
Returns: int - pointer value
"""
return torch.ops.libtorch_agnostic.get_any_data_ptr.default(t, mutable)
def get_template_any_data_ptr(t, dtype, mutable) -> int:
"""
Return data pointer value of the tensor iff it has dtype.
Args:
t: Input tensor
dtype: Input dtype
mutable: whether data pointer qualifier is mutable or const
Returns: int - pointer value
Raises RuntimeError when t.dtype() != dtype.
"""
return torch.ops.libtorch_agnostic.get_template_any_data_ptr.default(
t, dtype, mutable
)
def my_pad(t) -> Tensor:
"""
Pads the input tensor with hardcoded padding parameters.
@ -518,72 +487,3 @@ def test_get_num_threads() -> int:
Returns: int - the number of threads for the parallel backend
"""
return torch.ops.libtorch_agnostic.test_get_num_threads.default()
def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor:
"""
Creates an empty tensor with the specified size, dtype, device, and pin_memory.
Args:
size: list[int] - size of the tensor to create
dtype: ScalarType or None - data type of the tensor
device: Device or None - device on which to create the tensor
pin_memory: bool or None - whether to use pinned memory
Returns: Tensor - an uninitialized tensor with the specified properties
"""
return torch.ops.libtorch_agnostic.my_empty.default(size, dtype, device, pin_memory)
def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor:
"""
Flattens the input tensor from start_dim to end_dim into a single dimension.
Args:
t: Tensor - tensor to flatten
start_dim: int - first dimension to flatten (default: 0)
end_dim: int - last dimension to flatten (default: -1)
Returns: Tensor - flattened tensor
"""
return torch.ops.libtorch_agnostic.my_flatten.default(t, start_dim, end_dim)
def my_reshape(t, shape) -> Tensor:
"""
Returns a tensor with the same data but different shape.
Args:
t: Tensor - tensor to reshape
shape: list[int] - new shape for the tensor
Returns: Tensor - reshaped tensor
"""
return torch.ops.libtorch_agnostic.my_reshape.default(t, shape)
def my_view(t, size) -> Tensor:
"""
Returns a new tensor with the same data as the input tensor but of a different shape.
Args:
t: Tensor - tensor to view
size: list[int] - new size for the tensor
Returns: Tensor - tensor with new view
"""
return torch.ops.libtorch_agnostic.my_view.default(t, size)
def mv_tensor_accessor(m, v) -> Tensor:
"""
Returns matrix-vector product.
Args:
m: any 2-D Tensor with shape (N, M)
v: any 1-D Tensor with shape (M,)
Returns:
a 1-D Tensor with shape (N,)
"""
return torch.ops.libtorch_agnostic.mv_tensor_accessor.default(m, v)

View File

@ -33,17 +33,16 @@ class clean(distutils.command.clean.clean):
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always"],
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
}
sources = list(CSRC_DIR.glob("**/*.cpp"))
extension = CppExtension
# allow including <cuda_runtime.h>
if torch.cuda.is_available():
extra_compile_args["cxx"].append("-DLAE_USE_CUDA")
extra_compile_args["nvcc"] = ["-O2"]
extension = CUDAExtension
sources.extend(CSRC_DIR.glob("**/*.cu"))
sources = list(CSRC_DIR.glob("**/*.cpp"))
return [
extension(

View File

@ -14,38 +14,11 @@ from torch.testing._internal.common_utils import (
install_cpp_extension,
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TestCase,
xfailIfTorchDynamo,
)
def get_supported_dtypes():
"""Return a list of dtypes that are supported by torch stable ABI."""
return [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.bfloat16,
torch.float16,
torch.float32,
torch.float64,
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e5m2fnuz,
torch.float8_e4m3fnuz,
torch.complex32,
torch.complex64,
torch.complex128,
torch.bool,
]
# TODO: Fix this error in Windows:
# LINK : error LNK2001: unresolved external symbol PyInit__C
if not IS_WINDOWS:
@ -301,43 +274,6 @@ if not IS_WINDOWS:
expected0 = torch.narrow(t, dim0, start0, length0)
self.assertEqual(out0, expected0)
@skipIfTorchDynamo("no data pointer defined for FakeTensor, FunctionalTensor")
def test_get_any_data_ptr(self, device):
import libtorch_agnostic
t = torch.empty(2, 5, device=device, dtype=torch.float32)
expected_p = t.data_ptr()
for mutable in [True, False]:
p = libtorch_agnostic.ops.get_any_data_ptr(t, mutable)
self.assertEqual(p, expected_p)
@skipIfTorchDynamo("no data pointer defined for FakeTensor, FunctionalTensor")
def test_get_template_any_data_ptr(self, device):
import libtorch_agnostic
supported_dtypes = get_supported_dtypes()
for dtype in supported_dtypes:
t = torch.empty(2, 5, device=device, dtype=dtype)
expected_p = t.data_ptr()
for rdtype in supported_dtypes:
if dtype == rdtype:
for mutable in [True, False]:
p = libtorch_agnostic.ops.get_template_any_data_ptr(
t, rdtype, mutable
)
self.assertEqual(p, expected_p)
else:
for mutable in [True, False]:
with self.assertRaisesRegex(
RuntimeError, "expected scalar type.* but found"
):
libtorch_agnostic.ops.get_template_any_data_ptr(
t, rdtype, mutable
)
@onlyCUDA
@deviceCountAtLeast(2)
def test_device_guard(self, device):
@ -589,113 +525,6 @@ if not IS_WINDOWS:
expected_num_threads = torch.get_num_threads()
self.assertEqual(num_threads, expected_num_threads)
def test_my_empty(self, device):
import libtorch_agnostic
deterministic = torch.are_deterministic_algorithms_enabled()
try:
# set use_deterministic_algorithms to fill uninitialized memory
torch.use_deterministic_algorithms(True)
size = [2, 3]
result = libtorch_agnostic.ops.my_empty(size, None, None, None)
expected = torch.empty(size)
self.assertEqual(result, expected, exact_device=True)
result_float = libtorch_agnostic.ops.my_empty(
size, torch.float32, None, None
)
expected_float = torch.empty(size, dtype=torch.float32)
self.assertEqual(result_float, expected_float, exact_device=True)
result_with_device = libtorch_agnostic.ops.my_empty(
size, torch.float64, device, None
)
expected_with_device = torch.empty(
size, dtype=torch.float64, device=device
)
self.assertEqual(
result_with_device, expected_with_device, exact_device=True
)
if device == "cuda":
result_pinned = libtorch_agnostic.ops.my_empty(
size, torch.float32, "cpu", True
)
expected_pinned = torch.empty(
size, dtype=torch.float32, device="cpu", pin_memory=True
)
self.assertEqual(result_pinned, expected_pinned)
self.assertTrue(result_pinned.is_pinned())
finally:
torch.use_deterministic_algorithms(deterministic)
def test_my_flatten(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_flatten(t)
expected = torch.flatten(t)
self.assertEqual(result, expected)
result_start = libtorch_agnostic.ops.my_flatten(t, 1)
expected_start = torch.flatten(t, 1)
self.assertEqual(result_start, expected_start)
result_range = libtorch_agnostic.ops.my_flatten(t, 2, -1)
expected_range = torch.flatten(t, 2, -1)
self.assertEqual(result_range, expected_range)
def test_my_reshape(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_reshape(t, [6, 4])
expected = torch.reshape(t, [6, 4])
self.assertEqual(result, expected)
result_infer = libtorch_agnostic.ops.my_reshape(t, [-1, 4])
expected_infer = torch.reshape(t, [-1, 4])
self.assertEqual(result_infer, expected_infer)
result_flat = libtorch_agnostic.ops.my_reshape(t, [-1])
expected_flat = torch.reshape(t, [-1])
self.assertEqual(result_flat, expected_flat)
def test_my_view(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_view(t, [6, 4])
expected = t.view([6, 4])
self.assertEqual(result, expected)
result_infer = libtorch_agnostic.ops.my_view(t, [-1, 4])
expected_infer = t.view([-1, 4])
self.assertEqual(result_infer, expected_infer)
result_flat = libtorch_agnostic.ops.my_view(t, [-1])
expected_flat = t.view([-1])
self.assertEqual(result_flat, expected_flat)
def test_mv_tensor_accessor(self, device):
import libtorch_agnostic
m = torch.rand(3, 5, device=device)
v = torch.rand(5, device=device)
result = libtorch_agnostic.ops.mv_tensor_accessor(m, v)
expected = torch.mv(m, v)
self.assertEqual(result, expected)
# non-contiguous inputs
m = torch.rand(3 * 2, 5 * 3, device=device)[::2, ::3]
v = torch.rand(5 * 4, device=device)[::4]
result = libtorch_agnostic.ops.mv_tensor_accessor(m, v)
expected = torch.mv(m, v)
self.assertEqual(result, expected)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -4,12 +4,17 @@
#include <c10/util/Exception.h>
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
void orCheckFail(
const char* func,
const char* file,
uint32_t line,
const char* msg = "");
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (C10_UNLIKELY(__err != orSuccess)) { \
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (__err != orSuccess) { \
orCheckFail( \
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
} while (0)

View File

@ -1,4 +1,3 @@
#include <c10/util/Exception.h>
#include <include/openreg.h>
#include "OpenRegException.h"
@ -10,22 +9,21 @@ orError_t GetDeviceCount(int* dev_count) {
return orGetDeviceCount(dev_count);
}
orError_t GetDevice(DeviceIndex* device) {
orError_t GetDevice(c10::DeviceIndex* device) {
int tmp_device = -1;
auto err = orGetDevice(&tmp_device);
*device = static_cast<DeviceIndex>(tmp_device);
*device = static_cast<c10::DeviceIndex>(tmp_device);
return err;
}
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
orError_t SetDevice(DeviceIndex device) {
orError_t SetDevice(c10::DeviceIndex device) {
int cur_device = -1;
OPENREG_CHECK(orGetDevice(&cur_device));
orGetDevice(&cur_device);
if (device == cur_device) {
return orSuccess;
}
return orSetDevice(device);
}
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
int device_count_impl() {
int count = 0;
@ -33,37 +31,34 @@ int device_count_impl() {
return count;
}
OPENREG_EXPORT DeviceIndex device_count() noexcept {
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
// initialize number of devices only once
static int count = []() {
try {
auto result = device_count_impl();
TORCH_CHECK(
result <= std::numeric_limits<DeviceIndex>::max(),
result <= std::numeric_limits<c10::DeviceIndex>::max(),
"Too many devices, DeviceIndex overflowed");
return result;
} catch (const Error& ex) {
} catch (const c10::Error& ex) {
// We don't want to fail, but still log the warning
// msg() returns the message without the stack trace
TORCH_WARN("Device initialization: ", ex.msg());
return 0;
}
}();
return static_cast<DeviceIndex>(count);
return static_cast<c10::DeviceIndex>(count);
}
OPENREG_EXPORT DeviceIndex current_device() {
DeviceIndex cur_device = -1;
OPENREG_CHECK(GetDevice(&cur_device));
OPENREG_EXPORT c10::DeviceIndex current_device() {
c10::DeviceIndex cur_device = -1;
GetDevice(&cur_device);
return cur_device;
}
// LITERALINCLUDE START: OPENREG set_device FUNCTION
OPENREG_EXPORT void set_device(DeviceIndex device) {
check_device_index(device);
OPENREG_CHECK(SetDevice(device));
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
SetDevice(device);
}
// LITERALINCLUDE END: OPENREG set_device FUNCTION
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
int current_device = -1;
@ -76,8 +71,4 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
return current_device;
}
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
check_device_index(to_device);
return ExchangeDevice(to_device);
}
} // namespace c10::openreg

View File

@ -9,20 +9,10 @@
namespace c10::openreg {
OPENREG_EXPORT DeviceIndex device_count() noexcept;
OPENREG_EXPORT DeviceIndex current_device();
OPENREG_EXPORT void set_device(DeviceIndex device);
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
OPENREG_EXPORT c10::DeviceIndex current_device();
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
static inline void check_device_index(int64_t device) {
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
"The device index is out of range. It must be in [0, ",
static_cast<int>(c10::openreg::device_count()),
"), but got ",
static_cast<int>(device),
".");
}
} // namespace c10::openreg

View File

@ -2,8 +2,6 @@
namespace c10::openreg {
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
} // namespace c10::openreg

Some files were not shown because too many files have changed in this diff Show More