Compare commits

..

1 Commits

Author SHA1 Message Date
47e16bddf9 Clarify that crashes/OOB accesses and not security threats
Added note on crashes and out of bounds access in PyTorch.

Addresses https://github.com/pytorch/pytorch/issues/166881#issuecomment-3513245388
2025-11-10 22:31:52 -08:00
381 changed files with 6019 additions and 13196 deletions

View File

@ -0,0 +1,19 @@
# Aarch64 (ARM/Graviton) Support Scripts
Scripts for building aarch64 PyTorch PIP Wheels. These scripts build the following wheels:
* torch
* torchvision
* torchaudio
* torchtext
* torchdata
## Aarch64_ci_build.sh
This script is design to support CD operations within PyPi manylinux aarch64 container, and be executed in the container. It prepares the container and then executes __aarch64_wheel_ci_build.py__ to build the wheels. The script "assumes" the PyTorch repo is located at: ```/pytorch``` and will put the wheels into ```/artifacts```.
### Usage
```DESIRED_PYTHON=<PythonVersion> aarch64_ci_build.sh```
__NOTE:__ CI build is currently __EXPERMINTAL__
## Build_aarch64_wheel.py
This app allows a person to build using AWS EC3 resources and requires AWS-CLI and Boto3 with AWS credentials to support building EC2 instances for the wheel builds. Can be used in a codebuild CD or from a local system.
### Usage
```build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch <RCtag>```

View File

@ -0,0 +1,53 @@
#!/bin/bash
set -eux -o pipefail
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
# Set CUDA architecture lists to match x86 build_cuda.sh
if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
elif [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
elif [[ "$GPU_ARCH_VERSION" == *"13.0"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;11.0;12.0+PTX"
fi
# Compress the fatbin with -compress-mode=size for CUDA 13
if [[ "$DESIRED_CUDA" == *"13"* ]]; then
export TORCH_NVCC_FLAGS="-compress-mode=size"
# Bundle ptxas into the cu13 wheel, see https://github.com/pytorch/pytorch/issues/163801
export BUILD_BUNDLE_PTXAS=1
fi
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
source $SCRIPTPATH/aarch64_ci_setup.sh
###############################################################################
# Run aarch64 builder python
###############################################################################
cd /
# adding safe directory for git as the permissions will be
# on the mounted pytorch repo
git config --global --add safe.directory /pytorch
pip install -r /pytorch/requirements.txt
pip install auditwheel==6.2.0 wheel
if [ "$DESIRED_CUDA" = "cpu" ]; then
echo "BASE_CUDA_VERSION is not set. Building cpu wheel."
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
else
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
export USE_SYSTEM_NCCL=1
# Check if we should use NVIDIA libs from PyPI (similar to x86 build_cuda.sh logic)
if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
echo "Bundling CUDA libraries with wheel for aarch64."
else
echo "Using nvidia libs from pypi for aarch64."
echo "Updated PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64: $PYTORCH_EXTRA_INSTALL_REQUIREMENTS"
export USE_NVIDIA_PYPI_LIBS=1
fi
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
fi

View File

@ -0,0 +1,21 @@
#!/bin/bash
set -eux -o pipefail
# This script is used to prepare the Docker container for aarch64_ci_wheel_build.py python script
# By creating symlinks from desired /opt/python to /usr/local/bin/
NUMPY_VERSION=2.0.2
if [[ "$DESIRED_PYTHON" == "3.13" || "$DESIRED_PYTHON" == "3.13t" ]]; then
NUMPY_VERSION=2.1.2
fi
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
source $SCRIPTPATH/../manywheel/set_desired_python.sh
pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2
for tool in python python3 pip pip3 ninja scons patchelf; do
ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin;
done
python --version

View File

@ -0,0 +1,333 @@
#!/usr/bin/env python3
# encoding: UTF-8
import os
import shutil
from subprocess import check_call, check_output
def list_dir(path: str) -> list[str]:
"""'
Helper for getting paths for Python
"""
return check_output(["ls", "-1", path]).decode().split("\n")
def replace_tag(filename) -> None:
with open(filename) as f:
lines = f.readlines()
for i, line in enumerate(lines):
if line.startswith("Tag:"):
lines[i] = line.replace("-linux_", "-manylinux_2_28_")
print(f"Updated tag from {line} to {lines[i]}")
break
with open(filename, "w") as f:
f.writelines(lines)
def patch_library_rpath(
folder: str,
lib_name: str,
use_nvidia_pypi_libs: bool = False,
desired_cuda: str = "",
) -> None:
"""Apply patchelf to set RPATH for a library in torch/lib"""
lib_path = f"{folder}/tmp/torch/lib/{lib_name}"
if use_nvidia_pypi_libs:
# For PyPI NVIDIA libraries, construct CUDA RPATH
cuda_rpaths = [
"$ORIGIN/../../nvidia/cudnn/lib",
"$ORIGIN/../../nvidia/nvshmem/lib",
"$ORIGIN/../../nvidia/nccl/lib",
"$ORIGIN/../../nvidia/cusparselt/lib",
]
if "130" in desired_cuda:
cuda_rpaths.append("$ORIGIN/../../nvidia/cu13/lib")
else:
cuda_rpaths.extend(
[
"$ORIGIN/../../nvidia/cublas/lib",
"$ORIGIN/../../nvidia/cuda_cupti/lib",
"$ORIGIN/../../nvidia/cuda_nvrtc/lib",
"$ORIGIN/../../nvidia/cuda_runtime/lib",
"$ORIGIN/../../nvidia/cufft/lib",
"$ORIGIN/../../nvidia/curand/lib",
"$ORIGIN/../../nvidia/cusolver/lib",
"$ORIGIN/../../nvidia/cusparse/lib",
"$ORIGIN/../../nvidia/nvtx/lib",
"$ORIGIN/../../nvidia/cufile/lib",
]
)
# Add $ORIGIN for local torch libs
rpath = ":".join(cuda_rpaths) + ":$ORIGIN"
else:
# For bundled libraries, just use $ORIGIN
rpath = "$ORIGIN"
if os.path.exists(lib_path):
os.system(
f"cd {folder}/tmp/torch/lib/; "
f"patchelf --set-rpath '{rpath}' --force-rpath {lib_name}"
)
def copy_and_patch_library(
src_path: str,
folder: str,
use_nvidia_pypi_libs: bool = False,
desired_cuda: str = "",
) -> None:
"""Copy a library to torch/lib and patch its RPATH"""
if os.path.exists(src_path):
lib_name = os.path.basename(src_path)
shutil.copy2(src_path, f"{folder}/tmp/torch/lib/{lib_name}")
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
def package_cuda_wheel(wheel_path, desired_cuda) -> None:
"""
Package the cuda wheel libraries
"""
folder = os.path.dirname(wheel_path)
os.mkdir(f"{folder}/tmp")
os.system(f"unzip {wheel_path} -d {folder}/tmp")
# Delete original wheel since it will be repackaged
os.system(f"rm {wheel_path}")
# Check if we should use PyPI NVIDIA libraries or bundle system libraries
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
if use_nvidia_pypi_libs:
print("Using nvidia libs from pypi - skipping CUDA library bundling")
# For PyPI approach, we don't bundle CUDA libraries - they come from PyPI packages
# We only need to bundle non-NVIDIA libraries
minimal_libs_to_copy = [
"/lib64/libgomp.so.1",
"/usr/lib64/libgfortran.so.5",
"/acl/build/libarm_compute.so",
"/acl/build/libarm_compute_graph.so",
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_lapack_core.so.0",
"/usr/local/lib/libnvpl_blas_core.so.0",
]
# Copy minimal libraries to unzipped_folder/torch/lib
for lib_path in minimal_libs_to_copy:
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
# Patch torch libraries used for searching libraries
torch_libs_to_patch = [
"libtorch.so",
"libtorch_cpu.so",
"libtorch_cuda.so",
"libtorch_cuda_linalg.so",
"libtorch_global_deps.so",
"libtorch_python.so",
"libtorch_nvshmem.so",
"libc10.so",
"libc10_cuda.so",
"libcaffe2_nvrtc.so",
"libshm.so",
]
for lib_name in torch_libs_to_patch:
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
else:
print("Bundling CUDA libraries with wheel")
# Original logic for bundling system CUDA libraries
# Common libraries for all CUDA versions
common_libs = [
# Non-NVIDIA system libraries
"/lib64/libgomp.so.1",
"/usr/lib64/libgfortran.so.5",
"/acl/build/libarm_compute.so",
"/acl/build/libarm_compute_graph.so",
# Common CUDA libraries (same for all versions)
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_lapack_core.so.0",
"/usr/local/lib/libnvpl_blas_core.so.0",
"/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so",
"/usr/local/cuda/lib64/libcudnn.so.9",
"/usr/local/cuda/lib64/libcusparseLt.so.0",
"/usr/local/cuda/lib64/libcurand.so.10",
"/usr/local/cuda/lib64/libnccl.so.2",
"/usr/local/cuda/lib64/libnvshmem_host.so.3",
"/usr/local/cuda/lib64/libcudnn_adv.so.9",
"/usr/local/cuda/lib64/libcudnn_cnn.so.9",
"/usr/local/cuda/lib64/libcudnn_graph.so.9",
"/usr/local/cuda/lib64/libcudnn_ops.so.9",
"/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9",
"/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9",
"/usr/local/cuda/lib64/libcudnn_heuristic.so.9",
"/usr/local/cuda/lib64/libcufile.so.0",
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
"/usr/local/cuda/lib64/libcusparse.so.12",
]
# CUDA version-specific libraries
if "13" in desired_cuda:
minor_version = desired_cuda[-1]
version_specific_libs = [
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13",
"/usr/local/cuda/lib64/libcublas.so.13",
"/usr/local/cuda/lib64/libcublasLt.so.13",
"/usr/local/cuda/lib64/libcudart.so.13",
"/usr/local/cuda/lib64/libcufft.so.12",
"/usr/local/cuda/lib64/libcusolver.so.12",
"/usr/local/cuda/lib64/libnvJitLink.so.13",
"/usr/local/cuda/lib64/libnvrtc.so.13",
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.{minor_version}",
]
elif "12" in desired_cuda:
# Get the last character for libnvrtc-builtins version (e.g., "129" -> "9")
minor_version = desired_cuda[-1]
version_specific_libs = [
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12",
"/usr/local/cuda/lib64/libcublas.so.12",
"/usr/local/cuda/lib64/libcublasLt.so.12",
"/usr/local/cuda/lib64/libcudart.so.12",
"/usr/local/cuda/lib64/libcufft.so.11",
"/usr/local/cuda/lib64/libcusolver.so.11",
"/usr/local/cuda/lib64/libnvJitLink.so.12",
"/usr/local/cuda/lib64/libnvrtc.so.12",
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.{minor_version}",
]
else:
raise ValueError(f"Unsupported CUDA version: {desired_cuda}.")
# Combine all libraries
libs_to_copy = common_libs + version_specific_libs
# Copy libraries to unzipped_folder/torch/lib
for lib_path in libs_to_copy:
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
# Make sure the wheel is tagged with manylinux_2_28
for f in os.scandir(f"{folder}/tmp/"):
if f.is_dir() and f.name.endswith(".dist-info"):
replace_tag(f"{f.path}/WHEEL")
break
os.system(f"wheel pack {folder}/tmp/ -d {folder}")
os.system(f"rm -rf {folder}/tmp/")
def complete_wheel(folder: str) -> str:
"""
Complete wheel build and put in artifact location
"""
wheel_name = list_dir(f"/{folder}/dist")[0]
# Please note for cuda we don't run auditwheel since we use custom script to package
# the cuda dependencies to the wheel file using update_wheel() method.
# However we need to make sure filename reflects the correct Manylinux platform.
if "pytorch" in folder and not enable_cuda:
print("Repairing Wheel with AuditWheel")
check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder)
repaired_wheel_name = list_dir(f"/{folder}/wheelhouse")[0]
print(f"Moving {repaired_wheel_name} wheel to /{folder}/dist")
os.rename(
f"/{folder}/wheelhouse/{repaired_wheel_name}",
f"/{folder}/dist/{repaired_wheel_name}",
)
else:
repaired_wheel_name = list_dir(f"/{folder}/dist")[0]
print(f"Copying {repaired_wheel_name} to artifacts")
shutil.copy2(
f"/{folder}/dist/{repaired_wheel_name}", f"/artifacts/{repaired_wheel_name}"
)
return repaired_wheel_name
def parse_arguments():
"""
Parse inline arguments
"""
from argparse import ArgumentParser
parser = ArgumentParser("AARCH64 wheels python CD")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")
parser.add_argument("--test-only", type=str)
parser.add_argument("--enable-mkldnn", action="store_true")
parser.add_argument("--enable-cuda", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
"""
Entry Point
"""
args = parse_arguments()
enable_mkldnn = args.enable_mkldnn
enable_cuda = args.enable_cuda
branch = check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd="/pytorch"
).decode()
print("Building PyTorch wheel")
build_vars = ""
# MAX_JOB=5 is not required for CPU backend (see commit 465d98b)
if enable_cuda:
build_vars += "MAX_JOBS=5 "
# Handle PyPI NVIDIA libraries vs bundled libraries
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
if use_nvidia_pypi_libs:
print("Configuring build for PyPI NVIDIA libraries")
# Configure for dynamic linking (matching x86 logic)
build_vars += "ATEN_STATIC_CUDA=0 USE_CUDA_STATIC_LINK=0 USE_CUPTI_SO=1 "
else:
print("Configuring build for bundled NVIDIA libraries")
# Keep existing static linking approach - already configured above
override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION")
desired_cuda = os.getenv("DESIRED_CUDA")
if override_package_version is not None:
version = override_package_version
build_vars += (
f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version} PYTORCH_BUILD_NUMBER=1 "
)
elif branch in ["nightly", "main"]:
build_date = (
check_output(["git", "log", "--pretty=format:%cs", "-1"], cwd="/pytorch")
.decode()
.replace("-", "")
)
version = (
check_output(["cat", "version.txt"], cwd="/pytorch").decode().strip()[:-2]
)
if enable_cuda:
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date}+{desired_cuda} PYTORCH_BUILD_NUMBER=1 "
else:
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 "
elif branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
if enable_mkldnn:
print("build pytorch with mkldnn+acl backend")
build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
build_vars += "ACL_ROOT_DIR=/acl "
if enable_cuda:
build_vars += "BLAS=NVPL "
else:
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS "
else:
print("build pytorch without mkldnn backend")
os.system(f"cd /pytorch; {build_vars} python3 -m build --wheel --no-isolation")
if enable_cuda:
print("Updating Cuda Dependency")
filename = os.listdir("/pytorch/dist/")
wheel_path = f"/pytorch/dist/{filename[0]}"
package_cuda_wheel(wheel_path, desired_cuda)
pytorch_wheel_name = complete_wheel("/pytorch/")
print(f"Build Complete. Created {pytorch_wheel_name}..")

View File

@ -0,0 +1,999 @@
#!/usr/bin/env python3
# This script is for building AARCH64 wheels using AWS EC2 instances.
# To generate binaries for the release follow these steps:
# 1. Update mappings for each of the Domain Libraries by adding new row to a table like this:
# "v1.11.0": ("0.11.0", "rc1"),
# 2. Run script with following arguments for each of the supported python versions and required tag, for example:
# build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch v1.11.0-rc3
import os
import subprocess
import sys
import time
from typing import Optional, Union
import boto3
# AMI images for us-east-1, change the following based on your ~/.aws/config
os_amis = {
"ubuntu20_04": "ami-052eac90edaa9d08f", # login_name: ubuntu
"ubuntu22_04": "ami-0c6c29c5125214c77", # login_name: ubuntu
"redhat8": "ami-0698b90665a2ddcf1", # login_name: ec2-user
}
ubuntu20_04_ami = os_amis["ubuntu20_04"]
def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]:
if key_name is None:
key_name = os.getenv("AWS_KEY_NAME")
if key_name is None:
return os.getenv("SSH_KEY_PATH", ""), ""
homedir_path = os.path.expanduser("~")
default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem")
return os.getenv("SSH_KEY_PATH", default_path), key_name
ec2 = boto3.resource("ec2")
def ec2_get_instances(filter_name, filter_value):
return ec2.instances.filter(
Filters=[{"Name": filter_name, "Values": [filter_value]}]
)
def ec2_instances_of_type(instance_type="t4g.2xlarge"):
return ec2_get_instances("instance-type", instance_type)
def ec2_instances_by_id(instance_id):
rc = list(ec2_get_instances("instance-id", instance_id))
return rc[0] if len(rc) > 0 else None
def start_instance(
key_name, ami=ubuntu20_04_ami, instance_type="t4g.2xlarge", ebs_size: int = 50
):
inst = ec2.create_instances(
ImageId=ami,
InstanceType=instance_type,
SecurityGroups=["ssh-allworld"],
KeyName=key_name,
MinCount=1,
MaxCount=1,
BlockDeviceMappings=[
{
"DeviceName": "/dev/sda1",
"Ebs": {
"DeleteOnTermination": True,
"VolumeSize": ebs_size,
"VolumeType": "standard",
},
}
],
)[0]
print(f"Create instance {inst.id}")
inst.wait_until_running()
running_inst = ec2_instances_by_id(inst.id)
print(f"Instance started at {running_inst.public_dns_name}")
return running_inst
class RemoteHost:
addr: str
keyfile_path: str
login_name: str
container_id: Optional[str] = None
ami: Optional[str] = None
def __init__(self, addr: str, keyfile_path: str, login_name: str = "ubuntu"):
self.addr = addr
self.keyfile_path = keyfile_path
self.login_name = login_name
def _gen_ssh_prefix(self) -> list[str]:
return [
"ssh",
"-o",
"StrictHostKeyChecking=no",
"-i",
self.keyfile_path,
f"{self.login_name}@{self.addr}",
"--",
]
@staticmethod
def _split_cmd(args: Union[str, list[str]]) -> list[str]:
return args.split() if isinstance(args, str) else args
def run_ssh_cmd(self, args: Union[str, list[str]]) -> None:
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
def check_ssh_output(self, args: Union[str, list[str]]) -> str:
return subprocess.check_output(
self._gen_ssh_prefix() + self._split_cmd(args)
).decode("utf-8")
def scp_upload_file(self, local_file: str, remote_file: str) -> None:
subprocess.check_call(
[
"scp",
"-i",
self.keyfile_path,
local_file,
f"{self.login_name}@{self.addr}:{remote_file}",
]
)
def scp_download_file(
self, remote_file: str, local_file: Optional[str] = None
) -> None:
if local_file is None:
local_file = "."
subprocess.check_call(
[
"scp",
"-i",
self.keyfile_path,
f"{self.login_name}@{self.addr}:{remote_file}",
local_file,
]
)
def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None:
self.run_ssh_cmd("sudo apt-get install -y docker.io")
self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}")
self.run_ssh_cmd("sudo service docker start")
self.run_ssh_cmd(f"docker pull {image}")
self.container_id = self.check_ssh_output(
f"docker run -t -d -w /root {image}"
).strip()
def using_docker(self) -> bool:
return self.container_id is not None
def run_cmd(self, args: Union[str, list[str]]) -> None:
if not self.using_docker():
return self.run_ssh_cmd(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + [
"docker",
"exec",
"-i",
self.container_id,
"bash",
]
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE)
p.communicate(
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
"utf-8"
)
)
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd)
def check_output(self, args: Union[str, list[str]]) -> str:
if not self.using_docker():
return self.check_ssh_output(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + [
"docker",
"exec",
"-i",
self.container_id,
"bash",
]
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
(out, err) = p.communicate(
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
"utf-8"
)
)
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err)
return out.decode("utf-8")
def upload_file(self, local_file: str, remote_file: str) -> None:
if not self.using_docker():
return self.scp_upload_file(local_file, remote_file)
tmp_file = os.path.join("/tmp", os.path.basename(local_file))
self.scp_upload_file(local_file, tmp_file)
self.run_ssh_cmd(
["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"]
)
self.run_ssh_cmd(["rm", tmp_file])
def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None:
if not self.using_docker():
return self.scp_download_file(remote_file, local_file)
tmp_file = os.path.join("/tmp", os.path.basename(remote_file))
self.run_ssh_cmd(
["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file]
)
self.scp_download_file(tmp_file, local_file)
self.run_ssh_cmd(["rm", tmp_file])
def download_wheel(
self, remote_file: str, local_file: Optional[str] = None
) -> None:
if self.using_docker() and local_file is None:
basename = os.path.basename(remote_file)
local_file = basename.replace(
"-linux_aarch64.whl", "-manylinux2014_aarch64.whl"
)
self.download_file(remote_file, local_file)
def list_dir(self, path: str) -> list[str]:
return self.check_output(["ls", "-1", path]).split("\n")
def wait_for_connection(addr, port, timeout=15, attempt_cnt=5):
import socket
for i in range(attempt_cnt):
try:
with socket.create_connection((addr, port), timeout=timeout):
return
except (ConnectionRefusedError, TimeoutError): # noqa: PERF203
if i == attempt_cnt - 1:
raise
time.sleep(timeout)
def update_apt_repo(host: RemoteHost) -> None:
time.sleep(5)
host.run_cmd("sudo systemctl stop apt-daily.service || true")
host.run_cmd("sudo systemctl stop unattended-upgrades.service || true")
host.run_cmd(
"while systemctl is-active --quiet apt-daily.service; do sleep 1; done"
)
host.run_cmd(
"while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done"
)
host.run_cmd("sudo apt-get update")
time.sleep(3)
host.run_cmd("sudo apt-get update")
def install_condaforge(
host: RemoteHost, suffix: str = "latest/download/Miniforge3-Linux-aarch64.sh"
) -> None:
print("Install conda-forge")
host.run_cmd(f"curl -OL https://github.com/conda-forge/miniforge/releases/{suffix}")
host.run_cmd(f"sh -f {os.path.basename(suffix)} -b")
host.run_cmd(f"rm -f {os.path.basename(suffix)}")
if host.using_docker():
host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc")
else:
host.run_cmd(
[
"sed",
"-i",
"'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH'",
".bashrc",
]
)
def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None:
if python_version == "3.6":
# Python-3.6 EOLed and not compatible with conda-4.11
install_condaforge(
host, suffix="download/4.10.3-10/Miniforge3-4.10.3-10-Linux-aarch64.sh"
)
host.run_cmd(f"conda install -y python={python_version} numpy pyyaml")
else:
install_condaforge(
host, suffix="download/4.11.0-4/Miniforge3-4.11.0-4-Linux-aarch64.sh"
)
# Pytorch-1.10 or older are not compatible with setuptools=59.6 or newer
host.run_cmd(
f"conda install -y python={python_version} numpy pyyaml setuptools>=59.5.0"
)
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
host.run_cmd("pip3 install auditwheel")
host.run_cmd(
"conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf"
)
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as tmp:
tmp.write(embed_library_script.encode("utf-8"))
tmp.flush()
host.upload_file(tmp.name, "embed_library.py")
print("Embedding libgomp into wheel")
if host.using_docker():
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
else:
host.run_cmd(f"python3 embed_library.py {wheel_name}")
def checkout_repo(
host: RemoteHost,
*,
branch: str = "main",
url: str,
git_clone_flags: str,
mapping: dict[str, tuple[str, str]],
) -> Optional[str]:
for prefix in mapping:
if not branch.startswith(prefix):
continue
tag = f"v{mapping[prefix][0]}-{mapping[prefix][1]}"
host.run_cmd(f"git clone {url} -b {tag} {git_clone_flags}")
return mapping[prefix][0]
host.run_cmd(f"git clone {url} -b {branch} {git_clone_flags}")
return None
def build_torchvision(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str,
run_smoke_tests: bool = True,
) -> str:
print("Checking out TorchVision repo")
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/vision",
git_clone_flags=git_clone_flags,
mapping={
"v1.7.1": ("0.8.2", "rc2"),
"v1.8.0": ("0.9.0", "rc3"),
"v1.8.1": ("0.9.1", "rc1"),
"v1.9.0": ("0.10.0", "rc1"),
"v1.10.0": ("0.11.1", "rc1"),
"v1.10.1": ("0.11.2", "rc1"),
"v1.10.2": ("0.11.3", "rc1"),
"v1.11.0": ("0.12.0", "rc1"),
"v1.12.0": ("0.13.0", "rc4"),
"v1.12.1": ("0.13.1", "rc6"),
"v1.13.0": ("0.14.0", "rc4"),
"v1.13.1": ("0.14.1", "rc2"),
"v2.0.0": ("0.15.1", "rc2"),
"v2.0.1": ("0.15.2", "rc2"),
},
)
print("Building TorchVision wheel")
# Please note libnpg and jpeg are required to build image.so extension
if use_conda:
host.run_cmd("conda install -y libpng jpeg")
# Remove .so files to force static linking
host.run_cmd(
"rm miniforge3/lib/libpng.so miniforge3/lib/libpng16.so miniforge3/lib/libjpeg.so"
)
# And patch setup.py to include libz dependency for libpng
host.run_cmd(
[
'sed -i -e \'s/image_link_flags\\.append("png")/image_link_flags += ["png", "z"]/\' vision/setup.py'
]
)
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"]
).strip()
if len(version) == 0:
# In older revisions, version was embedded in setup.py
version = (
host.check_output(["grep", '"version = \'"', "vision/setup.py"])
.strip()
.split("'")[1][:-2]
)
build_date = (
host.check_output("cd vision && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd vision && {build_vars} python3 -m build --wheel --no-isolation")
vision_wheel_name = host.list_dir("vision/dist")[0]
embed_libgomp(host, use_conda, os.path.join("vision", "dist", vision_wheel_name))
print("Copying TorchVision wheel")
host.download_wheel(os.path.join("vision", "dist", vision_wheel_name))
if run_smoke_tests:
host.run_cmd(
f"pip3 install {os.path.join('vision', 'dist', vision_wheel_name)}"
)
host.run_cmd("python3 vision/test/smoke_test.py")
print("Delete vision checkout")
host.run_cmd("rm -rf vision")
return vision_wheel_name
def build_torchdata(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchData repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/data",
git_clone_flags=git_clone_flags,
mapping={
"v1.13.1": ("0.5.1", ""),
"v2.0.0": ("0.6.0", "rc5"),
"v2.0.1": ("0.6.1", "rc1"),
},
)
print("Building TorchData wheel")
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f data/version.txt ]; then cat data/version.txt; fi"]
).strip()
build_date = (
host.check_output("cd data && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd data && {build_vars} python3 -m build --wheel --no-isolation")
wheel_name = host.list_dir("data/dist")[0]
embed_libgomp(host, use_conda, os.path.join("data", "dist", wheel_name))
print("Copying TorchData wheel")
host.download_wheel(os.path.join("data", "dist", wheel_name))
return wheel_name
def build_torchtext(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchText repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/text",
git_clone_flags=git_clone_flags,
mapping={
"v1.9.0": ("0.10.0", "rc1"),
"v1.10.0": ("0.11.0", "rc2"),
"v1.10.1": ("0.11.1", "rc1"),
"v1.10.2": ("0.11.2", "rc1"),
"v1.11.0": ("0.12.0", "rc1"),
"v1.12.0": ("0.13.0", "rc2"),
"v1.12.1": ("0.13.1", "rc5"),
"v1.13.0": ("0.14.0", "rc3"),
"v1.13.1": ("0.14.1", "rc1"),
"v2.0.0": ("0.15.1", "rc2"),
"v2.0.1": ("0.15.2", "rc2"),
},
)
print("Building TorchText wheel")
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f text/version.txt ]; then cat text/version.txt; fi"]
).strip()
build_date = (
host.check_output("cd text && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd text && {build_vars} python3 -m build --wheel --no-isolation")
wheel_name = host.list_dir("text/dist")[0]
embed_libgomp(host, use_conda, os.path.join("text", "dist", wheel_name))
print("Copying TorchText wheel")
host.download_wheel(os.path.join("text", "dist", wheel_name))
return wheel_name
def build_torchaudio(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchAudio repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/audio",
git_clone_flags=git_clone_flags,
mapping={
"v1.9.0": ("0.9.0", "rc2"),
"v1.10.0": ("0.10.0", "rc5"),
"v1.10.1": ("0.10.1", "rc1"),
"v1.10.2": ("0.10.2", "rc1"),
"v1.11.0": ("0.11.0", "rc1"),
"v1.12.0": ("0.12.0", "rc3"),
"v1.12.1": ("0.12.1", "rc5"),
"v1.13.0": ("0.13.0", "rc4"),
"v1.13.1": ("0.13.1", "rc2"),
"v2.0.0": ("2.0.1", "rc3"),
"v2.0.1": ("2.0.2", "rc2"),
},
)
print("Building TorchAudio wheel")
build_vars = ""
if branch == "nightly":
version = (
host.check_output(["grep", '"version = \'"', "audio/setup.py"])
.strip()
.split("'")[1][:-2]
)
build_date = (
host.check_output("cd audio && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(
f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \
&& ./packaging/ffmpeg/build.sh \
&& {build_vars} python3 -m build --wheel --no-isolation"
)
wheel_name = host.list_dir("audio/dist")[0]
embed_libgomp(host, use_conda, os.path.join("audio", "dist", wheel_name))
print("Copying TorchAudio wheel")
host.download_wheel(os.path.join("audio", "dist", wheel_name))
return wheel_name
def configure_system(
host: RemoteHost,
*,
compiler: str = "gcc-8",
use_conda: bool = True,
python_version: str = "3.8",
) -> None:
if use_conda:
install_condaforge_python(host, python_version)
print("Configuring the system")
if not host.using_docker():
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip")
else:
host.run_cmd("yum install -y sudo")
host.run_cmd("conda install -y ninja scons")
if not use_conda:
host.run_cmd(
"sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip"
)
host.run_cmd("pip3 install dataclasses typing-extensions")
if not use_conda:
print("Installing Cython + numpy from PyPy")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
def build_domains(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> tuple[str, str, str, str]:
vision_wheel_name = build_torchvision(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
audio_wheel_name = build_torchaudio(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
data_wheel_name = build_torchdata(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
text_wheel_name = build_torchtext(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
return (vision_wheel_name, audio_wheel_name, data_wheel_name, text_wheel_name)
def start_build(
host: RemoteHost,
*,
branch: str = "main",
compiler: str = "gcc-8",
use_conda: bool = True,
python_version: str = "3.8",
pytorch_only: bool = False,
pytorch_build_number: Optional[str] = None,
shallow_clone: bool = True,
enable_mkldnn: bool = False,
) -> tuple[str, str, str, str, str]:
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
if host.using_docker() and not use_conda:
print("Auto-selecting conda option for docker images")
use_conda = True
if not host.using_docker():
print("Disable mkldnn for host builds")
enable_mkldnn = False
configure_system(
host, compiler=compiler, use_conda=use_conda, python_version=python_version
)
if host.using_docker():
print("Move libgfortant.a into a standard location")
# HACK: pypa gforntran.a is compiled without PIC, which leads to the following error
# libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17' # noqa: E501, B950
# Workaround by copying gfortran library from the host
host.run_ssh_cmd("sudo apt-get install -y gfortran-8")
host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8")
host.run_ssh_cmd(
[
"docker",
"cp",
"/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a",
f"{host.container_id}:/opt/rh/devtoolset-10/root/usr/lib/gcc/aarch64-redhat-linux/10/",
]
)
print("Checking out PyTorch repo")
host.run_cmd(
f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}"
)
host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh")
print("Building PyTorch wheel")
build_opts = ""
if pytorch_build_number is not None:
build_opts += f" -C--build-option=--build-number={pytorch_build_number}"
# Breakpad build fails on aarch64
build_vars = "USE_BREAKPAD=0 "
if branch == "nightly":
build_date = (
host.check_output("cd pytorch && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
version = host.check_output("cat pytorch/version.txt").strip()[:-2]
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1"
if branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
if enable_mkldnn:
host.run_cmd("pytorch/.ci/docker/common/install_acl.sh")
print("build pytorch with mkldnn+acl backend")
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
build_vars += " BLAS=OpenBLAS"
build_vars += " OpenBLAS_HOME=/opt/OpenBLAS"
build_vars += " ACL_ROOT_DIR=/acl"
host.run_cmd(
f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
)
print("Repair the wheel")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
ld_library_path = "/acl/build:$HOME/pytorch/build/lib"
host.run_cmd(
f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}"
)
print("replace the original wheel with the repaired one")
pytorch_repaired_wheel_name = host.list_dir("wheelhouse")[0]
host.run_cmd(
f"cp $HOME/wheelhouse/{pytorch_repaired_wheel_name} $HOME/pytorch/dist/{pytorch_wheel_name}"
)
else:
print("build pytorch without mkldnn backend")
host.run_cmd(
f"cd pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
)
print("Deleting build folder")
host.run_cmd("cd pytorch && rm -rf build")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
embed_libgomp(host, use_conda, os.path.join("pytorch", "dist", pytorch_wheel_name))
print("Copying the wheel")
host.download_wheel(os.path.join("pytorch", "dist", pytorch_wheel_name))
print("Installing PyTorch wheel")
host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}")
if pytorch_only:
return (pytorch_wheel_name, None, None, None, None)
domain_wheels = build_domains(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
return (pytorch_wheel_name, *domain_wheels)
embed_library_script = """
#!/usr/bin/env python3
from auditwheel.patcher import Patchelf
from auditwheel.wheeltools import InWheelCtx
from auditwheel.elfutils import elf_file_filter
from auditwheel.repair import copylib
from auditwheel.lddtree import lddtree
from subprocess import check_call
import os
import shutil
import sys
from tempfile import TemporaryDirectory
def replace_tag(filename):
with open(filename, 'r') as f:
lines = f.read().split("\\n")
for i,line in enumerate(lines):
if not line.startswith("Tag: "):
continue
lines[i] = line.replace("-linux_", "-manylinux2014_")
print(f'Updated tag from {line} to {lines[i]}')
with open(filename, 'w') as f:
f.write("\\n".join(lines))
class AlignedPatchelf(Patchelf):
def set_soname(self, file_name: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name])
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name])
def embed_library(whl_path, lib_soname, update_tag=False):
patcher = AlignedPatchelf()
out_dir = TemporaryDirectory()
whl_name = os.path.basename(whl_path)
tmp_whl_name = os.path.join(out_dir.name, whl_name)
with InWheelCtx(whl_path) as ctx:
torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib')
ctx.out_wheel=tmp_whl_name
new_lib_path, new_lib_soname = None, None
for filename, elf in elf_file_filter(ctx.iter_files()):
if not filename.startswith('torch/lib'):
continue
libtree = lddtree(filename)
if lib_soname not in libtree['needed']:
continue
lib_path = libtree['libs'][lib_soname]['path']
if lib_path is None:
print(f"Can't embed {lib_soname} as it could not be found")
break
if lib_path.startswith(torchlib_path):
continue
if new_lib_path is None:
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
patcher.replace_needed(filename, lib_soname, new_lib_soname)
print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}')
if update_tag:
# Add manylinux2014 tag
for filename in ctx.iter_files():
if os.path.basename(filename) != 'WHEEL':
continue
replace_tag(filename)
shutil.move(tmp_whl_name, whl_path)
if __name__ == '__main__':
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
"""
def run_tests(host: RemoteHost, whl: str, branch="main") -> None:
print("Configuring the system")
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y python3-pip git")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
host.upload_file(whl, ".")
host.run_cmd(f"sudo pip3 install {whl}")
host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'")
host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch")
host.run_cmd("cd pytorch/test; python3 test_torch.py -v")
def get_instance_name(instance) -> Optional[str]:
if instance.tags is None:
return None
for tag in instance.tags:
if tag["Key"] == "Name":
return tag["Value"]
return None
def list_instances(instance_type: str) -> None:
print(f"All instances of type {instance_type}")
for instance in ec2_instances_of_type(instance_type):
ifaces = instance.network_interfaces
az = ifaces[0].subnet.availability_zone if len(ifaces) > 0 else None
print(
f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']} {az}"
)
def terminate_instances(instance_type: str) -> None:
print(f"Terminating all instances of type {instance_type}")
instances = list(ec2_instances_of_type(instance_type))
for instance in instances:
print(f"Terminating {instance.id}")
instance.terminate()
print("Waiting for termination to complete")
for instance in instances:
instance.wait_until_terminated()
def parse_arguments():
from argparse import ArgumentParser
parser = ArgumentParser("Build and test AARCH64 wheels using EC2")
parser.add_argument("--key-name", type=str)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")
parser.add_argument("--test-only", type=str)
group = parser.add_mutually_exclusive_group()
group.add_argument("--os", type=str, choices=list(os_amis.keys()))
group.add_argument("--ami", type=str)
parser.add_argument(
"--python-version",
type=str,
choices=[f"3.{d}" for d in range(6, 12)],
default=None,
)
parser.add_argument("--alloc-instance", action="store_true")
parser.add_argument("--list-instances", action="store_true")
parser.add_argument("--pytorch-only", action="store_true")
parser.add_argument("--keep-running", action="store_true")
parser.add_argument("--terminate-instances", action="store_true")
parser.add_argument("--instance-type", type=str, default="t4g.2xlarge")
parser.add_argument("--ebs-size", type=int, default=50)
parser.add_argument("--branch", type=str, default="main")
parser.add_argument("--use-docker", action="store_true")
parser.add_argument(
"--compiler",
type=str,
choices=["gcc-7", "gcc-8", "gcc-9", "clang"],
default="gcc-8",
)
parser.add_argument("--use-torch-from-pypi", action="store_true")
parser.add_argument("--pytorch-build-number", type=str, default=None)
parser.add_argument("--disable-mkldnn", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
ami = (
args.ami
if args.ami is not None
else os_amis[args.os]
if args.os is not None
else ubuntu20_04_ami
)
keyfile_path, key_name = compute_keyfile_path(args.key_name)
if args.list_instances:
list_instances(args.instance_type)
sys.exit(0)
if args.terminate_instances:
terminate_instances(args.instance_type)
sys.exit(0)
if len(key_name) == 0:
raise RuntimeError("""
Cannot start build without key_name, please specify
--key-name argument or AWS_KEY_NAME environment variable.""")
if len(keyfile_path) == 0 or not os.path.exists(keyfile_path):
raise RuntimeError(f"""
Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please
check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""")
# Starting the instance
inst = start_instance(
key_name, ami=ami, instance_type=args.instance_type, ebs_size=args.ebs_size
)
instance_name = f"{args.key_name}-{args.os}"
if args.python_version is not None:
instance_name += f"-py{args.python_version}"
inst.create_tags(
DryRun=False,
Tags=[
{
"Key": "Name",
"Value": instance_name,
}
],
)
addr = inst.public_dns_name
wait_for_connection(addr, 22)
host = RemoteHost(addr, keyfile_path)
host.ami = ami
if args.use_docker:
update_apt_repo(host)
host.start_docker()
if args.test_only:
run_tests(host, args.test_only)
sys.exit(0)
if args.alloc_instance:
if args.python_version is None:
sys.exit(0)
install_condaforge_python(host, args.python_version)
sys.exit(0)
python_version = args.python_version if args.python_version is not None else "3.10"
if args.use_torch_from_pypi:
configure_system(host, compiler=args.compiler, python_version=python_version)
print("Installing PyTorch wheel")
host.run_cmd("pip3 install torch")
build_domains(
host, branch=args.branch, git_clone_flags=" --depth 1 --shallow-submodules"
)
else:
start_build(
host,
branch=args.branch,
compiler=args.compiler,
python_version=python_version,
pytorch_only=args.pytorch_only,
pytorch_build_number=args.pytorch_build_number,
enable_mkldnn=not args.disable_mkldnn,
)
if not args.keep_running:
print(f"Waiting for instance {inst.id} to terminate")
inst.terminate()
inst.wait_until_terminated()

View File

@ -0,0 +1,87 @@
#!/usr/bin/env python3
import os
import shutil
import sys
from subprocess import check_call
from tempfile import TemporaryDirectory
from auditwheel.elfutils import elf_file_filter
from auditwheel.lddtree import lddtree
from auditwheel.patcher import Patchelf
from auditwheel.repair import copylib
from auditwheel.wheeltools import InWheelCtx
def replace_tag(filename):
with open(filename) as f:
lines = f.read().split("\\n")
for i, line in enumerate(lines):
if not line.startswith("Tag: "):
continue
lines[i] = line.replace("-linux_", "-manylinux2014_")
print(f"Updated tag from {line} to {lines[i]}")
with open(filename, "w") as f:
f.write("\\n".join(lines))
class AlignedPatchelf(Patchelf):
def set_soname(self, file_name: str, new_soname: str) -> None:
check_call(
["patchelf", "--page-size", "65536", "--set-soname", new_soname, file_name]
)
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
check_call(
[
"patchelf",
"--page-size",
"65536",
"--replace-needed",
soname,
new_soname,
file_name,
]
)
def embed_library(whl_path, lib_soname, update_tag=False):
patcher = AlignedPatchelf()
out_dir = TemporaryDirectory()
whl_name = os.path.basename(whl_path)
tmp_whl_name = os.path.join(out_dir.name, whl_name)
with InWheelCtx(whl_path) as ctx:
torchlib_path = os.path.join(ctx._tmpdir.name, "torch", "lib")
ctx.out_wheel = tmp_whl_name
new_lib_path, new_lib_soname = None, None
for filename, _ in elf_file_filter(ctx.iter_files()):
if not filename.startswith("torch/lib"):
continue
libtree = lddtree(filename)
if lib_soname not in libtree["needed"]:
continue
lib_path = libtree["libs"][lib_soname]["path"]
if lib_path is None:
print(f"Can't embed {lib_soname} as it could not be found")
break
if lib_path.startswith(torchlib_path):
continue
if new_lib_path is None:
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
patcher.replace_needed(filename, lib_soname, new_lib_soname)
print(f"Replacing {lib_soname} with {new_lib_soname} for {filename}")
if update_tag:
# Add manylinux2014 tag
for filename in ctx.iter_files():
if os.path.basename(filename) != "WHEEL":
continue
replace_tag(filename)
shutil.move(tmp_whl_name, whl_path)
if __name__ == "__main__":
embed_library(
sys.argv[1], "libgomp.so.1", len(sys.argv) > 2 and sys.argv[2] == "--update-tag"
)

View File

@ -30,6 +30,7 @@ into a tarball, with the following structure:
More specifically, `build_magma.sh` copies over the relevant files from the `package_files` directory depending on the ROCm version.
Outputted binaries should be in the `output` folder.
## Pushing
Packages can be uploaded to an S3 bucket using:

View File

@ -4,17 +4,14 @@ set -ex
SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Source the common build script for architecture-specific configurations (MKLDNN, ACL, etc.)
source "${SCRIPTPATH}/../pytorch/build.sh" || true
case "${GPU_ARCH_TYPE:-BLANK}" in
cuda | cuda-aarch64)
cuda)
bash "${SCRIPTPATH}/build_cuda.sh"
;;
rocm)
bash "${SCRIPTPATH}/build_rocm.sh"
;;
cpu | cpu-cxx11-abi | cpu-aarch64 | cpu-s390x)
cpu | cpu-cxx11-abi | cpu-s390x)
bash "${SCRIPTPATH}/build_cpu.sh"
;;
xpu)

View File

@ -18,31 +18,12 @@ retry () {
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
}
# Detect architecture first
ARCH=$(uname -m)
echo "Detected architecture: $ARCH"
PLATFORM=""
# TODO move this into the Docker images
OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release)
if [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
retry yum install -q -y zip openssl
# Set platform based on architecture
case $ARCH in
x86_64)
PLATFORM="manylinux_2_28_x86_64"
;;
aarch64)
PLATFORM="manylinux_2_28_aarch64"
;;
s390x)
PLATFORM="manylinux_2_28_s390x"
;;
*)
echo "Unsupported architecture: $ARCH"
exit 1
;;
esac
PLATFORM="manylinux_2_28_x86_64"
elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
retry dnf install -q -y zip openssl
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
@ -57,8 +38,6 @@ else
exit 1
fi
echo "Platform set to: $PLATFORM"
# We use the package name to test the package by passing this to 'pip install'
# This is the env variable that setup.py uses to name the package. Note that
# pip 'normalizes' the name first by changing all - to _
@ -320,8 +299,8 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
# ROCm workaround for roctracer dlopens
if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then
patchedpath=$(fname_without_so_number $destpath)
# Keep the so number for XPU dependencies, libgomp.so.1, ACL libraries, and NVPL libraries to avoid twice load
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" || "$filename" == libarm_compute* || "$filename" == libnvpl* || "$filename" == "libgfortran.so.5" ]]; then
# Keep the so number for XPU dependencies and libgomp.so.1 to avoid twice load
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" ]]; then
patchedpath=$destpath
else
patchedpath=$(fname_with_sha256 $destpath)
@ -367,22 +346,9 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
done
# create Manylinux 2_28 tag this needs to happen before regenerate the RECORD
# Support all architectures (x86_64, aarch64, s390x)
if [[ "$IS_MANYLINUX2_28" == "1" && $GPU_ARCH_TYPE != "xpu" ]]; then
if [[ $PLATFORM == "manylinux_2_28_x86_64" && $GPU_ARCH_TYPE != "cpu-s390x" && $GPU_ARCH_TYPE != "xpu" ]]; then
wheel_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/WHEEL/g')
echo "Updating wheel tag for $ARCH architecture"
# Replace linux_* with manylinux_2_28_* based on architecture
case $ARCH in
x86_64)
sed -i -e 's#linux_x86_64#manylinux_2_28_x86_64#g' $wheel_file
;;
aarch64)
sed -i -e 's#linux_aarch64#manylinux_2_28_aarch64#g' $wheel_file
;;
s390x)
sed -i -e 's#linux_s390x#manylinux_2_28_s390x#g' $wheel_file
;;
esac
sed -i -e s#linux_x86_64#"${PLATFORM}"# $wheel_file;
fi
# regenerate the RECORD file with new hashes

View File

@ -15,10 +15,6 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
EXTRA_CAFFE2_CMAKE_FLAGS=()
fi
# Detect architecture
ARCH=$(uname -m)
echo "Building CPU wheel for architecture: $ARCH"
WHEELHOUSE_DIR="wheelhousecpu"
LIBTORCH_HOUSE_DIR="libtorch_housecpu"
if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then
@ -38,10 +34,8 @@ elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
LIBGOMP_PATH="/usr/lib64/libgomp.so.1"
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
if [[ "$ARCH" == "s390x" ]]; then
if [[ "$(uname -m)" == "s390x" ]]; then
LIBGOMP_PATH="/usr/lib/s390x-linux-gnu/libgomp.so.1"
elif [[ "$ARCH" == "aarch64" ]]; then
LIBGOMP_PATH="/usr/lib/aarch64-linux-gnu/libgomp.so.1"
else
LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1"
fi
@ -55,32 +49,6 @@ DEPS_SONAME=(
"libgomp.so.1"
)
# Add ARM-specific library dependencies for CPU builds
if [[ "$ARCH" == "aarch64" ]]; then
echo "Adding ARM-specific CPU library dependencies"
# ARM Compute Library (if available)
if [[ -d "/acl/build" ]]; then
echo "Adding ARM Compute Library for CPU"
DEPS_LIST+=(
"/acl/build/libarm_compute.so"
"/acl/build/libarm_compute_graph.so"
)
DEPS_SONAME+=(
"libarm_compute.so"
"libarm_compute_graph.so"
)
fi
# ARM system libraries
DEPS_LIST+=(
"/usr/lib64/libgfortran.so.5"
)
DEPS_SONAME+=(
"libgfortran.so.5"
)
fi
rm -rf /usr/local/cuda*
SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )"

View File

@ -29,10 +29,6 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
EXTRA_CAFFE2_CMAKE_FLAGS=()
fi
# Detect architecture
ARCH=$(uname -m)
echo "Building for architecture: $ARCH"
# Determine CUDA version and architectures to build for
#
# NOTE: We should first check `DESIRED_CUDA` when determining `CUDA_VERSION`,
@ -57,60 +53,34 @@ fi
cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
# Function to remove architectures from a list
remove_archs() {
local result="$1"
shift
for arch in "$@"; do
result="${result//${arch};/}"
done
echo "$result"
}
# Function to filter CUDA architectures for aarch64
# aarch64 ARM GPUs only support certain compute capabilities
# Keep: 8.0 (A100), 9.0+ (Hopper, Grace Hopper, newer)
# Remove: < 8.0 (no ARM GPUs), 8.6 (x86_64 RTX 3090/A6000 only)
filter_aarch64_archs() {
local arch_list="$1"
# Explicitly remove architectures not needed on aarch64
arch_list=$(remove_archs "$arch_list" "5.0" "6.0" "7.0" "7.5" "8.6")
echo "$arch_list"
}
# Base: Common architectures across all modern CUDA versions
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0"
case ${CUDA_VERSION} in
12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;${TORCH_CUDA_ARCH_LIST}" ;; # Only 12.6 includes Legacy Maxwell/Pascal that will be removed in future releases
12.8) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0" ;; # +Hopper/Blackwell support
12.9) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0+PTX" # +Hopper/Blackwell support + PTX for forward compatibility
#removing sm_50-sm_60 as these architectures are deprecated in CUDA 12.8/9 and will be removed in future releases
#however we would like to keep sm_70 architecture see: https://github.com/pytorch/pytorch/issues/157517
12.8)
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0"
;;
12.9)
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0+PTX"
# WAR to resolve the ld error in libtorch build with CUDA 12.9
if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//7.0;/}" # Remove 7.0 to resolve the ld error
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//8.6;/}" # Remove 8.6 for libtorch
TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX"
fi
;;
13.0)
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;$([[ "$ARCH" == "aarch64" ]] && echo "11.0;" || echo "")12.0+PTX"
export TORCH_NVCC_FLAGS="-compress-mode=size"
export BUILD_BUNDLE_PTXAS=1
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX"
;;
12.6)
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0"
;;
*)
echo "unknown cuda version $CUDA_VERSION"
exit 1
;;
*) echo "unknown cuda version $CUDA_VERSION"; exit 1 ;;
esac
# Filter for aarch64: Remove < 8.0 and 8.6
[[ "$ARCH" == "aarch64" ]] && TORCH_CUDA_ARCH_LIST=$(filter_aarch64_archs "$TORCH_CUDA_ARCH_LIST")
echo "TORCH_CUDA_ARCH_LIST set to: $TORCH_CUDA_ARCH_LIST"
export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}
echo "${TORCH_CUDA_ARCH_LIST}"
# Disable MAGMA for aarch64 as pre-built libraries are x86-64 only
if [[ "$ARCH" == "aarch64" ]]; then
echo "Disabling MAGMA for aarch64 architecture"
export USE_MAGMA=0
fi
# Package directories
WHEELHOUSE_DIR="wheelhouse$cuda_version_nodot"
LIBTORCH_HOUSE_DIR="libtorch_house$cuda_version_nodot"
@ -274,51 +244,6 @@ else
exit 1
fi
# Add ARM-specific library dependencies
if [[ "$ARCH" == "aarch64" ]]; then
echo "Adding ARM-specific library dependencies"
# ARM Compute Library (if available)
if [[ -d "/acl/build" ]]; then
echo "Adding ARM Compute Library"
DEPS_LIST+=(
"/acl/build/libarm_compute.so"
"/acl/build/libarm_compute_graph.so"
)
DEPS_SONAME+=(
"libarm_compute.so"
"libarm_compute_graph.so"
)
fi
# ARM system libraries
DEPS_LIST+=(
"/lib64/libgomp.so.1"
"/usr/lib64/libgfortran.so.5"
)
DEPS_SONAME+=(
"libgomp.so.1"
"libgfortran.so.5"
)
# NVPL libraries (ARM optimized BLAS/LAPACK)
if [[ -d "/usr/local/lib" && -f "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ]]; then
echo "Adding NVPL libraries for ARM"
DEPS_LIST+=(
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0"
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0"
"/usr/local/lib/libnvpl_lapack_core.so.0"
"/usr/local/lib/libnvpl_blas_core.so.0"
)
DEPS_SONAME+=(
"libnvpl_lapack_lp64_gomp.so.0"
"libnvpl_blas_lp64_gomp.so.0"
"libnvpl_lapack_core.so.0"
"libnvpl_blas_core.so.0"
)
fi
fi
# run_tests.sh requires DESIRED_CUDA to know what tests to exclude
export DESIRED_CUDA="$cuda_version_nodot"
@ -326,11 +251,9 @@ export DESIRED_CUDA="$cuda_version_nodot"
rm -rf /usr/local/cuda || true
ln -s "/usr/local/cuda-${CUDA_VERSION}" /usr/local/cuda
# Switch `/usr/local/magma` to the desired CUDA version (skip for aarch64)
if [[ "$ARCH" != "aarch64" ]]; then
rm -rf /usr/local/magma || true
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
fi
# Switch `/usr/local/magma` to the desired CUDA version
rm -rf /usr/local/magma || true
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
export CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) # 10.0.130
export CUDA_VERSION_SHORT=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev | cut -f1,2 -d".") # 10.0

View File

@ -86,20 +86,10 @@ else
fi
fi
# Enable MKLDNN with ARM Compute Library for ARM builds
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
export USE_MKLDNN=1
# ACL is required for aarch64 builds
if [[ ! -d "/acl" ]]; then
echo "ERROR: ARM Compute Library not found at /acl"
echo "ACL is required for aarch64 builds. Check Docker image setup."
exit 1
fi
export USE_MKLDNN_ACL=1
export ACL_ROOT_DIR=/acl
echo "ARM Compute Library enabled for MKLDNN: ACL_ROOT_DIR=/acl"
fi
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then

View File

@ -96,6 +96,7 @@ function pip_build_and_install() {
python3 -m pip wheel \
--no-build-isolation \
--no-deps \
--no-use-pep517 \
-w "${wheel_dir}" \
"${build_target}"
fi
@ -307,28 +308,6 @@ function install_torchao() {
pip_build_and_install "git+https://github.com/pytorch/ao.git@${commit}" dist/ao
}
function install_flash_attn_cute() {
echo "Installing FlashAttention CuTe from GitHub..."
# Grab latest main til we have a pinned commit
local flash_attn_commit
flash_attn_commit=$(git ls-remote https://github.com/Dao-AILab/flash-attention.git HEAD | cut -f1)
# Clone the repo to a temporary directory
rm -rf flash-attention-build
git clone --depth 1 --recursive https://github.com/Dao-AILab/flash-attention.git flash-attention-build
pushd flash-attention-build
git checkout "${flash_attn_commit}"
# Install only the 'cute' sub-directory
pip_install -e flash_attn/cute/
popd
# remove the local repo
rm -rf flash-attention-build
echo "FlashAttention CuTe installation complete."
}
function print_sccache_stats() {
echo 'PyTorch Build Statistics'
sccache --show-stats

View File

@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
)
def _compile_and_extract_symbols(
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
) -> list[str]:
"""
Helper to compile a C++ file and extract all symbols.
Args:
cpp_content: C++ source code to compile
compile_flags: Compilation flags
exclude_list: List of symbol names to exclude. Defaults to ["main"].
Returns:
List of all symbols found in the object file (excluding those in exclude_list).
"""
import subprocess
import tempfile
if exclude_list is None:
exclude_list = ["main"]
with tempfile.TemporaryDirectory() as tmpdir:
tmppath = Path(tmpdir)
cpp_file = tmppath / "test.cpp"
obj_file = tmppath / "test.o"
cpp_file.write_text(cpp_content)
result = subprocess.run(
compile_flags + [str(cpp_file), "-o", str(obj_file)],
capture_output=True,
text=True,
timeout=60,
)
if result.returncode != 0:
raise RuntimeError(f"Compilation failed: {result.stderr}")
symbols = get_symbols(str(obj_file))
# Return all symbol names, excluding those in the exclude list
return [name for _addr, _stype, name in symbols if name not in exclude_list]
def check_stable_only_symbols(install_root: Path) -> None:
"""
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
This approach tests:
1. WITHOUT macros -> many torch symbols exposed
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
4. WITH both macros -> zero torch symbols (all hidden)
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
test_cpp_content = """
// Main torch C++ API headers
#include <torch/torch.h>
#include <torch/all.h>
// ATen tensor library
#include <ATen/ATen.h>
// Core c10 headers (commonly used)
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Optional.h>
int main() { return 0; }
"""
base_compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c", # Compile only, don't link
]
# Compile WITHOUT any macros
symbols_without = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=base_compile_flags,
)
# We expect constexpr symbols, inline functions used by other headers etc.
# to produce symbols
num_symbols_without = len(symbols_without)
print(f"Found {num_symbols_without} symbols without any macros defined")
assert num_symbols_without != 0, (
"Expected a non-zero number of symbols without any macros"
)
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
symbols_with_stable_only = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_stable_only,
)
num_symbols_with_stable_only = len(symbols_with_stable_only)
assert num_symbols_with_stable_only == 0, (
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
)
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
compile_flags_with_target_version = base_compile_flags + [
"-DTORCH_TARGET_VERSION=1"
]
symbols_with_target_version = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_target_version,
)
num_symbols_with_target_version = len(symbols_with_target_version)
assert num_symbols_with_target_version == 0, (
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
)
# Compile WITH both macros (expect 0 symbols)
compile_flags_with_both = base_compile_flags + [
"-DTORCH_STABLE_ONLY",
"-DTORCH_TARGET_VERSION=1",
]
symbols_with_both = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_both,
)
num_symbols_with_both = len(symbols_with_both)
assert num_symbols_with_both == 0, (
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
)
def check_stable_api_symbols(install_root: Path) -> None:
"""
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
stable_dir = include_dir / "torch" / "csrc" / "stable"
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
stable_headers = list(stable_dir.rglob("*.h"))
if not stable_headers:
raise RuntimeError("Could not find any stable headers")
includes = []
for header in stable_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_stable_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable = _compile_and_extract_symbols(
cpp_content=test_stable_content,
compile_flags=compile_flags,
)
num_symbols_stable = len(symbols_stable)
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
assert num_symbols_stable > 0, (
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable} symbols"
)
def check_headeronly_symbols(install_root: Path) -> None:
"""
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Find all headers in torch/headeronly
headeronly_dir = include_dir / "torch" / "headeronly"
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
headeronly_headers = list(headeronly_dir.rglob("*.h"))
if not headeronly_headers:
raise RuntimeError("Could not find any headeronly headers")
# Filter out platform-specific headers that may not compile everywhere
platform_specific_keywords = [
"cpu/vec",
]
filtered_headers = []
for header in headeronly_headers:
rel_path = header.relative_to(include_dir).as_posix()
if not any(
keyword in rel_path.lower() for keyword in platform_specific_keywords
):
filtered_headers.append(header)
includes = []
for header in filtered_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_headeronly_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_headeronly = _compile_and_extract_symbols(
cpp_content=test_headeronly_content,
compile_flags=compile_flags,
)
num_symbols_headeronly = len(symbols_headeronly)
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
assert num_symbols_headeronly > 0, (
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_headeronly} symbols"
)
def check_aoti_shim_symbols(install_root: Path) -> None:
"""
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_shim_content = """
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
int main() {
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
int32_t (*fp2)() = &aoti_torch_dtype_float32;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_shim = _compile_and_extract_symbols(
cpp_content=test_shim_content,
compile_flags=compile_flags,
)
num_symbols_shim = len(symbols_shim)
assert num_symbols_shim > 0, (
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_shim} symbols"
)
def check_stable_c_shim_symbols(install_root: Path) -> None:
"""
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Check if the stable C shim exists
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
if not stable_shim.exists():
raise RuntimeError("Could not find stable c shim")
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_stable_shim_content = """
#include <torch/csrc/stable/c/shim.h>
int main() {
// Reference stable C API functions to create undefined symbols
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable_shim = _compile_and_extract_symbols(
cpp_content=test_stable_shim_content,
compile_flags=compile_flags,
)
num_symbols_stable_shim = len(symbols_stable_shim)
assert num_symbols_stable_shim > 0, (
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable_shim} symbols"
)
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
print(f"lib: {lib}")
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
@ -460,13 +129,6 @@ def main() -> None:
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
# Check symbols when TORCH_STABLE_ONLY is defined
check_stable_only_symbols(install_root)
check_stable_api_symbols(install_root)
check_headeronly_symbols(install_root)
check_aoti_shim_symbols(install_root)
check_stable_c_shim_symbols(install_root)
if __name__ == "__main__":
main()

View File

@ -353,17 +353,6 @@ def test_linalg(device="cpu") -> None:
torch.linalg.svd(A)
def test_sdpa(device="cpu", dtype=torch.float16) -> None:
"""Regression test for https://github.com/pytorch/pytorch/issues/167602
Without nvrtc_builtins on CuDNN-9.13 on CUDA-13 fails with ` No valid execution plans built.`
"""
print(f"Testing SDPA on {device} using type {dtype}")
k, q, v = torch.rand(3, 1, 16, 77, 64, dtype=dtype, device=device).unbind(0)
attn = torch.rand(1, 1, 77, 77, dtype=dtype, device=device)
rc = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn)
assert rc.isnan().any().item() is False
def smoke_test_compile(device: str = "cpu") -> None:
supported_dtypes = [torch.float16, torch.float32, torch.float64]
@ -500,12 +489,10 @@ def main() -> None:
smoke_test_conv2d()
test_linalg()
test_numpy()
test_sdpa()
if is_cuda_system:
test_linalg("cuda")
test_cuda_gds_errors_captured()
test_sdpa("cuda")
if options.package == "all":
smoke_test_modules()

View File

@ -344,18 +344,8 @@ test_python_smoke() {
}
test_python_smoke_b200() {
# Targeted smoke tests for B200 including FlashAttention CuTe coverage
install_flash_attn_cute
time python test/run_test.py \
--include \
test_matmul_cuda \
test_scaled_matmul_cuda \
inductor/test_fp8 \
nn/attention/test_fa4 \
nn/attention/test_open_registry \
inductor/test_flex_flash \
$PYTHON_TEST_EXTRA_OPTION \
--upload-artifacts-while-running
# Targeted smoke tests for B200 - staged approach to avoid too many failures
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}
@ -1680,22 +1670,6 @@ test_operator_microbenchmark() {
done
}
test_attention_microbenchmark() {
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"
TEST_DIR=$(pwd)
# Install attention-gym dependency
echo "Installing attention-gym..."
python -m pip install git+https://github.com/meta-pytorch/attention-gym.git@main
pip show triton
cd "${TEST_DIR}"/benchmarks/transformer
$TASKSET python score_mod.py --config configs/config_basic.yaml \
--output-json-for-dashboard "${TEST_REPORTS_DIR}/attention_microbenchmark.json"
}
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
(cd test && python -c "import torch; print(torch.__config__.show())")
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
@ -1753,8 +1727,6 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then
fi
elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then
test_operator_microbenchmark
elif [[ "${TEST_CONFIG}" == *attention_microbenchmark* ]]; then
test_attention_microbenchmark
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then

View File

@ -63,7 +63,7 @@ self-hosted-runner:
- linux.rocm.gpu.gfx942.1
- linux.rocm.gpu.gfx942.2
- linux.rocm.gpu.gfx942.4
- linux.rocm.gfx942.docker-cache
- rocm-docker
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
- macos-m1-stable
- macos-m1-14

View File

@ -1 +1 @@
07b6cbde121417a70e4dc871adb6d27030e0ce3f
ad5816f0eee1c873df1b7d371c69f1f811a89387

View File

@ -1 +1 @@
acccf86477759b2d3500f1ae1be065f7b1e409ec
ccb801b88af136454798b945175c4c87e636ac33

View File

@ -50,7 +50,7 @@ def get_tag() -> str:
def get_base_version() -> str:
root = get_pytorch_root()
dirty_version = Path(root / "version.txt").read_text().strip()
dirty_version = open(root / "version.txt").read().strip()
# Strips trailing a0 from version.txt, not too sure why it's there in the
# first place
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)

View File

@ -260,8 +260,11 @@ jobs:
"${DOCKER_IMAGE}"
)
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
# Unified build script for all architectures (x86_64, aarch64, s390x)
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
if [[ ${BUILD_ENVIRONMENT} == *"aarch64"* ]]; then
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/aarch64_linux/aarch64_ci_build.sh"
else
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
fi
- name: Chown artifacts
if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }}

View File

@ -1,73 +0,0 @@
name: attention_op_microbenchmark
on:
push:
tags:
- ciflow/op-benchmark/*
workflow_dispatch:
schedule:
# Run at 06:00 UTC everyday
- cron: 0 7 * * *
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
attn-microbenchmark-build:
if: github.repository_owner == 'pytorch'
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '8.0 9.0'
test-matrix: |
{ include: [
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" },
]}
secrets: inherit
attn-microbenchmark-test:
name: attn-microbenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: attn-microbenchmark-build
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }}
secrets: inherit
# B200 runner
opmicrobenchmark-build-b200:
if: github.repository_owner == 'pytorch'
name: opmicrobenchmark-build-b200
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'
test-matrix: |
{ include: [
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
]}
secrets: inherit
opmicrobenchmark-test-b200:
name: opmicrobenchmark-test-b200
uses: ./.github/workflows/_linux-test.yml
needs: opmicrobenchmark-build-b200
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }}
test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
secrets: inherit

View File

@ -37,7 +37,6 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'

View File

@ -37,7 +37,6 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100-symm
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'

View File

@ -119,22 +119,6 @@ jobs:
with:
docker-image: ${{ steps.build-docker-image.outputs.docker-image }}
- name: Generate output
if: contains(matrix.docker-image-name, 'rocm')
id: generate_output
run: |
docker_image_name="${{ matrix.docker-image-name }}"
docker_image_tag="${{ steps.build-docker-image.outputs.docker-image }}"
echo "${docker_image_name}=${docker_image_tag}" >> docker-builds-output-${docker_image_name}.txt
- name: Upload artifacts
uses: actions/upload-artifact@v4.4.0
if: contains(matrix.docker-image-name, 'rocm')
with:
name: docker-builds-artifacts-${{ matrix.docker-image-name }}
retention-days: 14
path: ./docker-builds-output-${{ matrix.docker-image-name }}.txt
- uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
name: Push to https://ghcr.io/
id: push-to-ghcr-io

View File

@ -0,0 +1,55 @@
name: docker-cache-mi300
on:
# run every 6 hours
schedule:
- cron: 0 0,6,12,18 * * *
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
docker-cache:
if: github.repository_owner == 'pytorch'
runs-on: rocm-docker
steps:
- name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
with:
no-sudo: true
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: false
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
push: false
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Tar and upload to S3 bucket
run: |
sudo docker save -o ~/docker-data/pytorch/pytorch_docker_image.tar ${{ steps.calculate-docker-image.outputs.docker-image }}
sudo rclone copy -P --s3-upload-concurrency 64 --s3-chunk-size 200M --s3-upload-cutoff 300M ~/docker-data/pytorch/pytorch_docker_image.tar oci:pytorchbucket0002/pytorch_docker_image --progress

View File

@ -1,105 +0,0 @@
name: docker-cache-rocm
on:
workflow_run:
workflows: [docker-builds]
branches: [main, release]
types:
- completed
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
actions: read
jobs:
download-docker-builds-artifacts:
if: github.repository_owner == 'pytorch'
name: download-docker-builds-artifacts
runs-on: ubuntu-latest
outputs:
pytorch-linux-jammy-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}
pytorch-linux-noble-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}
pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}
steps:
- name: Download artifacts
uses: actions/download-artifact@v4.1.7
with:
run-id: ${{ github.event.workflow_run.id }}
path: ./docker-builds-artifacts
merge-multiple: true
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Process artifacts
id: process-artifacts
run: |
ls -R ./docker-builds-artifacts
cat ./docker-builds-artifacts/*txt >> "${GITHUB_OUTPUT}"
cat "${GITHUB_OUTPUT}"
docker-cache:
if: github.repository_owner == 'pytorch'
needs: download-docker-builds-artifacts
strategy:
fail-fast: false
matrix:
runner: [linux.rocm.gfx942.docker-cache]
docker-image: [
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
]
runs-on: "${{ matrix.runner }}"
steps:
- name: debug
run: |
JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}"
echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}"
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: false
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Generate ghrc.io tag
id: ghcr-io-tag
run: |
ecr_image="${{ matrix.docker-image }}"
ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}"
echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT"
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }}
- name: Save as tarball
run: |
docker_image_tag=${{ matrix.docker-image }}
docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":"
docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-"
ref_name=${{ github.event.workflow_run.head_branch }}
if [[ $ref_name =~ "release/" ]]; then
ref_suffix="release"
elif [[ $ref_name == "main" ]]; then
ref_suffix="main"
else
echo "Unexpected branch in ref_name: ${ref_name}" && exit 1
fi
docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }}
# mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention
docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }}
mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar

View File

@ -1,4 +1,4 @@
name: inductor-rocm-mi200
name: inductor-rocm
on:
schedule:

View File

@ -1,4 +1,4 @@
name: rocm-mi200
name: rocm
on:
push:

View File

@ -5,9 +5,7 @@
# Flow:
# 1. Builds PyTorch with CUDA 12.8+ and sm100 architecture for B200
# 2. Runs smoke tests on linux.dgx.b200 runner
# 3. Tests executed are defined in .ci/pytorch/test.sh -> test_python_smoke_b200() function
# - Includes matmul, scaled_matmul, FP8, and FlashAttention CuTe tests
# - FlashAttention CuTe DSL is installed as part of test execution
# 3. Tests executed are defined in .ci/pytorch/test.sh -> test_python_smoke() function
#
# Triggered by:
# - Pull requests modifying this workflow file
@ -54,7 +52,6 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'
@ -75,4 +72,4 @@ jobs:
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
secrets: inherit
secrets: inherit

View File

@ -1,83 +0,0 @@
name: trunk-rocm-mi300
on:
push:
branches:
- main
- release/*
workflow_dispatch:
schedule:
- cron: 29 8 * * * # about 1:29am PDT
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
llm-td:
if: github.repository_owner == 'pytorch'
name: before-test
uses: ./.github/workflows/llm_td_retrieval.yml
permissions:
id-token: write
contents: read
target-determination:
name: before-test
uses: ./.github/workflows/target_determination.yml
needs: llm-td
permissions:
id-token: write
contents: read
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-rocm-py3_10-build:
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit

View File

@ -5,7 +5,6 @@ on:
workflows:
- pull
- trunk
- trunk-rocm-mi300
- periodic
- periodic-rocm-mi200
- periodic-rocm-mi300

View File

@ -37,7 +37,7 @@ Copyright (c) 2024 Tri Dao.
All rights reserved.
All contributions by Arm:
Copyright (c) 2021, 2023-2025 Arm Limited and/or its affiliates
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors

View File

@ -245,9 +245,6 @@ class TORCH_API TensorBase {
size_t weak_use_count() const noexcept {
return impl_.weak_use_count();
}
bool is_uniquely_owned() const noexcept {
return impl_.is_uniquely_owned();
}
std::string toString() const;

View File

@ -18,8 +18,6 @@
#include <unordered_set>
#include <utility>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
namespace torch {
class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {};
namespace jit {
@ -1632,6 +1630,4 @@ struct TORCH_API WeakOrStrongTypePtr {
} // namespace c10
C10_DIAGNOSTIC_POP()
#include <ATen/core/ivalue_inl.h> // IWYU pragma: keep

View File

@ -29,8 +29,6 @@
#include <c10/util/intrusive_ptr.h>
#include <c10/util/irange.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
namespace torch {
namespace jit {
struct Function;
@ -2569,5 +2567,3 @@ TypePtr IValue::type() const {
}
} // namespace c10
C10_DIAGNOSTIC_POP()

View File

@ -223,62 +223,6 @@ CONVERT_FROM_BF16_TEMPLATE(double)
CONVERT_FROM_BF16_TEMPLATE(float16_t)
#endif
#ifdef __ARM_FEATURE_BF16
// clang-[17, 20] crashes when autovectorizing static cast to bf16
// Below is a workaround to have some vectorization
// Works decently well for smaller int types
template <typename from_type>
inline void convertToBf16Impl(
const from_type* __restrict src,
c10::BFloat16* __restrict dst,
uint64_t n) {
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
uint64_t loopBound = n - (n % 16);
uint64_t i = 0;
for (; i < loopBound; i += 16) {
float32x4_t a, b, c, d;
a[0] = static_cast<float>(src[i]);
a[1] = static_cast<float>(src[i + 1]);
a[2] = static_cast<float>(src[i + 2]);
a[3] = static_cast<float>(src[i + 3]);
b[0] = static_cast<float>(src[i + 4]);
b[1] = static_cast<float>(src[i + 5]);
b[2] = static_cast<float>(src[i + 6]);
b[3] = static_cast<float>(src[i + 7]);
c[0] = static_cast<float>(src[i + 8]);
c[1] = static_cast<float>(src[i + 9]);
c[2] = static_cast<float>(src[i + 10]);
c[3] = static_cast<float>(src[i + 11]);
d[0] = static_cast<float>(src[i + 12]);
d[1] = static_cast<float>(src[i + 13]);
d[2] = static_cast<float>(src[i + 14]);
d[3] = static_cast<float>(src[i + 15]);
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
}
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
for (; i < n; i++) {
float a = static_cast<float>(src[i]);
dstPtr[i] = vcvth_bf16_f32(a);
}
}
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
template <> \
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
return convertToBf16Impl<from_type>(src, dst, n); \
}
CONVERT_TO_BF16_TEMPLATE(uint8_t)
CONVERT_TO_BF16_TEMPLATE(int8_t)
CONVERT_TO_BF16_TEMPLATE(int16_t)
CONVERT_TO_BF16_TEMPLATE(int32_t)
#endif
inline void convertBoolToBfloat16Impl(
const bool* __restrict src,
c10::BFloat16* __restrict dst,

View File

@ -11,8 +11,6 @@
#include <sleef.h>
#endif
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
// Sleef offers vectorized versions of some transcedentals
// such as sin, cos, tan etc..
// However for now opting for STL, since we are not building
@ -652,5 +650,3 @@ inline Vectorized<float> Vectorized<float>::erf() const {
} // namespace CPU_CAPABILITY
} // namespace at::vec
C10_DIAGNOSTIC_POP()

View File

@ -1,7 +1,6 @@
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraph.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/MemPool.h>
#include <ATen/Functions.h>
#include <c10/cuda/CUDAFunctions.h>
@ -14,7 +13,7 @@ static bool _cuda_graphs_debug = false;
MempoolId_t graph_pool_handle() {
// Sets just the second value, to distinguish it from MempoolId_ts created from
// cudaStreamGetCaptureInfo id_s in capture_begin.
return at::cuda::MemPool::graph_pool_handle();
return c10::cuda::MemPool::graph_pool_handle();
}
/**
@ -91,7 +90,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
} else {
// User did not ask us to share a mempool. Create graph pool handle using is_user_created=false.
// Sets just the first value, to distinguish it from MempoolId_ts created by graph_pool_handle().
mempool_id_ = at::cuda::MemPool::graph_pool_handle(false);
mempool_id_ = c10::cuda::MemPool::graph_pool_handle(false);
TORCH_INTERNAL_ASSERT(mempool_id_.first > 0);
}

View File

@ -1,69 +0,0 @@
#include <ATen/core/CachingHostAllocator.h>
#include <ATen/cuda/MemPool.h>
namespace at::cuda {
// uid_ is incremented when a user creates a MemPool,
// for example: using graph_pool_handle() or c10::cuda::MemPool().
//
// uuid_ is incremented when CUDAGraph creates a MemPool
// as a result of a user not providing a pool.
//
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
// passed to a function, either by user or CUDAGraphs. For example,
// default value of MempoolId_t for capture_begin function is {0, 0}.
// That's why uid_ and uuid_ start at 1.
std::atomic<CaptureId_t> MemPool::uid_{1};
std::atomic<CaptureId_t> MemPool::uuid_{1};
MemPool::MemPool(
CUDACachingAllocator::CUDAAllocator* allocator,
bool is_user_created,
bool use_on_oom)
: allocator_(allocator), is_user_created_(is_user_created) {
if (is_user_created_) {
id_ = {0, uid_++};
} else {
id_ = {uuid_++, 0};
}
device_ = c10::cuda::current_device();
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
if (use_on_oom) {
CUDACachingAllocator::setUseOnOOM(device_, id_);
}
}
MemPool::~MemPool() {
// TORCH_INTERNAL_ASSERT(use_count() == 1);
// We used to assert that TORCH_INTERNAL_ASSERT(use_count() == 1);
// However, this assertion is not true if a memory pool is shared
// with a cuda graph. That CUDAGraph will increase the use count
// until it is reset.
CUDACachingAllocator::releasePool(device_, id_);
c10::cuda::CUDACachingAllocator::emptyCache(id_);
}
MempoolId_t MemPool::id() {
return id_;
}
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
return allocator_;
}
int MemPool::use_count() {
return CUDACachingAllocator::getPoolUseCount(device_, id_);
}
c10::DeviceIndex MemPool::device() {
return device_;
}
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
if (is_user_created) {
return {0, uid_++};
}
return {uuid_++, 0};
}
} // namespace at::cuda

View File

@ -1,44 +0,0 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/cuda/CUDACachingAllocator.h>
namespace at::cuda {
// Keep BC only
using c10::CaptureId_t;
using c10::MempoolId_t;
// MemPool represents a pool of memory in a caching allocator. Currently,
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
//
// An allocator pointer can be passed to the MemPool to define how the
// allocations should be done in the pool. For example: using a different
// system allocator such as ncclMemAlloc.
struct TORCH_CUDA_CPP_API MemPool {
MemPool(
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
bool is_user_created = true,
bool use_on_oom = false);
MemPool(const MemPool&) = delete;
MemPool(MemPool&&) = default;
MemPool& operator=(const MemPool&) = delete;
MemPool& operator=(MemPool&&) = default;
~MemPool();
MempoolId_t id();
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator();
int use_count();
c10::DeviceIndex device();
static MempoolId_t graph_pool_handle(bool is_user_created = true);
private:
static std::atomic<CaptureId_t> uid_;
static std::atomic<CaptureId_t> uuid_;
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator_;
bool is_user_created_;
MempoolId_t id_;
c10::DeviceIndex device_;
};
} // namespace at::cuda

View File

@ -1936,7 +1936,7 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
// We order the tensors. t1 will be the larger tensor
// We can always transpose tensor2 as the dimensions are always >= 1 (precondition from matmul)
// and tensor1_larger iff tensor2.dim() > tensor1.dim()
// and tensor1_larger iff tensor2.dim() > tensor1.dim(9
const auto t1 = tensor1_larger ? MaybeOwned<Tensor>::borrowed(tensor1)
: MaybeOwned<Tensor>::owned(tensor2.mT());
const int64_t dim_t1 = t1->dim();
@ -1948,11 +1948,20 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
return false;
}
// If we require a gradient, we should fold to minimize backward memory usage - even if this
// leads to a copy in forward because is needed in backward,
// only time we avoid this strict pre-allocated memory usage (has_out = True)
bool requires_grad = tensor1.requires_grad() || tensor2.requires_grad();
if (requires_grad && !has_out) {
// In this case we *do* incur in an extra copy to avoid creating an unnecessary large tensor in the backward
// Suppose we don't fold here. Let t1.shape = [b, m, n] t2.shape = [n, k] like in a transformer
// t2 will be expanded to a tensor of shape [b, n, k] and then we do t1.bmm(t2_expanded)
// The issue appears in the backward.
// The output gradient g of this operation would have shape [b, m, k]
// The backward wrt. t2 of bmm would be given by t1.mH @ g, which has shape [b, n, k]
// Then, the backward of expand is simply `sum(0)`. As such, we are instantiating a tensor
// of shape [b, n, k] unnecessarily, which may cause a large memory footprint, and in the
// worst case, an OOM
bool t2_requires_grad = tensor1_larger ? tensor2.requires_grad() : tensor1.requires_grad();
if (t2_requires_grad && !has_out) {
// We should be checking !at::GradMode::is_enabled(), but apparently
// this regresses performance in some cases:
// https://github.com/pytorch/pytorch/issues/118548#issuecomment-1916022394
return true;
}

View File

@ -142,7 +142,6 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra
std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
auto batch_sizes_t = _batch_sizes.contiguous();
checkLongTensor(batch_sizes_t);
TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty");
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
int64_t max_batch_size = batch_sizes[0];

View File

@ -1,8 +1,6 @@
#pragma once
#include <c10/util/Exception.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
namespace at::native {
// Used as an interface between the different BLAS-like libraries
@ -23,5 +21,3 @@ static inline char to_blas(TransposeType trans) {
}
} // namespace at::native
C10_DIAGNOSTIC_POP()

View File

@ -1,7 +1,6 @@
#pragma once
#include <ATen/native/CompositeRandomAccessorCommon.h>
#include <thrust/swap.h>
#include <thrust/tuple.h>
namespace at { namespace native {

View File

@ -75,52 +75,30 @@ static inline bool can_use_int32_nhwc(
return true;
}
static inline bool can_use_int32_nchw(
int64_t nbatch, int64_t channels,
int64_t height, int64_t width,
int64_t pooled_height, int64_t pooled_width) {
int64_t hw = height * width;
return can_use_int32_nhwc(
nbatch, channels, height, width,
pooled_height, pooled_width,
channels * hw, // in_stride_n
hw, // in_stride_c
width, // in_stride_h
1 // in_stride_w
);
}
// kernels borrowed from Caffe
template <typename scalar_t, typename index_t>
__global__ void max_pool_forward_nchw(
const index_t nthreads,
const scalar_t* bottom_data,
const int64_t channels,
const int64_t height,
const int64_t width,
const int pooled_height,
const int pooled_width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
scalar_t* top_data,
template <typename scalar_t>
__global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom_data,
const int64_t channels, const int64_t height,
const int64_t width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, scalar_t* top_data,
int64_t* top_mask) {
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
index_t pw = index % pooled_width;
index_t ph = (index / pooled_width) % pooled_height;
index_t c = (index / pooled_width / pooled_height) % channels;
index_t n = index / pooled_width / pooled_height / channels;
index_t hstart = ph * stride_h - pad_h;
index_t wstart = pw * stride_w - pad_w;
index_t hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
index_t wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
while(hstart < 0)
hstart += dilation_h;
while(wstart < 0)
wstart += dilation_w;
scalar_t maxval = at::numeric_limits<scalar_t>::lower_bound(); // -Infinity
index_t maxidx = hstart * width + wstart;
int maxidx = hstart * width + wstart;
const scalar_t* btm_data = bottom_data + (n * channels + c) * height * width;
for (int h = hstart; h < hend; h += dilation_h) {
for (int w = wstart; w < wend; w += dilation_w) {
@ -273,39 +251,32 @@ __global__ void max_pool_forward_nhwc(
static constexpr int BLOCK_THREADS = 256;
template <typename scalar_t, typename accscalar_t, typename index_t>
template <typename scalar_t, typename accscalar_t>
#if defined (USE_ROCM)
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4)
#else
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 8)
#endif
__global__ void max_pool_backward_nchw(
const scalar_t* top_diff,
const int64_t* top_mask,
const index_t num,
const index_t channels,
const index_t height,
const index_t width,
const index_t pooled_height,
const index_t pooled_width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
__global__ void max_pool_backward_nchw(const scalar_t* top_diff,
const int64_t* top_mask, const int num, const int64_t channels,
const int64_t height, const int64_t width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
scalar_t* bottom_diff) {
CUDA_KERNEL_LOOP_TYPE(index, height*width, index_t) {
index_t h = index / width;
index_t w = index - h * width;
index_t phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
index_t phend = p_end(h, pad_h, pooled_height, stride_h);
index_t pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
index_t pwend = p_end(w, pad_w, pooled_width, stride_w);
for (index_t n = blockIdx.y; n < num; n += gridDim.y) {
for (index_t c = blockIdx.z; c < channels; c += gridDim.z) {
CUDA_KERNEL_LOOP(index, height*width) {
int h = index / width;
int w = index - h * width;
int phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
int phend = p_end(h, pad_h, pooled_height, stride_h);
int pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
int pwend = p_end(w, pad_w, pooled_width, stride_w);
for (int n = blockIdx.y; n < num; n += gridDim.y) {
for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
accscalar_t gradient = accscalar_t(0);
index_t offset = (n * channels + c) * pooled_height * pooled_width;
for (index_t ph = phstart; ph < phend; ++ph) {
for (index_t pw = pwstart; pw < pwend; ++pw) {
int offset = (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (top_mask[ph * pooled_width + pw + offset] == h * width + w) {
gradient += static_cast<accscalar_t>(top_diff[ph * pooled_width + pw + offset]);
}
@ -498,6 +469,8 @@ const Tensor& indices) {
const int64_t in_stride_h = input.stride(-2);
const int64_t in_stride_w = input.stride(-1);
const int count = safe_downcast<int, int64_t>(output.numel());
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"max_pool2d_with_indices_out_cuda_frame",
[&] {
@ -580,42 +553,14 @@ const Tensor& indices) {
break;
}
case MemoryFormat::Contiguous: {
const int threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
BLOCK_THREADS);
const int64_t nthreads = output.numel();
bool use_int32 = can_use_int32_nchw(
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
const int blocks = static_cast<int>(std::min<int64_t>(
ceil_div(nthreads, static_cast<int64_t>(threads)),
static_cast<int64_t>(maxGridX)));
auto stream = at::cuda::getCurrentCUDAStream();
if (use_int32) {
max_pool_forward_nchw<scalar_t, int32_t>
<<<blocks, threads, 0, stream>>>(
static_cast<int32_t>(nthreads),
input_data,
static_cast<int32_t>(nInputPlane),
static_cast<int32_t>(inputHeight),
static_cast<int32_t>(inputWidth),
static_cast<int32_t>(outputHeight),
static_cast<int32_t>(outputWidth),
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
output_data, indices_data);
} else {
max_pool_forward_nchw<scalar_t, int64_t>
<<<blocks, threads, 0, stream>>>(
nthreads,
input_data,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
output_data, indices_data);
}
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
BLOCK_THREADS);
max_pool_forward_nchw<scalar_t>
<<<ceil_div(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count, input_data,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
output_data, indices_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
@ -688,6 +633,8 @@ const Tensor& gradInput) {
gradInput.zero_();
int64_t count = input.numel();
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"max_pool2d_with_indices_out_cuda_frame",
[&] {
@ -745,45 +692,25 @@ const Tensor& gradInput) {
break;
}
case MemoryFormat::Contiguous: {
const int threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
BLOCK_THREADS);
const int imgcount = inputWidth * inputHeight;
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
const int maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const int maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
const int blocks_x = std::min(ceil_div(imgcount, threads), maxGridX);
dim3 grid(blocks_x, static_cast<unsigned>(std::min<int64_t>(nbatch, maxGridY)), static_cast<unsigned>(std::min<int64_t>(nInputPlane, maxGridZ)));
bool use_int32 = can_use_int32_nchw(
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
auto stream = at::cuda::getCurrentCUDAStream();
if (use_int32) {
max_pool_backward_nchw<scalar_t, accscalar_t, int32_t>
<<<grid, threads, 0, stream>>>(
gradOutput_data,
indices_data,
static_cast<int32_t>(nbatch),
static_cast<int32_t>(nInputPlane),
static_cast<int32_t>(inputHeight),
static_cast<int32_t>(inputWidth),
static_cast<int32_t>(outputHeight),
static_cast<int32_t>(outputWidth),
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
gradInput_data);
} else {
max_pool_backward_nchw<scalar_t, accscalar_t, int64_t>
<<<grid, threads, 0, stream>>>(
gradOutput_data,
indices_data,
nbatch,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
gradInput_data);
}
int imgcount = inputWidth * inputHeight;
dim3 grid;
const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS;
grid.x = blocks;
grid.y = nbatch;
uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
if (maxGridY < grid.y) grid.y = maxGridY;
grid.z = nInputPlane;
uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
if (maxGridZ < grid.z) grid.z = maxGridZ;
max_pool_backward_nchw<scalar_t, accscalar_t>
<<<grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
gradOutput_data,
indices_data,
nbatch,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
gradInput_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}

View File

@ -669,12 +669,9 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
// On non CK system(w/ ROCm), make sure use_fast_path is false
#if defined(USE_ROCM_CK_GEMM)
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif //USE_ROCM_CK_GEMM
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
@ -683,11 +680,7 @@ std::optional<c10::ScalarType> out_dtype) {
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
#if defined(USE_ROCM_CK_GEMM)
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#else
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
#endif //USE_ROCM_CK_GEMM
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);

View File

@ -267,15 +267,15 @@ void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, con
* outer dimensions, which contains several "inner rows").
* Each thread processes a single inner row at a time.
*/
template<typename scalar_t, typename index_t, class BinaryOp>
template<typename scalar_t, class BinaryOp>
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
const scalar_t init, BinaryOp binary_op)
{
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
const scalar_t *src = src_ + static_cast<index_t>(orow) * row_size * num_irows + irow;
scalar_t *tgt = tgt_ + (index_t) orow * row_size * num_irows + irow;
const scalar_t *src = src_ + orow * row_size * num_irows + irow;
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
scalar_t acc = init;
for (uint32_t col = 0; col < row_size; ++col) {
@ -409,15 +409,10 @@ __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
check_fits_in_unsigned(num_irows, "num_irows");
check_fits_in_unsigned(num_orows, "num_orows");
check_fits_in_unsigned(row_size, "row_size");
if (static_cast<size_t>(num_irows) * num_orows * row_size <= UINT_MAX) {
tensor_kernel_scan_outer_dim<scalar_t, uint32_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
num_orows, num_irows, row_size, init, binary_op);
} else {
tensor_kernel_scan_outer_dim<scalar_t, size_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
num_orows, num_irows, row_size, init, binary_op);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

View File

@ -337,6 +337,10 @@ Tensor _convolution_out(
TORCH_CHECK(
3 == ndim || 4 == ndim || 5 == ndim,
"convolution only supports 3D, 4D, 5D tensor");
// get computation format for Conv/TransposedConv
bool is_channels_last_suggested =
use_channels_last_for_conv(input_r, weight_r);
Tensor input = input_r, weight = weight_r;
// PyTorch does not support ChannelsLast1D case,
// thus we need the transformation here
@ -344,8 +348,13 @@ Tensor _convolution_out(
input = view4d(input_r);
weight = view4d(weight_r);
}
// get computation format for Conv/TransposedConv
bool is_channels_last_suggested = use_channels_last_for_conv(input, weight);
// ensure the input/weight/bias/output are congituous in desired format
at::MemoryFormat mfmt = is_channels_last_suggested
? get_cl_tag_by_ndim(input.ndimension())
: at::MemoryFormat::Contiguous;
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
input = input.contiguous(mfmt);
weight = weight.contiguous(mfmt);
auto k = weight.ndimension();
if (k == input.ndimension() + 1) {
@ -379,14 +388,6 @@ Tensor _convolution_out(
expand_param_if_needed(output_padding_, "output_padding", dim);
params.groups = groups_;
}
// ensure the input/weight/bias/output are congituous in desired format
at::MemoryFormat mfmt = is_channels_last_suggested
? get_cl_tag_by_ndim(input.ndimension())
: at::MemoryFormat::Contiguous;
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
input = input.contiguous(mfmt);
weight = weight.contiguous(mfmt);
check_shape_forward(input, weight, bias, params, true);
Tensor output;
@ -513,9 +514,18 @@ Tensor convolution_overrideable(
at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;
auto k = weight_r.ndimension();
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
if (xpu_conv_use_channels_last(input_r, weight_r)) {
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d
: at::MemoryFormat::ChannelsLast;
}
Tensor input_c = input_r.contiguous(backend_memory_format);
Tensor weight_c = weight_r.contiguous(backend_memory_format);
return _convolution(
input_r,
weight_r,
input_c,
weight_c,
bias_r,
stride_,
padding_,

View File

@ -40,7 +40,7 @@ inline c10::metal::opmath_t<T> matmul_inner(
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint k = 0; k < TILE_DIM; k++) {
sum += c10::metal::mul(A_tile[tid.y][k], B_tile[k][tid.x]);
sum += A_tile[tid.y][k] * B_tile[k][tid.x];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@ -832,10 +832,6 @@ INSTANTIATE_MM_OPS(float);
INSTANTIATE_MM_OPS(half);
INSTANTIATE_MM_OPS(bfloat);
// Complex MM
INSTANTIATE_MM_OPS(float2);
INSTANTIATE_MM_OPS(half2);
// Integral MM
INSTANTIATE_MM_OPS(long);
INSTANTIATE_MM_OPS(int);

View File

@ -190,16 +190,10 @@ std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* gr
bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {
static bool always_use_metal = c10::utils::has_env("PYTORCH_MPS_PREFER_METAL");
constexpr auto max_stride_size = 32768;
constexpr auto max_complex_inner_size = 2048;
static bool is_macos_14_4_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
if (always_use_metal || c10::isIntegralType(self.scalar_type(), true)) {
return true;
}
// multiplicationWithPrimaryTensor: returns incorrect results if inner size exceeds 2048
// See https://github.com/pytorch/pytorch/issues/167727#issuecomment-3529308548
if (c10::isComplexType(self.scalar_type()) && self.size(1) > max_complex_inner_size) {
return true;
}
return !is_macos_14_4_or_newer &&
(self.stride(0) > max_stride_size || self.stride(1) > max_stride_size || self.size(0) > max_stride_size ||
self.size(1) > max_stride_size || other.stride(0) > max_stride_size || other.stride(1) > max_stride_size ||

View File

@ -7518,7 +7518,7 @@
- func: _sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_mask_projection
SparseCPU, SparseCUDA: sparse_mask_projection
autogen: _sparse_mask_projection.out
- func: _to_cpu(Tensor[] tensors) -> Tensor[]

View File

@ -30,12 +30,10 @@
#include <thrust/binary_search.h>
#include <thrust/device_ptr.h>
#include <thrust/distance.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/system/cuda/execution_policy.h>
#include <thrust/iterator/constant_iterator.h>
#include <cuda_runtime_api.h>
#include <cusparse.h>
@ -49,7 +47,6 @@
#include <c10/macros/Macros.h>
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/distance.h>
#include <thrust/for_each.h>
#include <thrust/functional.h>
#include <thrust/gather.h>

View File

@ -445,33 +445,6 @@ static SparseTensor& mul_out_dense_sparse_mps(
return out;
}
static std::tuple<Tensor, Tensor, int64_t> mps_intersect_binary_search(
const Tensor& A_keys,
const Tensor& B_keys,
int64_t lenA,
int64_t lenB,
bool boolean_flag) {
auto stream = getCurrentMPSStream();
auto outA_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
auto outB_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
auto counter = at::zeros({1}, A_keys.options().dtype(at::kInt));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
static_cast<uint32_t>(lenB), boolean_flag);
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
}
});
const auto match_count = static_cast<int64_t>(counter.item<int32_t>());
return std::make_tuple(std::move(outA_idx), std::move(outB_idx), match_count);
}
SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTensor& r_) {
TORCH_CHECK(r_.is_mps(), "mul: expected 'out' to be MPS, but got ", r_.device());
@ -550,10 +523,22 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
auto [outA_idx, outB_idx, M_int64] = mps_intersect_binary_search(
A_keys, B_keys, lenA, lenB, A_is_lhs);
auto outA_idx = at::empty({lenA}, at::device(device).dtype(kLong));
auto outB_idx = at::empty({lenA}, at::device(device).dtype(kLong));
auto counter = at::zeros({1}, at::device(device).dtype(kInt));
const auto M = static_cast<uint32_t>(M_int64); // number of structural matches
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
static_cast<uint32_t>(lenB), A_is_lhs);
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
}
});
const uint32_t M = counter.item<int32_t>(); // number of structural matches
r_.resize_as_(lhs);
@ -777,14 +762,6 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self,
using OptTensor = std::optional<Tensor>;
static Tensor create_sparse_output_values(
const Tensor& template_values,
int64_t output_nnz,
ScalarType dtype) {
auto out_val_sizes = template_values.sizes().vec();
out_val_sizes[0] = output_nnz;
return at::zeros(out_val_sizes, template_values.options().dtype(dtype));
}
static void sparse_mask_apply_out_mps_kernel(
Tensor& result,
@ -806,9 +783,9 @@ static void sparse_mask_apply_out_mps_kernel(
auto src = src_in.coalesce();
auto mask = coalesce_mask ? mask_in.coalesce() : mask_in;
const auto src_nnz = src._nnz();
const auto mask_nnz = mask._nnz();
const auto sd = src.sparse_dim();
const int64_t src_nnz = src._nnz();
const int64_t mask_nnz = mask._nnz();
const int64_t sd = src.sparse_dim();
result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim());
auto commonDtype = at::result_type(src, mask);
@ -837,27 +814,53 @@ static void sparse_mask_apply_out_mps_kernel(
return;
}
auto mask_indices = mask._indices().contiguous();
auto src_values = src._values().to(commonDtype).contiguous();
auto out_values = create_sparse_output_values(src_values, mask_nnz, commonDtype);
if (src_nnz == 0) {
alias_into_sparse(result, mask_indices, out_values);
auto out_indices = mask._indices().contiguous();
auto src_values = src._values().to(commonDtype);
auto out_val_sizes = src_values.sizes().vec();
out_val_sizes[0] = mask_nnz;
auto out_values = at::zeros(out_val_sizes, src_values.options());
alias_into_sparse(result, out_indices, out_values);
result._coalesced_(mask.is_coalesced());
return;
}
auto mask_keys = flatten_indices(mask._indices().contiguous(), mask.sizes().slice(0, sd)).contiguous();
auto src_keys = flatten_indices(src._indices().contiguous(), src.sizes().slice(0, sd)).contiguous();
auto mask_indices = mask._indices().contiguous();
auto src_indices = src._indices().contiguous();
auto src_values = src._values().to(commonDtype).contiguous();
const auto A_is_src = (src_nnz <= mask_nnz);
const auto lenA = A_is_src ? src_nnz : mask_nnz;
const auto lenB = A_is_src ? mask_nnz : src_nnz;
auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous();
auto src_keys = flatten_indices(src_indices, src.sizes().slice(0, sd)).contiguous();
const bool A_is_src = (src_nnz <= mask_nnz);
const int64_t lenA = A_is_src ? src_nnz : mask_nnz;
const int64_t lenB = A_is_src ? mask_nnz : src_nnz;
auto A_keys = A_is_src ? src_keys : mask_keys;
auto B_keys = A_is_src ? mask_keys : src_keys;
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
A_keys, B_keys, lenA, lenB, A_is_src);
const auto device = result.device();
auto stream = getCurrentMPSStream();
auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
auto counter = at::zeros({1}, at::device(device).dtype(at::kInt));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
static_cast<uint32_t>(lenB), A_is_src);
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
}
});
const int64_t M = static_cast<int64_t>(counter.item<int32_t>());
auto out_val_sizes = src_values.sizes().vec();
out_val_sizes[0] = mask_nnz;
auto out_values = at::zeros(out_val_sizes, src_values.options());
if (M > 0) {
auto src_match = outA_idx.narrow(0, 0, M);
@ -875,70 +878,6 @@ static void sparse_mask_apply_out_mps_kernel(
result._coalesced_(mask.is_coalesced());
}
static void sparse_mask_projection_out_mps_kernel(
Tensor& result,
const Tensor& lhs,
const Tensor& rhs,
const OptTensor& /*x_hash_opt*/,
bool accumulate_matches) {
TORCH_CHECK(lhs.is_sparse() && rhs.is_sparse(), "sparse_mask_projection: expected sparse COO");
TORCH_CHECK(lhs.is_mps() && rhs.is_mps(), "sparse_mask_projection: expected MPS tensors");
TORCH_CHECK(lhs.sparse_dim() == rhs.sparse_dim(), "sparse_dim mismatch");
auto lhs_c = lhs.coalesce();
auto rhs_c = rhs.coalesce();
const auto sd = lhs_c.sparse_dim();
const auto lhs_nnz = lhs_c._nnz();
const auto rhs_nnz = rhs_c._nnz();
auto commonDtype = at::result_type(lhs_c, rhs_c);
TORCH_CHECK(canCast(commonDtype, result.scalar_type()),
"Can't convert ", commonDtype, " to output ", result.scalar_type());
result.sparse_resize_(lhs.sizes(), lhs.sparse_dim(), lhs.dense_dim());
auto lhs_indices = lhs_c._indices().contiguous();
auto rhs_values = rhs_c._values().to(commonDtype).contiguous();
auto out_values = create_sparse_output_values(rhs_values, lhs_nnz, commonDtype);
if (lhs_nnz > 0 && rhs_nnz > 0) {
auto lhs_keys = flatten_indices(lhs_indices, lhs_c.sizes().slice(0, sd)).contiguous();
auto rhs_keys = flatten_indices(rhs_c._indices().contiguous(), rhs_c.sizes().slice(0, sd)).contiguous();
const auto A_is_lhs = (lhs_nnz <= rhs_nnz);
const auto lenA = A_is_lhs ? lhs_nnz : rhs_nnz;
const auto lenB = A_is_lhs ? rhs_nnz : lhs_nnz;
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
A_keys, B_keys, lenA, lenB, A_is_lhs);
if (M > 0) {
auto idx_in_A = outA_idx.narrow(0, 0, M);
auto idx_in_B = outB_idx.narrow(0, 0, M);
auto idx_in_lhs = A_is_lhs ? idx_in_A : idx_in_B;
auto idx_in_rhs = A_is_lhs ? idx_in_B : idx_in_A;
const auto view_cols = rhs_values.numel() / std::max<int64_t>(rhs_nnz, 1);
auto rhs_rows = rhs_values.index_select(0, idx_in_rhs).contiguous();
auto rhs_rows_2d = rhs_rows.view({M, view_cols});
auto out_2d = out_values.view({lhs_nnz, view_cols});
if (accumulate_matches) {
out_2d.index_add_(0, idx_in_lhs, rhs_rows_2d);
} else {
out_2d.index_copy_(0, idx_in_lhs, rhs_rows_2d);
}
}
}
alias_into_sparse(result, lhs._indices(), out_values);
result._coalesced_(lhs.is_coalesced());
}
static void sparse_mask_intersection_out_mps_kernel(
Tensor& result,
const Tensor& lhs,
@ -1063,5 +1002,4 @@ Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
}
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
REGISTER_MPS_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_mps_kernel);
} // namespace at::native

View File

@ -1,3 +1,191 @@
#pragma once
#include <ATen/xpu/XPUContext.h>
#include <c10/xpu/XPUEvent.h>
#include <optional>
namespace at::xpu {
/*
* XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
* constructed lazily when first recorded. It has a device, and this device is
* acquired from the first recording stream. Later streams that record the event
* must match the same device.
*
* Currently, XPUEvent does NOT support to export an inter-process event from
* another process via inter-process communication(IPC). So it means that
* inter-process communication for event handles between different processes is
* not available. This could impact some applications that rely on cross-process
* synchronization and communication.
*/
struct TORCH_XPU_API XPUEvent {
// Constructors
XPUEvent(bool enable_timing = false) noexcept
: enable_timing_{enable_timing} {}
~XPUEvent() {
if (isCreated()) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
}
}
XPUEvent(const XPUEvent&) = delete;
XPUEvent& operator=(const XPUEvent&) = delete;
XPUEvent(XPUEvent&& other) = default;
XPUEvent& operator=(XPUEvent&& other) = default;
operator sycl::event&() const {
return event();
}
std::optional<at::Device> device() const {
if (isCreated()) {
return at::Device(at::kXPU, device_index_);
} else {
return std::nullopt;
}
}
inline bool isCreated() const {
return (event_.get() != nullptr);
}
DeviceIndex device_index() const {
return device_index_;
}
sycl::event& event() const {
return *event_;
}
bool query() const {
using namespace sycl::info;
if (!isCreated()) {
return true;
}
return event().get_info<event::command_execution_status>() ==
event_command_status::complete;
}
void record() {
record(getCurrentXPUStream());
}
void recordOnce(const XPUStream& stream) {
if (!isCreated()) {
record(stream);
}
}
void record(const XPUStream& stream) {
if (!isCreated()) {
device_index_ = stream.device_index();
assignEvent(stream.queue());
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
} else {
TORCH_CHECK(
device_index_ == stream.device_index(),
"Event device ",
device_index_,
" does not match recording stream's device ",
stream.device_index(),
".");
reassignEvent(stream.queue());
}
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
at::kXPU,
reinterpret_cast<uintptr_t>(event_.get()),
reinterpret_cast<uintptr_t>(&stream.queue()));
}
}
void block(const XPUStream& stream) {
if (isCreated()) {
std::vector<sycl::event> event_list{event()};
// Make this stream wait until event_ is completed.
stream.queue().ext_oneapi_submit_barrier(event_list);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
at::kXPU,
reinterpret_cast<uintptr_t>(event_.get()),
reinterpret_cast<uintptr_t>(&stream.queue()));
}
}
}
double elapsed_time(const XPUEvent& other) const {
TORCH_CHECK(
isCreated() && other.isCreated(),
"Both events must be recorded before calculating elapsed time.");
TORCH_CHECK(
query() && other.query(),
"Both events must be completed before calculating elapsed time.");
TORCH_CHECK(
enable_timing_ && other.enable_timing_,
"Both events must be created with argument 'enable_timing=True'.");
#if SYCL_COMPILER_VERSION < 20250000
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"elapsed_time of XPUEvent requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
#endif
using namespace sycl::info::event_profiling;
// Block until both of the recorded events are completed.
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
uint64_t start_time_ns = event().get_profiling_info<command_end>();
// Return the eplased time in milliseconds.
return 1e-6 *
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
}
void synchronize() const {
if (isCreated()) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
event().wait_and_throw();
}
}
private:
void assignEvent(sycl::queue& queue) {
#if SYCL_COMPILER_VERSION >= 20250000
if (enable_timing_) {
event_ = std::make_unique<sycl::event>(
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
} else {
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
}
#else
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
#endif
}
void reassignEvent(sycl::queue& queue) {
event_.reset();
assignEvent(queue);
}
bool enable_timing_ = false;
DeviceIndex device_index_ = -1;
// Only need to track the last event, as events in an in-order queue are
// executed sequentially.
std::unique_ptr<sycl::event> event_;
};
} // namespace at::xpu

View File

@ -10,13 +10,6 @@
...
}
{
ignore_empty_generic_uninitialised_conditional_jump
Memcheck:Cond
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
...
}
{
Cond_cuda
Memcheck:Cond

View File

@ -50,7 +50,6 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
"mobilenet_v2",
"pytorch_CycleGAN_and_pix2pix",
"pytorch_stargan",
"repvgg_a2",
"resnet152",
"resnet18",
"resnet50",

View File

@ -10,7 +10,7 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
@ -66,7 +66,7 @@ visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
vit_base_patch14_dinov2.lvd142m,pass,7

1 name accuracy graph_breaks
10 mobilenetv2_100 pass 7
11 mobilenetv3_large_100 pass 7
12 mobilevit_s pass 6
13 nfnet_l0 pass 7
14 repvgg_a2 pass 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
66
67
68
69
70
71
72

View File

@ -50,7 +50,7 @@ nfnet_l0,pass,7
repvgg_a2,pass,7
repvgg_a2,fail_accuracy,7

1 name accuracy graph_breaks
50
51
52
53
54
55
56

View File

@ -2288,9 +2288,11 @@ class BenchmarkRunner:
)
):
is_same = False
except Exception:
except Exception as e:
# Sometimes torch.allclose may throw RuntimeError
is_same = False
exception_string = str(e)
accuracy_status = f"fail_exception: {exception_string}"
return record_status(accuracy_status, dynamo_start_stats=start_stats)
if not is_same:
accuracy_status = "eager_two_runs_differ"
@ -2407,9 +2409,11 @@ class BenchmarkRunner:
force_max_multiplier=force_max_multiplier,
):
is_same = False
except Exception:
except Exception as e:
# Sometimes torch.allclose may throw RuntimeError
is_same = False
exception_string = str(e)
accuracy_status = f"fail_exception: {exception_string}"
return record_status(accuracy_status, dynamo_start_stats=start_stats)
if not is_same:
if self.args.skip_accuracy_check:

View File

@ -2,7 +2,6 @@ import csv
import os
import re
import sys
from pathlib import Path
# This script takes the logs produced by the benchmark scripts (e.g.,
@ -16,7 +15,8 @@ from pathlib import Path
# This script is not very well written, feel free to rewrite it as necessary
assert len(sys.argv) == 2
full_log = Path(sys.argv[1]).read_text()
full_log = open(sys.argv[1]).read()
# If the log contains a gist URL, extract it so we can include it in the CSV
gist_url = ""

View File

@ -1,62 +0,0 @@
import sys
from benchmark_base import BenchmarkBase
import torch
from torch.distributed._tensor import DTensor, Replicate
from torch.testing._internal.distributed.fake_pg import FakeStore
class BenchmarkDTensorDispatch(BenchmarkBase):
def __init__(self, operator, world_size) -> None:
super().__init__(
category=f"dtensor_dispatch_{operator}",
device="cuda",
)
self.world_size = world_size
def name(self) -> str:
prefix = f"{self.category()}"
return prefix
def description(self) -> str:
return f"DTensor dispatch time for {self.category()}"
def _prepare_once(self) -> None:
self.mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", (self.world_size,), mesh_dim_names=("dp",)
)
self.a = DTensor.from_local(
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
)
self.b = DTensor.from_local(
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
)
def _prepare(self) -> None:
pass
class BenchmarkDetach(BenchmarkDTensorDispatch):
def __init__(self, world_size) -> None:
super().__init__(operator="detach", world_size=world_size)
def _work(self) -> None:
self.a.detach()
def main():
world_size = 256
fake_store = FakeStore()
torch.distributed.init_process_group(
"fake", store=fake_store, rank=0, world_size=world_size
)
result_path = sys.argv[1]
BenchmarkDetach(world_size).enable_instruction_count().collect_all().append_results(
result_path
)
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main()

View File

@ -83,13 +83,10 @@ if __name__ == "__main__":
if args.outfile == "stdout":
outfile = sys.stdout
need_close = False
elif args.outfile == "stderr":
outfile = sys.stderr
need_close = False
else:
outfile = open(args.outfile, "a")
need_close = True
test_count = args.test_count
m = args.m
@ -150,5 +147,3 @@ if __name__ == "__main__":
time,
file=outfile,
)
if need_close:
outfile.close()

View File

@ -82,13 +82,10 @@ if __name__ == "__main__":
if args.outfile == "stdout":
outfile = sys.stdout
need_close = False
elif args.outfile == "stderr":
outfile = sys.stderr
need_close = False
else:
outfile = open(args.outfile, "a")
need_close = True
test_count = args.test_count
m = args.m
@ -135,5 +132,3 @@ if __name__ == "__main__":
time_csr,
file=outfile,
)
if need_close:
outfile.close()

View File

@ -179,13 +179,10 @@ if __name__ == "__main__":
if args.outfile == "stdout":
outfile = sys.stdout
need_close = False
elif args.outfile == "stderr":
outfile = sys.stderr
need_close = False
else:
outfile = open(args.outfile, "a")
need_close = True
ops = args.ops.split(",")
@ -437,5 +434,3 @@ if __name__ == "__main__":
if op not in {"bsr_scatter_mm6", "bsr_dense_mm_with_meta"}:
# Break on operations that do not consume parameters
break
if need_close:
outfile.close()

View File

@ -125,17 +125,6 @@ AttentionType = Literal[
]
DtypeString = Literal["bfloat16", "float16", "float32"]
SpeedupType = Literal["fwd", "bwd"]
# Operator Name mapping
backend_to_operator_name = {
"math": "math attention kernel",
"efficient": "efficient attention kernel",
"cudnn": "cudnn attention kernel",
"fav2": "flash attention 2 kernel",
"fav3": "flash attention 3 kernel",
"fakv": "flash attention kv cache kernel",
"og-eager": "eager attention kernel",
"flex": "flex attention kernel",
}
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
@ -1276,14 +1265,12 @@ def _output_json_for_dashboard(
model: ModelInfo
metric: MetricInfo
operator_name = backend_to_operator_name.get(backend, backend)
# Benchmark extra info
benchmark_extra_info = {
"input_config": input_config,
"device": device,
"arch": device_arch,
"operator_name": operator_name,
"operator_name": backend,
"attn_type": config.attn_type,
"shape": str(config.shape),
"max_autotune": config.max_autotune,
@ -1301,7 +1288,7 @@ def _output_json_for_dashboard(
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": operator_name,
"operator_name": backend,
"attn_type": config.attn_type,
},
),
@ -1328,7 +1315,7 @@ def _output_json_for_dashboard(
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": operator_name,
"operator_name": backend,
},
),
metric=MetricInfo(
@ -1354,7 +1341,7 @@ def _output_json_for_dashboard(
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": operator_name,
"operator_name": backend,
},
),
metric=MetricInfo(
@ -1384,7 +1371,7 @@ def _output_json_for_dashboard(
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": operator_name,
"operator_name": backend,
},
),
metric=MetricInfo(

View File

@ -19,17 +19,6 @@
namespace c10 {
using CaptureId_t = unsigned long long;
// first is set if the instance is created by CUDAGraph::capture_begin.
// second is set if the instance is created by at::cuda::graph_pool_handle.
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
struct MempoolIdHash {
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
}
};
// A DataPtr is a unique pointer (with an attached deleter and some
// context for the deleter) to some memory, which also records what
// device is for its data.

View File

@ -99,10 +99,7 @@ struct C10_API DeviceAllocator : public c10::Allocator {
// Return the free memory size and total memory size in bytes for the
// specified device.
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "getMemoryInfo is not implemented for this allocator yet.");
}
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
};
// This function is used to get the DeviceAllocator for a specific device type

View File

@ -44,7 +44,7 @@ struct C10_API SafePyObject {
(*other.pyinterpreter_)->incref(other.data_);
}
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_);
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
}
data_ = other.data_;
pyinterpreter_ = other.pyinterpreter_;
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
~SafePyObject() {
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_);
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
}
}

View File

@ -27,7 +27,6 @@
#include <torch/headeronly/core/ScalarType.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
namespace c10 {
@ -206,12 +205,6 @@ inline bool isSignedType(ScalarType t) {
break;
// Do not add default here, but rather define behavior of every new entry
// here. `-Wswitch-enum` would raise a warning in those cases.
// TODO: get PyTorch to adopt exhaustive switches by default with a way to
// opt specific switches to being non-exhaustive.
// Exhaustive:
// `-Wswitch-enum`, `-Wswitch-default`, `-Wno-covered-switch-default`
// Non-Exhaustive:
// `-Wno-switch-enum`, `-Wswitch-default`, `-Wcovered-switch-default`
}
TORCH_CHECK(false, "Unknown ScalarType ", t);
#undef CASE_ISSIGNED

View File

@ -48,30 +48,6 @@ void warnDeprecatedDataPtr() {
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
}
void StorageImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void StorageImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool StorageImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
// Allowlist verification.
// Only if the devicetype is in the allowlist,

View File

@ -105,12 +105,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
data_ptr_.clear();
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
size_t nbytes() const {
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
TORCH_CHECK(!size_bytes_is_heap_allocated_);
@ -376,18 +370,4 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
bool resizable,
std::optional<at::Device> device_opt);
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
} // namespace c10

View File

@ -277,6 +277,7 @@ void TensorImpl::release_resources() {
if (storage_) {
storage_ = {};
}
pyobj_slot_.maybe_destroy_pyobj();
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
@ -988,30 +989,6 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
}
}
void TensorImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void TensorImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool TensorImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
namespace impl {
namespace {

View File

@ -57,8 +57,6 @@ C10_DECLARE_bool(caffe2_keep_on_shrink);
// respect caffe2_keep_on_shrink.
C10_DECLARE_int64(caffe2_max_keep_on_shrink_memory);
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
namespace at {
class Tensor;
class TensorBase;
@ -2178,12 +2176,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return &pyobj_slot_;
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
private:
// See NOTE [std::optional operator usage in CUDA]
// We probably don't want to expose this publicly until
@ -3085,19 +3077,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
friend class C10_TensorImpl_Size_Check_Dummy_Class;
};
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
// Note [TensorImpl size constraints]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Changed the size of TensorImpl? If the size went down, good for
@ -3324,5 +3303,3 @@ static_assert(
#undef C10_GCC_VERSION_MINOR
} // namespace c10
C10_DIAGNOSTIC_POP()

View File

@ -11,11 +11,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
void incref(PyObject* pyobj) const override {} // do nothing
void decref(PyObject* pyobj) const override {} // do nothing
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
return false;
}
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
} // do nothing
#define PANIC(m) \
TORCH_INTERNAL_ASSERT( \
@ -23,10 +20,6 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
"attempted to call " #m \
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
size_t refcnt(PyObject* pyobj) const override {
PANIC(refcnt);
}
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
PANIC(detach);
}

View File

@ -18,9 +18,6 @@ namespace c10 {
struct IValue;
class OperatorHandle;
struct TensorImpl;
namespace impl {
struct PyObjectSlot;
} // namespace impl
} // namespace c10
namespace torch::jit {
@ -129,12 +126,9 @@ struct C10_API PyInterpreterVTable {
// Run Py_INCREF on a PyObject.
virtual void incref(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call.
virtual void decref(PyObject* pyobj) const = 0;
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL.
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
// Run Py_REFCNT on a PyObject.
virtual size_t refcnt(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
// Perform a detach by deferring to the __torch_dispatch__ implementation of
// detach, which will also arrange for the PyObject to get copied in this

View File

@ -0,0 +1,56 @@
#include <c10/core/impl/PyObjectSlot.h>
namespace c10::impl {
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot::~PyObjectSlot() {
maybe_destroy_pyobj();
}
void PyObjectSlot::maybe_destroy_pyobj() {
if (owns_pyobj()) {
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
(*pyobj_interpreter_.load(std::memory_order_acquire))
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
// NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references
// to the PyObject (if there are references to the PyObject,
// then the PyObject holds an owning reference to the tensor).
// So it is OK to clear pyobj_ here as it is impossible for it to
// be used again (modulo weak reference races)
pyobj_ = nullptr; // for safety
}
}
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
}
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter) {
return *interpreter;
}
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
}
bool PyObjectSlot::owns_pyobj() {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
}
void PyObjectSlot::set_owns_pyobj(bool b) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
pyobj_ = reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
}
} // namespace c10::impl

View File

@ -8,58 +8,117 @@
#include <atomic>
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10::impl {
struct C10_API PyObjectSlot {
public:
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot();
~PyObjectSlot();
void maybe_destroy_pyobj();
// Associate the TensorImpl with the specified PyObject, and, if necessary,
// also tag the interpreter.
//
// NB: This lives in a header so that we can inline away the switch on status
//
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
// PyObject if necessary!
void init_pyobj(PyObject* pyobj) {
pyobj_interpreter_.store(
getGlobalPyInterpreter(), std::memory_order_relaxed);
pyobj_ = pyobj;
}
// Query the PyObject interpreter. This may return null if there is no
// interpreter.
PyInterpreter* pyobj_interpreter() const {
return pyobj_interpreter_.load(std::memory_order_acquire);
// interpreter. This is racy!
PyInterpreter* pyobj_interpreter();
PyObject* _unchecked_untagged_pyobj() const;
// Test the interpreter tag. If tagged for the current interpreter, return
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
// returns a nullopt. If it is definitely invalid, raises an error.
//
// If `ignore_hermetic_tls` is false and this function is called from a
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
// context is ignored, allowing you to check the interpreter tag of a
// nonhermetic PyObject from within a hermetic context. This is necessary
// because there are some cases where the deallocator function of a
// nonhermetic PyObject is called from within a hermetic context, so it must
// be properly treated as a nonhermetic PyObject.
//
// NB: this lives in header so that we can avoid actually creating the
// std::optional
// @todo alban: I'm not too sure what's going on here, we can probably delete
// it but it's worthwhile making sure
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter == nullptr) {
return std::nullopt;
}
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
return _unchecked_untagged_pyobj();
}
}
PyInterpreter& load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
TORCH_INTERNAL_ASSERT(
interpreter, "cannot access PyObject for Tensor - no interpreter set");
return *interpreter;
}
PyInterpreter& load_pyobj_interpreter() const;
PyObject* load_pyobj() const {
return pyobj_.load(std::memory_order_acquire);
}
bool owns_pyobj();
void store_pyobj(PyObject* obj) {
pyobj_.store(obj, std::memory_order_release);
}
bool has_unique_reference() const {
PyObject* pyobj = load_pyobj();
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
}
void clear() {
pyobj_.store(nullptr, std::memory_order_relaxed);
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
}
void set_owns_pyobj(bool b);
private:
// This is now always the global interpreter if the PyObject is set.
// Maybe we can remove this field some day...
// This field contains the interpreter tag for this object. See
// Note [Python interpreter tag] for general context
//
// Note [Memory ordering on Python interpreter tag]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What memory_order do we need when accessing this atomic? We don't
// need a single total modification order (as provided by
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
// transition from -1 to some positive integer and never changes afterwards.
// Because there is only one modification, it trivially already has a total
// modification order (e.g., we don't need fences or locked instructions on
// x86)
//
// In fact, one could make a reasonable argument that relaxed reads are OK,
// due to the presence of external locking (GIL) to ensure that interactions
// with other data structures are still correctly synchronized, so that
// we fall in the "Single-Location Data Structures" case as described in
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
// as I get the same assembly in both cases. So I just use the more
// conservative acquire (which will impede compiler optimizations but I don't
// care)
std::atomic<PyInterpreter*> pyobj_interpreter_;
// The PyObject representing this Tensor or nullptr. Ownership is managed
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
// reference is already dead.
std::atomic<PyObject*> pyobj_;
friend class torch::utils::PyObjectPreservation;
// This field contains a reference to a PyObject representing this Tensor.
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
// PyObject for it and set this field. This field does not have to be
// protected by an atomic as it is only allowed to be accessed when you hold
// the GIL, or during destruction of the tensor.
//
// When a PyObject dies, you are obligated to clear this field
// (otherwise, you will try to use-after-free the pyobj); this currently
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
//
// NB: Ordinarily, this should not be a strong reference, as if the
// PyObject owns the Tensor, this would create a reference cycle.
// However, sometimes this ownership flips. To track who owns
// who, this has a single pointer tag indicating whether or not the
// C++ object owns the PyObject (the common case, zero, means PyObject
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
// or check_pyobj for checked access. See references to PyObject
// resurrection in torch/csrc/autograd/python_variable.cpp
PyObject* pyobj_;
};
} // namespace c10::impl

View File

@ -1012,6 +1012,12 @@ PrivatePoolState::PrivatePoolState(
}
}
struct MempoolIdHash {
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
}
};
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) {
*ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size);
@ -4504,3 +4510,66 @@ std::atomic<CUDAAllocator*> allocator;
static BackendStaticInitializer backend_static_initializer;
} // namespace cuda::CUDACachingAllocator
} // namespace c10
namespace c10::cuda {
// uid_ is incremented when a user creates a MemPool,
// for example: using graph_pool_handle() or c10::cuda::MemPool().
//
// uuid_ is incremented when CUDAGraph creates a MemPool
// as a result of a user not providing a pool.
//
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
// passed to a function, either by user or CUDAGraphs. For example,
// default value of MempoolId_t for capture_begin function is {0, 0}.
// That's why uid_ and uuid_ start at 1.
std::atomic<CaptureId_t> MemPool::uid_{1};
std::atomic<CaptureId_t> MemPool::uuid_{1};
MemPool::MemPool(
CUDACachingAllocator::CUDAAllocator* allocator,
bool is_user_created,
bool use_on_oom)
: allocator_(allocator), is_user_created_(is_user_created) {
if (is_user_created_) {
id_ = {0, uid_++};
} else {
id_ = {uuid_++, 0};
}
device_ = c10::cuda::current_device();
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
if (use_on_oom) {
CUDACachingAllocator::setUseOnOOM(device_, id_);
}
}
MemPool::~MemPool() {
TORCH_INTERNAL_ASSERT(use_count() == 1);
CUDACachingAllocator::releasePool(device_, id_);
c10::cuda::CUDACachingAllocator::emptyCache(id_);
}
MempoolId_t MemPool::id() {
return id_;
}
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
return allocator_;
}
int MemPool::use_count() {
return CUDACachingAllocator::getPoolUseCount(device_, id_);
}
c10::DeviceIndex MemPool::device() {
return device_;
}
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
if (is_user_created) {
return {0, uid_++};
}
return {uuid_++, 0};
}
} // namespace c10::cuda

View File

@ -562,7 +562,41 @@ inline std::string getUserMetadata() {
} // namespace c10::cuda::CUDACachingAllocator
namespace c10::cuda {
// Keep BC only
using c10::CaptureId_t;
using c10::MempoolId_t;
// MemPool represents a pool of memory in a caching allocator. Currently,
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
//
// An allocator pointer can be passed to the MemPool to define how the
// allocations should be done in the pool. For example: using a different
// system allocator such as ncclMemAlloc.
struct C10_CUDA_API MemPool {
MemPool(
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
bool is_user_created = true,
bool use_on_oom = false);
MemPool(const MemPool&) = delete;
MemPool(MemPool&&) = default;
MemPool& operator=(const MemPool&) = delete;
MemPool& operator=(MemPool&&) = default;
~MemPool();
MempoolId_t id();
CUDACachingAllocator::CUDAAllocator* allocator();
int use_count();
c10::DeviceIndex device();
static MempoolId_t graph_pool_handle(bool is_user_created = true);
private:
static std::atomic<CaptureId_t> uid_;
static std::atomic<CaptureId_t> uuid_;
CUDACachingAllocator::CUDAAllocator* allocator_;
bool is_user_created_;
MempoolId_t id_;
c10::DeviceIndex device_;
};
} // namespace c10::cuda

View File

@ -295,19 +295,11 @@ DeviceAssertionsData* CUDAKernelLaunchRegistry::
C10_CUDA_CHECK_WO_DSA(
cudaMallocManaged(&uvm_assertions_ptr, sizeof(DeviceAssertionsData)));
#if CUDART_VERSION >= 13000
cudaMemLocation cpuDevice;
cpuDevice.type = cudaMemLocationTypeDevice;
cpuDevice.id = cudaCpuDeviceId;
#else
const auto cpuDevice = cudaCpuDeviceId;
#endif
C10_CUDA_CHECK_WO_DSA(cudaMemAdvise(
uvm_assertions_ptr,
sizeof(DeviceAssertionsData),
cudaMemAdviseSetPreferredLocation,
cpuDevice));
cudaCpuDeviceId));
// GPU will establish direct mapping of data in CPU memory, no page faults
// will be generated
@ -315,7 +307,7 @@ DeviceAssertionsData* CUDAKernelLaunchRegistry::
uvm_assertions_ptr,
sizeof(DeviceAssertionsData),
cudaMemAdviseSetAccessedBy,
cpuDevice));
cudaCpuDeviceId));
// Initialize the memory from the CPU; otherwise, pages may have to be created
// on demand. We think that UVM documentation indicates that first access may

View File

@ -12,10 +12,6 @@ template <typename, typename...>
class class_;
}
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10 {
class intrusive_ptr_target;
namespace raw {
@ -37,8 +33,6 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
constexpr uint64_t kReferenceCountOne = 1;
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
// Indicates whether the object has a PyObject wrapper.
constexpr uint64_t kHasPyObject = (uint64_t(1) << 63);
template <class TTarget>
struct intrusive_target_default_null_type final {
@ -61,11 +55,7 @@ inline uint32_t refcount(uint64_t combined_refcount) {
}
inline uint32_t weakcount(uint64_t combined_refcount) {
return static_cast<uint32_t>((combined_refcount & ~kHasPyObject) >> 32);
}
inline bool has_pyobject(uint64_t combined_refcount) {
return (combined_refcount & kHasPyObject) != 0;
return static_cast<uint32_t>(combined_refcount >> 32);
}
// The only requirement for refcount increment is that it happens-before
@ -76,6 +66,12 @@ inline uint64_t atomic_combined_refcount_increment(
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
}
inline uint32_t atomic_refcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::refcount(atomic_combined_refcount_increment(
combined_refcount, kReferenceCountOne));
}
inline uint32_t atomic_weakcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::weakcount(atomic_combined_refcount_increment(
@ -103,11 +99,6 @@ inline uint32_t atomic_weakcount_decrement(
combined_refcount, kWeakReferenceCountOne));
}
template <class T, class = void>
struct TargetTraits {
static constexpr bool can_have_pyobject = false;
};
} // namespace detail
/**
@ -164,23 +155,6 @@ class C10_API intrusive_ptr_target {
// we can atomically operate on both at the same time for performance
// and defined behaviors.
//
// Note [PyObject preservation for Tensor and Storages]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// intrusive_ptr has special support for preserving PyObject wrappers
// for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of
// the combined_refcount_ is used to indicate whether the object has a
// PyObject wrapper.
//
// - The PyObject, if it exists, holds a strong reference to the
// intrusive_ptr_target.
//
// - When the refcount goes from 1 to 2, we incref the PyObject.
//
// - When the refcount goes from 2 to 1, we decref the PyObject.
//
// In other words, the intrusive_ptr keeps the PyObject alive as long as there
// are other C++ references to the intrusive_ptr_target.
mutable std::atomic<uint64_t> combined_refcount_;
static_assert(sizeof(std::atomic<uint64_t>) == 8);
static_assert(alignof(std::atomic<uint64_t>) == 8);
@ -198,8 +172,6 @@ class C10_API intrusive_ptr_target {
template <typename T>
friend struct ExclusivelyOwnedTensorTraits;
friend class torch::utils::PyObjectPreservation;
protected:
// protected destructor. We never want to destruct intrusive_ptr_target*
// directly.
@ -283,16 +255,6 @@ class C10_API intrusive_ptr_target {
*/
virtual void release_resources() {}
/**
* These two methods are called when the refcount transitions between one
* and two and the object has a PyObject wrapper.
*/
virtual void incref_pyobject() const {}
virtual void decref_pyobject() const {}
virtual bool try_incref_pyobject() const {
return false;
}
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
return detail::refcount(combined_refcount_.load(order));
}
@ -303,19 +265,6 @@ class C10_API intrusive_ptr_target {
}
};
namespace detail {
#ifndef C10_MOBILE
template <>
struct TargetTraits<c10::intrusive_ptr_target> {
// A generic intrusive_ptr<intrusive_ptr_target> may actually be a TensorImpl
// or StorageImpl, so we have to allow for PyObject support.
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
template <class TTarget, class NullType>
class weak_intrusive_ptr;
@ -365,34 +314,18 @@ class intrusive_ptr final {
void retain_() {
if (target_ != NullType::singleton()) {
uint64_t combined = detail::atomic_combined_refcount_increment(
target_->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined);
uint32_t new_refcount =
detail::atomic_refcount_increment(target_->combined_refcount_);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
new_refcount != 1,
"intrusive_ptr: Cannot increase refcount after it reached zero.");
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 1 to 2, we need to incref the
// PyObject. In other words, we need to ensure that the PyObject stays
// alive now that we have a C++ reference to this object in addition to
// the PyObject itself.
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
target_->incref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!detail::has_pyobject(combined),
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
}
void reset_() noexcept {
if (target_ != NullType::singleton()) {
if (is_uniquely_owned()) {
if (target_->combined_refcount_.load(std::memory_order_acquire) ==
detail::kUniqueRef) {
// Both counts are 1, so there are no weak references and
// we are releasing the last strong reference. No other
// threads can observe the effects of this target_ deletion
@ -404,10 +337,9 @@ class intrusive_ptr final {
auto combined_refcount = detail::atomic_combined_refcount_decrement(
target_->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined_refcount);
bool has_pyobject = detail::has_pyobject(combined_refcount);
if (new_refcount == 0) {
bool should_delete = detail::weakcount(combined_refcount) == 1;
if (detail::refcount(combined_refcount) == 0) {
bool should_delete =
(combined_refcount == detail::kWeakReferenceCountOne);
// See comment above about weakcount. As long as refcount>0,
// weakcount is one larger than the actual number of weak references.
// So we need to decrement it here.
@ -424,18 +356,6 @@ class intrusive_ptr final {
if (should_delete) {
delete target_;
}
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 2 to 1, we need to decref the
// PyObject. In other words, we don't want to keep the PyObject alive if
// there are no C++ references to this object other than the PyObject
// itself.
if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) {
target_->decref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!has_pyobject,
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
}
@ -602,16 +522,6 @@ class intrusive_ptr final {
return use_count() == 1;
}
/**
* Stronger than unique() in that it must not have any weakrefs as well.
*/
bool is_uniquely_owned() const noexcept {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton());
uint64_t combined =
target_->combined_refcount_.load(std::memory_order_acquire);
return (combined & ~detail::kHasPyObject) == detail::kUniqueRef;
}
/**
* Returns an owning (!) pointer to the underlying object and makes the
* intrusive_ptr instance invalid. That means the refcount is not decreased.
@ -1022,7 +932,6 @@ class weak_intrusive_ptr final {
if (target_ == NullType::singleton()) {
return intrusive_ptr<TTarget, NullType>();
} else {
bool increfed = false;
auto combined_refcount =
target_->combined_refcount_.load(std::memory_order_relaxed);
do {
@ -1031,31 +940,12 @@ class weak_intrusive_ptr final {
// Return nullptr.
return intrusive_ptr<TTarget, NullType>();
}
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (detail::has_pyobject(combined_refcount) &&
detail::refcount(combined_refcount) == 1 && !increfed) {
// Object has a python wrapper with no other C++ references.
// We need to to incref the Python object before we acquire a
// strong reference to the C++ object to avoid a situation
// where the Python object is deallocated concurrently.
if (!target_->try_incref_pyobject()) {
return intrusive_ptr<TTarget, NullType>();
}
increfed = true;
}
}
} while (!target_->combined_refcount_.compare_exchange_weak(
combined_refcount,
combined_refcount + detail::kReferenceCountOne,
std::memory_order_acquire,
std::memory_order_relaxed));
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (increfed && detail::refcount(combined_refcount) != 1) {
target_->decref_pyobject();
}
}
return intrusive_ptr<TTarget, NullType>(
target_, raw::DontIncreaseRefcount{});
}
@ -1170,18 +1060,7 @@ namespace intrusive_ptr {
// NullType::singleton to this function
inline void incref(intrusive_ptr_target* self) {
if (self) {
uint64_t combined = detail::atomic_combined_refcount_increment(
self->combined_refcount_, detail::kReferenceCountOne);
#ifndef C10_MOBILE
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
self->incref_pyobject();
}
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!detail::has_pyobject(combined));
#endif
detail::atomic_refcount_increment(self->combined_refcount_);
}
}

View File

@ -24,7 +24,6 @@ set(C10_XPU_HEADERS
XPUCachingAllocator.h
XPUDeviceProp.h
XPUException.h
XPUEvent.h
XPUFunctions.h
XPUMacros.h
XPUStream.h

View File

@ -1,178 +0,0 @@
#pragma once
#include <c10/xpu/XPUStream.h>
namespace c10::xpu {
/*
* XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
* constructed lazily when first recorded. It has a device, and this device is
* acquired from the first recording stream. Later streams that record the event
* must match the same device.
*
* Currently, XPUEvent does NOT support to export an inter-process event from
* another process via inter-process communication(IPC). So it means that
* inter-process communication for event handles between different processes is
* not available. This could impact some applications that rely on cross-process
* synchronization and communication.
*/
struct XPUEvent {
// Constructors
XPUEvent(bool enable_timing = false) noexcept
: enable_timing_{enable_timing} {}
~XPUEvent() {
if (isCreated()) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
}
}
C10_DISABLE_COPY_AND_ASSIGN(XPUEvent);
XPUEvent(XPUEvent&& other) = default;
XPUEvent& operator=(XPUEvent&& other) = default;
operator sycl::event&() const {
return event();
}
std::optional<c10::Device> device() const {
if (isCreated()) {
return c10::Device(c10::kXPU, device_index_);
} else {
return std::nullopt;
}
}
inline bool isCreated() const {
return (event_.get() != nullptr);
}
DeviceIndex device_index() const {
return device_index_;
}
sycl::event& event() const {
return *event_;
}
bool query() const {
using namespace sycl::info;
if (!isCreated()) {
return true;
}
return event().get_info<event::command_execution_status>() ==
event_command_status::complete;
}
void record() {
record(getCurrentXPUStream());
}
void recordOnce(const XPUStream& stream) {
if (!isCreated()) {
record(stream);
}
}
void record(const XPUStream& stream) {
if (!isCreated()) {
device_index_ = stream.device_index();
assignEvent(stream.queue());
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
} else {
TORCH_CHECK(
device_index_ == stream.device_index(),
"Event device ",
device_index_,
" does not match recording stream's device ",
stream.device_index(),
".");
reassignEvent(stream.queue());
}
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
c10::kXPU,
reinterpret_cast<uintptr_t>(event_.get()),
reinterpret_cast<uintptr_t>(&stream.queue()));
}
}
void block(const XPUStream& stream) {
if (isCreated()) {
std::vector<sycl::event> event_list{event()};
// Make this stream wait until event_ is completed.
stream.queue().ext_oneapi_submit_barrier(event_list);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
c10::kXPU,
reinterpret_cast<uintptr_t>(event_.get()),
reinterpret_cast<uintptr_t>(&stream.queue()));
}
}
}
double elapsed_time(const XPUEvent& other) const {
TORCH_CHECK(
isCreated() && other.isCreated(),
"Both events must be recorded before calculating elapsed time.");
TORCH_CHECK(
query() && other.query(),
"Both events must be completed before calculating elapsed time.");
TORCH_CHECK(
enable_timing_ && other.enable_timing_,
"Both events must be created with argument 'enable_timing=True'.");
using namespace sycl::info::event_profiling;
// Block until both of the recorded events are completed.
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
uint64_t start_time_ns = event().get_profiling_info<command_end>();
// Return the eplased time in milliseconds.
return 1e-6 *
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
}
void synchronize() const {
if (isCreated()) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
event().wait_and_throw();
}
}
private:
void assignEvent(sycl::queue& queue) {
if (enable_timing_) {
event_ = std::make_unique<sycl::event>(
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
} else {
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
}
}
void reassignEvent(sycl::queue& queue) {
event_.reset();
assignEvent(queue);
}
bool enable_timing_ = false;
c10::DeviceIndex device_index_ = -1;
// Only need to track the last event, as events in an in-order queue are
// executed sequentially.
std::unique_ptr<sycl::event> event_;
};
} // namespace c10::xpu

View File

@ -1,7 +1,7 @@
# This will define the following variables:
# SYCL_FOUND : True if the system has the SYCL library.
# SYCL_INCLUDE_DIR : Include directories needed to use SYCL.
# SYCL_LIBRARY_DIR : The path to the SYCL library.
# SYCL_LIBRARY_DIR The path to the SYCL library.
# SYCL_LIBRARY : SYCL library fullname.
# SYCL_COMPILER_VERSION : SYCL compiler version.

View File

@ -1,113 +0,0 @@
# Device Management
## Background
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
## Design
Accelerator vendors need to implement these core functions:
| Function Name | Description | Application Scenarios |
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
## Implementation
This section shows how to implement device management using `set_device` as an example. The implementation requires:
1. C++ wrappers around the device runtime
2. Python bindings to expose the C++ functions
3. User-friendly Python APIs
### C++ Side
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
:linenos:
```
### Binding
Expose the C++ functions to Python using pybind11:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
:linenos:
:emphasize-lines: 5
```
### Python Side
Wrap the C++ bindings with user-friendly Python functions:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
:language: python
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
:linenos:
```
Here's the complete mapping from C++ to Python:
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
## Guard
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:linenos:
```
**What needs to be implemented:**
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
2. **getDevice()**: Get the current device
3. **setDevice()**: Set the active device
4. **Type checking**: Validate that device type matches the backend
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"

View File

@ -1,164 +0,0 @@
# Accelerator Hooks
## Background
OpenReg hooks provide a mechanism for integrating custom accelerator devices into PyTorch's runtime system. OpenReg (Open Registration) is PyTorch's extensibility framework that allows accelerator vendors to register custom device backends without modifying PyTorch core code.
## Design
The following tables list all hooks that accelerator vendors need to implement when integrating a new device backend. These hooks are categorized into two priority levels:
- **High Priority Hooks**: Core APIs that PyTorch runtime directly depends on. Accelerator vendors are recommended to implement all high priority hooks to ensure full PyTorch compatibility and enable basic device functionality.
- **Low Priority Hooks**: Device management and utility APIs that PyTorch does not directly depend on. These hooks enhance user experience and multi-device support but are *optional*. Accelerator vendors can choose to implement them based on their specific requirements and use cases.
### High Priority Hooks
| Hook Method | Description | Application Scenario |
| ---------------------------------- | --------------------------------------------------------- | -------------------------------------------------------------------------------- |
| `init()` | Initializes the accelerator runtime and device contexts | Set up necessary state when PyTorch first accesses the device |
| `hasPrimaryContext(DeviceIndex)` | Checks if a primary context exists for the device | Determine whether device initialization has occurred |
| `getDefaultGenerator(DeviceIndex)` | Returns the default random number generator for a device | Access the device's primary RNG for reproducible random operations |
| `getNewGenerator(DeviceIndex)` | Creates a new independent random number generator | Create isolated RNG instances for parallel operations |
| `getDeviceFromPtr(void*)` | Determines which device a memory pointer belongs to | Identify the accelerator device associated with a memory allocation |
| `getPinnedMemoryAllocator()` | Returns an allocator for pinned (page-locked) host memory | Allocate host memory that can be efficiently transferred to/from the accelerator |
| `isPinnedPtr(void*)` | Checks if a pointer points to pinned memory | Validate memory types before performing operations |
### Low Priority Hooks
| Hook Method | Description | Application Scenario |
| ---------------------------------- | ---------------------------------------------------------------------------- | -------------------------------------------------------------------- |
| `isBuilt()` | Returns whether the accelerator backend is built/compiled into the extension | Check whether the accelerator library is available at compile time |
| `isAvailable()` | Returns whether the accelerator hardware is available at runtime | Verify whether accelerator devices can be detected and initialized |
| `deviceCount()` | Returns the number of available accelerator devices | Enumerate all available accelerator devices for device selection |
| `setCurrentDevice(DeviceIndex)` | Sets the active device for the current thread | Switch the current thread's context to a specific accelerator device |
| `getCurrentDevice()` | Returns the currently active device index | Query which accelerator device is active in the current thread |
| `exchangeDevice(DeviceIndex)` | Atomically exchanges the current device and returns the previous one | Temporarily switch devices and restore the previous device afterward |
| `maybeExchangeDevice(DeviceIndex)` | Conditionally exchanges device only if the index is valid | Safely attempt device switching with validation |
## Implementation
We can just take `getDefaultGenerator` as an implementation example:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
:linenos:
```
In this implementation:
1. **Override the base interface**: The `getDefaultGenerator` method overrides the virtual method from `at::PrivateUse1HooksInterface`.
2. **Delegate to device-specific implementation**: It calls `getDefaultOpenRegGenerator(device_index)`, which manages a per-device generator instance.
3. **Return device-specific generator**: The returned `at::Generator` wraps an `OpenRegGeneratorImpl` that implements device-specific random number generation.
This pattern applies to all hooks: override the interface method, validate inputs, delegate to your device-specific API, and return results in PyTorch's expected format.
## Integration Example
The following sections demonstrate how PyTorch integrates with accelerator hooks when accessing the default random number generator. The example traces the complete flow from user-facing Python code down to the device-specific implementation.
### Layer 1: User Code
User code initiates the operation by calling `manual_seed` to set the random seed for reproducible results:
```python
import torch
torch.openreg.manual_seed(42)
```
### Layer 2: Extension Python API
The Python API layer handles device management and calls into the C++ extension (defined in [`torch_openreg/openreg/random.py`][random.py]):
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py
:language: python
:start-after: LITERALINCLUDE START: OPENREG MANUAL SEED
:end-before: LITERALINCLUDE END: OPENREG MANUAL SEED
:linenos:
```
The `manual_seed` function gets the current device index and calls `torch_openreg._C._get_default_generator(idx)` to obtain the device-specific generator, then sets the seed on it.
### Layer 3: Python/C++ Bridge
The C++ extension exposes `_getDefaultGenerator` to Python, which bridges to PyTorch's core runtime:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
:linenos:
:emphasize-lines: 10-11
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
:linenos:
:emphasize-lines: 3
```
This function unpacks the device index from Python, creates a `PrivateUse1` device object, and calls `at::globalContext().defaultGenerator()`. PyTorch's context then dispatches to the registered hooks.
### Layer 4: PyTorch Core Context
PyTorch's Context class dispatches to the appropriate accelerator hooks ([`aten/src/ATen/Context.h`][Context.h]):
```{eval-rst}
.. literalinclude:: ../../../aten/src/ATen/Context.h
:language: c++
:lines: 60-103
:linenos:
:emphasize-lines: 8-9, 24-25
```
This layered architecture enables PyTorch to remain device-agnostic while delegating hardware-specific operations to accelerator implementations. The hooks are registered once at module load time:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG HOOK REGISTER
:end-before: LITERALINCLUDE END: OPENREG HOOK REGISTER
:linenos:
:emphasize-lines: 4
```
### Layer 5: Accelerator Hooks
The hooks interface provides the abstraction that PyTorch uses to delegate to device-specific implementations:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
:linenos:
```
The `getDefaultGenerator` hook method overrides the base interface and delegates to `getDefaultOpenRegGenerator`, which manages the actual generator instances.
### Layer 6: Device-Specific Implementation
The device-specific implementation manages per-device generator instances:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR IMPL
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR IMPL
:linenos:
```
This function maintains a static vector of generators (one per device), initializes them on first access, validates the device index, and returns the appropriate generator instance.
[random.py]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py#L48-L53 "random.py"
[Context.h]: https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/Context.h#L61-L102 "Context.h"

View File

@ -42,8 +42,6 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
:glob:
:maxdepth: 1
device
hooks
autoload
operators
amp

View File

@ -14,10 +14,6 @@ Utils
sdpa_kernel
SDPBackend
register_flash_attention_impl
activate_flash_attention_impl
list_flash_attention_impls
current_flash_attention_impl
Submodules
----------

View File

@ -24,11 +24,15 @@ def gen_data(special_op_lists, analysis_name):
all_ops = get_ops_for_key(None)
composite_ops = get_ops_for_key("CompositeImplicitAutograd")
noncomposite_ops = all_ops - composite_ops
with open("../../aten/src/ATen/native/native_functions.yaml") as f:
ops = yaml.load(f.read(), Loader=yaml.CLoader)
with open("annotated_ops") as f:
annotated_ops = {a.strip(): b.strip() for a, b in csv.reader(f)}
ops = yaml.load(
open("../../aten/src/ATen/native/native_functions.yaml").read(),
Loader=yaml.CLoader,
)
annotated_ops = {
a.strip(): b.strip() for a, b in list(csv.reader(open("annotated_ops")))
}
uniq_ops = []
uniq_names = set()

View File

@ -10,7 +10,7 @@ tp2_dir="$top_dir/third_party"
pip install ninja
# Install onnx
pip install -e "$tp2_dir/onnx"
pip install --no-use-pep517 -e "$tp2_dir/onnx"
# Install caffe2 and pytorch
pip install -r "$top_dir/caffe2/requirements.txt"

View File

@ -1358,45 +1358,6 @@ class concat_license_files:
# Need to create the proper LICENSE.txt for the wheel
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
Excludes:
- torch/include/torch/headeronly/*
- torch/include/torch/csrc/stable/*
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
- torch/include/torch/csrc/inductor/aoti_torch/generated/
"""
header_extensions = (".h", ".hpp", ".cuh")
header_files = [
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
]
# Paths to exclude from wrapping
exclude_dir_patterns = [
"torch/include/torch/headeronly/",
"torch/include/torch/csrc/stable/",
"torch/include/torch/csrc/inductor/aoti_torch/c/",
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
]
for header_file in header_files:
rel_path = header_file.relative_to(bdist_dir).as_posix()
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
report(f"Skipping header: {rel_path}")
continue
original_content = header_file.read_text(encoding="utf-8")
wrapped_content = (
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
f"{original_content}"
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
)
header_file.write_text(wrapped_content, encoding="utf-8")
report(f"Wrapped header: {rel_path}")
def run(self) -> None:
with concat_license_files(include_files=True):
super().run()
@ -1419,14 +1380,6 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
# need an __init__.py file otherwise we wouldn't have a package
(bdist_dir / "torch" / "__init__.py").touch()
# Wrap all header files with TORCH_STABLE_ONLY macro
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
bdist_dir = Path(self.bdist_dir)
report(
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
)
self._wrap_headers_with_macro(bdist_dir)
class clean(Command):
user_options: ClassVar[list[tuple[str, str | None, str]]] = []

View File

@ -308,16 +308,12 @@ class StepcurrentPlugin:
self.report_status = ""
assert config.cache is not None
self.cache: pytest.Cache = config.cache
directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}"
self.lastrun_location = f"{directory}/lastrun"
self.lastrun: Optional[str] = self.cache.get(self.lastrun_location, None)
self.directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}"
self.lastrun: Optional[str] = self.cache.get(self.directory, None)
self.initial_val = self.lastrun
self.skip: bool = config.getoption("stepcurrent_skip")
self.run_single: bool = config.getoption("run_single")
self.made_failing_xml_location = f"{directory}/made_failing_xml"
self.cache.set(self.made_failing_xml_location, False)
def pytest_collection_modifyitems(self, config: Config, items: list[Any]) -> None:
if not self.lastrun:
self.report_status = "Cannot find last run test, not skipping"
@ -353,10 +349,8 @@ class StepcurrentPlugin:
def pytest_runtest_protocol(self, item, nextitem) -> None:
self.lastrun = item.nodeid
self.cache.set(self.lastrun_location, self.lastrun)
self.cache.set(self.directory, self.lastrun)
def pytest_sessionfinish(self, session, exitstatus):
if exitstatus == 0:
self.cache.set(self.lastrun_location, self.initial_val)
if exitstatus != 0:
self.cache.set(self.made_failing_xml_location, True)
self.cache.set(self.directory, self.initial_val)

View File

@ -38,7 +38,7 @@ using torch::stable::Tensor;
Tensor sgd_out_of_place(
const Tensor param,
const Tensor grad,
const double weight_decay,
const float weight_decay,
const double lr,
const bool maximize) {
STD_TORCH_CHECK(param.dim() == 1, "param must be 1D");
@ -57,7 +57,7 @@ Tensor sgd_out_of_place(
reinterpret_cast<float*>(param.data_ptr()),
reinterpret_cast<float*>(grad.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
float(weight_decay),
weight_decay,
lr,
maximize,
param.numel()
@ -66,29 +66,44 @@ Tensor sgd_out_of_place(
return out;
}
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = sgd_out_of_place(
torch::stable::detail::to<Tensor>(stack[0]),
torch::stable::detail::to<Tensor>(stack[1]),
float(torch::stable::detail::to<double>(stack[2])),
torch::stable::detail::to<double>(stack[3]),
torch::stable::detail::to<bool>(stack[4]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("sgd_out_of_place", TORCH_BOX(&sgd_out_of_place));
m.impl("sgd_out_of_place", &boxed_sgd_out_of_place);
}
Tensor identity(Tensor t) {
return t;
}
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = identity(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("identity(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
m.impl("identity", TORCH_BOX(&identity));
m.impl("identity", &boxed_identity);
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("identity", TORCH_BOX(&identity));
m.impl("identity", &boxed_identity);
}
Tensor my_abs(Tensor t) {
@ -99,12 +114,17 @@ Tensor my_abs(Tensor t) {
return torch::stable::detail::to<Tensor>(stack[0]);
}
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_abs(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(tensor_res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_abs(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_abs", TORCH_BOX(&my_abs));
m.impl("my_abs", &boxed_my_abs);
}
Tensor my_ones_like(Tensor t, StableIValue device) {
@ -125,12 +145,17 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
return torch::stable::detail::to<Tensor>(stack[0]);
}
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = my_ones_like(torch::stable::detail::to<Tensor>(stack[0]), stack[1]);
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_ones_like(Tensor t, Device d) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_ones_like", TORCH_BOX(&my_ones_like));
m.impl("my_ones_like", &boxed_my_ones_like);
}
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
@ -152,12 +177,19 @@ std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3
torch::stable::detail::to<bool>(stack_is_leaf[0]));
}
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto tuple = exp_neg_is_leaf(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<Tensor>(stack[2]));
stack[0] = torch::stable::detail::from(std::get<0>(tuple));
stack[1] = torch::stable::detail::from(std::get<1>(tuple));
stack[2] = torch::stable::detail::from(std::get<2>(tuple));
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("exp_neg_is_leaf", TORCH_BOX(&exp_neg_is_leaf));
m.impl("exp_neg_is_leaf", &boxed_exp_neg_is_leaf);
}
Tensor neg_exp(Tensor t) {
@ -168,12 +200,17 @@ Tensor neg_exp(Tensor t) {
return torch::stable::detail::to<Tensor>(stack[0]);
}
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("neg_exp(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("neg_exp", TORCH_BOX(&neg_exp));
m.impl("neg_exp", &boxed_neg_exp);
}
Tensor divide_neg_exp(Tensor t) {
@ -192,53 +229,108 @@ Tensor divide_neg_exp(Tensor t) {
return torch::stable::detail::to<Tensor>(stack_div[0]);
}
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = divide_neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("divide_neg_exp(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("divide_neg_exp", TORCH_BOX(&divide_neg_exp));
m.impl("divide_neg_exp", &boxed_divide_neg_exp);
}
bool is_contiguous(Tensor t) {
return t.is_contiguous();
}
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
bool res = is_contiguous(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("is_contiguous(Tensor t) -> bool");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("is_contiguous", TORCH_BOX(&is_contiguous));
m.impl("is_contiguous", &boxed_is_contiguous);
}
Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
return transpose(t, dim0, dim1);
}
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_transpose(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<int64_t>(stack[1]), torch::stable::detail::to<int64_t>(stack[2]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_empty_like(Tensor t) {
return empty_like(t);
}
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_empty_like(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
bool my_is_cpu(Tensor t) {
return t.is_cpu();
}
void boxed_my_is_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_is_cpu(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor fill_infinity(Tensor t) {
auto value = std::numeric_limits<float>::infinity();
return fill_(t, value);
}
void boxed_fill_infinity(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = fill_infinity(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_pad(Tensor t) {
std::string mode = "constant";
double value = 0.0;
return pad(t, {1, 2, 2, 1}, mode, value);
}
void boxed_my_pad(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = my_pad(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
return narrow(t, dim, start, length);
}
void boxed_my_narrow(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = my_narrow(
torch::stable::detail::to<Tensor>(stack[0]),
torch::stable::detail::to<int64_t>(stack[1]),
torch::stable::detail::to<int64_t>(stack[2]),
torch::stable::detail::to<int64_t>(stack[3]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
@ -250,19 +342,40 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
return new_empty(t, sizes, dtype);
}
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_empty_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_new_zeros_dtype_variant(Tensor t) {
auto dtype = std::make_optional(at::ScalarType::Float);
return new_zeros(t, {2, 5}, dtype);
}
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_zeros_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
return copy_(dst, src, non_blocking);
}
void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_copy_(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<bool>(stack[2]));
stack[0] = torch::stable::detail::from(tensor_res);
}
Tensor my_clone(Tensor t) {
return clone(t);
}
void boxed_my_clone(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_clone(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(tensor_res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
m.def("my_empty_like(Tensor t) -> Tensor");
@ -276,39 +389,57 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_transpose", TORCH_BOX(&my_transpose));
m.impl("my_empty_like", TORCH_BOX(&my_empty_like));
m.impl("fill_infinity", TORCH_BOX(&fill_infinity));
m.impl("my_is_cpu", TORCH_BOX(&my_is_cpu));
m.impl("my_new_empty_dtype_variant", TORCH_BOX(&my_new_empty_dtype_variant));
m.impl("my_new_zeros_dtype_variant", TORCH_BOX(&my_new_zeros_dtype_variant));
m.impl("my_copy_", TORCH_BOX(&my_copy_));
m.impl("my_clone", TORCH_BOX(&my_clone));
m.impl("my_transpose", &boxed_my_transpose);
m.impl("my_empty_like", &boxed_empty_like);
m.impl("fill_infinity", &boxed_fill_infinity);
m.impl("my_is_cpu", &boxed_my_is_cpu);
m.impl("my_new_empty_dtype_variant", &boxed_my_new_empty_dtype_variant);
m.impl("my_new_zeros_dtype_variant", &boxed_my_new_zeros_dtype_variant);
m.impl("my_copy_", &boxed_my_copy_);
m.impl("my_clone", &boxed_my_clone);
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {
m.impl("my_pad", TORCH_BOX(&my_pad));
m.impl("my_narrow", TORCH_BOX(&my_narrow));
m.impl("my_pad", &boxed_my_pad);
m.impl("my_narrow", &boxed_my_narrow);
}
Tensor my_zero_(Tensor t) {
return zero_(t);
}
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_zero_(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_amax(Tensor t) {
return amax(t, 0, false);
}
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_amax_vec(Tensor t) {
return amax(t, {0,1}, false);
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax_vec(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)");
m.def("my_amax(Tensor a) -> Tensor");
m.def("my_amax_vec(Tensor a) -> Tensor");
m.def("my_is_cpu(Tensor t) -> bool");
m.def("test_default_constructor(bool undefined) -> bool");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("my_zero_", &boxed_my_zero_);
}
bool test_default_constructor(bool defined) {
@ -330,12 +461,22 @@ bool test_default_constructor(bool defined) {
return out.defined();
}
void boxed_test_default_constructor(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
bool res = test_default_constructor(torch::stable::detail::to<bool>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("test_default_constructor(bool undefined) -> bool");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_zero_", TORCH_BOX(&my_zero_));
m.impl("my_amax", TORCH_BOX(&my_amax));
m.impl("my_amax_vec", TORCH_BOX(&my_amax_vec));
m.impl("test_default_constructor", TORCH_BOX(&test_default_constructor));
m.impl("test_default_constructor", &boxed_test_default_constructor);
m.impl("my_amax", &boxed_my_amax);
m.impl("my_amax_vec", &boxed_my_amax_vec);
}
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
@ -344,11 +485,23 @@ std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
}
void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
// Why is the following NOT torch::stable::detail::to<HeaderOnlyArrayRef<Tensor>>(stack[0])? Because calling `to`
// on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef
// is not owning, so it cannot safely steward the result of the torch::stable::detail::to<>.
auto res = my__foreach_mul(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
stack[0] = torch::stable::detail::from(res);
}
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
}
void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
my__foreach_mul_(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
}
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
// This function tests that my__foreach_mul can take in std::initializer_lists
// in addition to std::vectors.
@ -359,6 +512,11 @@ std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
}
void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = make_tensor_clones_and_call_foreach(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
@ -366,9 +524,9 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my__foreach_mul", TORCH_BOX(&my__foreach_mul));
m.impl("my__foreach_mul_", TORCH_BOX(&my__foreach_mul_));
m.impl("make_tensor_clones_and_call_foreach", TORCH_BOX(&make_tensor_clones_and_call_foreach));
m.impl("my__foreach_mul", &boxed_my__foreach_mul);
m.impl("my__foreach_mul_", &boxed_my__foreach_mul_);
m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach);
}
// Test functions for torch::stable::Tensor device method
@ -532,6 +690,14 @@ int64_t test_device_guard(int64_t device_index) {
return currentDevice;
}
void boxed_test_device_guard(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int res = test_device_guard(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
stack[0] = torch::stable::detail::from(res);
}
int64_t test_device_guard_set_index() {
using torch::stable::accelerator::DeviceGuard;
@ -543,6 +709,14 @@ int64_t test_device_guard_set_index() {
return currentDevice;
}
void boxed_test_device_guard_set_index(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_device_guard_set_index();
stack[0] = torch::stable::detail::from(res);
}
int64_t test_stream(int32_t device_index) {
STD_TORCH_CHECK(
device_index >= std::numeric_limits<int32_t>::min() &&
@ -552,10 +726,26 @@ int64_t test_stream(int32_t device_index) {
return torch::stable::accelerator::getCurrentStream(device_index).id();
}
void boxed_test_stream(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_stream(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
stack[0] = torch::stable::detail::from(res);
}
int64_t test_get_current_device_index() {
return torch::stable::accelerator::getCurrentDeviceIndex();
}
void boxed_test_get_current_device_index(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_get_current_device_index();
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("test_device_guard(int device_index) -> int");
m.def("test_device_guard_set_index() -> int");
@ -564,10 +754,10 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_device_guard", TORCH_BOX(&test_device_guard));
m.impl("test_device_guard_set_index", TORCH_BOX(&test_device_guard_set_index));
m.impl("test_stream", TORCH_BOX(&test_stream));
m.impl("test_get_current_device_index", TORCH_BOX(&test_get_current_device_index));
m.impl("test_device_guard", &boxed_test_device_guard);
m.impl("test_device_guard_set_index", &boxed_test_device_guard_set_index);
m.impl("test_stream", &boxed_test_stream);
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
}
#endif // LAE_USE_CUDA

View File

@ -33,7 +33,7 @@ class clean(distutils.command.clean.clean):
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
"cxx": ["-fdiagnostics-color=always"],
}
extension = CppExtension

View File

@ -4,12 +4,17 @@
#include <c10/util/Exception.h>
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
void orCheckFail(
const char* func,
const char* file,
uint32_t line,
const char* msg = "");
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (C10_UNLIKELY(__err != orSuccess)) { \
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (__err != orSuccess) { \
orCheckFail( \
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
} while (0)

View File

@ -1,4 +1,3 @@
#include <c10/util/Exception.h>
#include <include/openreg.h>
#include "OpenRegException.h"
@ -10,22 +9,21 @@ orError_t GetDeviceCount(int* dev_count) {
return orGetDeviceCount(dev_count);
}
orError_t GetDevice(DeviceIndex* device) {
orError_t GetDevice(c10::DeviceIndex* device) {
int tmp_device = -1;
auto err = orGetDevice(&tmp_device);
*device = static_cast<DeviceIndex>(tmp_device);
*device = static_cast<c10::DeviceIndex>(tmp_device);
return err;
}
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
orError_t SetDevice(DeviceIndex device) {
orError_t SetDevice(c10::DeviceIndex device) {
int cur_device = -1;
OPENREG_CHECK(orGetDevice(&cur_device));
orGetDevice(&cur_device);
if (device == cur_device) {
return orSuccess;
}
return orSetDevice(device);
}
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
int device_count_impl() {
int count = 0;
@ -33,37 +31,34 @@ int device_count_impl() {
return count;
}
OPENREG_EXPORT DeviceIndex device_count() noexcept {
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
// initialize number of devices only once
static int count = []() {
try {
auto result = device_count_impl();
TORCH_CHECK(
result <= std::numeric_limits<DeviceIndex>::max(),
result <= std::numeric_limits<c10::DeviceIndex>::max(),
"Too many devices, DeviceIndex overflowed");
return result;
} catch (const Error& ex) {
} catch (const c10::Error& ex) {
// We don't want to fail, but still log the warning
// msg() returns the message without the stack trace
TORCH_WARN("Device initialization: ", ex.msg());
return 0;
}
}();
return static_cast<DeviceIndex>(count);
return static_cast<c10::DeviceIndex>(count);
}
OPENREG_EXPORT DeviceIndex current_device() {
DeviceIndex cur_device = -1;
OPENREG_CHECK(GetDevice(&cur_device));
OPENREG_EXPORT c10::DeviceIndex current_device() {
c10::DeviceIndex cur_device = -1;
GetDevice(&cur_device);
return cur_device;
}
// LITERALINCLUDE START: OPENREG set_device FUNCTION
OPENREG_EXPORT void set_device(DeviceIndex device) {
check_device_index(device);
OPENREG_CHECK(SetDevice(device));
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
SetDevice(device);
}
// LITERALINCLUDE END: OPENREG set_device FUNCTION
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
int current_device = -1;
@ -76,8 +71,4 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
return current_device;
}
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
check_device_index(to_device);
return ExchangeDevice(to_device);
}
} // namespace c10::openreg

View File

@ -9,20 +9,10 @@
namespace c10::openreg {
OPENREG_EXPORT DeviceIndex device_count() noexcept;
OPENREG_EXPORT DeviceIndex current_device();
OPENREG_EXPORT void set_device(DeviceIndex device);
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
OPENREG_EXPORT c10::DeviceIndex current_device();
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
static inline void check_device_index(int64_t device) {
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
"The device index is out of range. It must be in [0, ",
static_cast<int>(c10::openreg::device_count()),
"), but got ",
static_cast<int>(device),
".");
}
} // namespace c10::openreg

View File

@ -5,7 +5,6 @@ static std::vector<at::Generator> default_generators;
namespace c10::openreg {
// LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR IMPL
const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) {
static bool flag [[maybe_unused]] = []() {
auto deivce_nums = device_count();
@ -25,6 +24,5 @@ const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) {
}
return default_generators[idx];
}
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR IMPL
} // namespace c10::openreg

View File

@ -2,8 +2,6 @@
namespace c10::openreg {
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
} // namespace c10::openreg

Some files were not shown because too many files have changed in this diff Show More