Compare commits

..

4 Commits

Author SHA1 Message Date
d25559423f Update
[ghstack-poisoned]
2025-11-17 22:50:04 +00:00
654c149d07 Update (base update)
[ghstack-poisoned]
2025-11-17 22:50:04 +00:00
80a0fd3f4d Update
[ghstack-poisoned]
2025-09-04 00:27:23 -07:00
4c4a6b3644 Update (base update)
[ghstack-poisoned]
2025-09-04 00:27:23 -07:00
283 changed files with 3309 additions and 12321 deletions

View File

@ -1,19 +0,0 @@
# Aarch64 (ARM/Graviton) Support Scripts
Scripts for building aarch64 PyTorch PIP Wheels. These scripts build the following wheels:
* torch
* torchvision
* torchaudio
* torchtext
* torchdata
## Aarch64_ci_build.sh
This script is design to support CD operations within PyPi manylinux aarch64 container, and be executed in the container. It prepares the container and then executes __aarch64_wheel_ci_build.py__ to build the wheels. The script "assumes" the PyTorch repo is located at: ```/pytorch``` and will put the wheels into ```/artifacts```.
### Usage
```DESIRED_PYTHON=<PythonVersion> aarch64_ci_build.sh```
__NOTE:__ CI build is currently __EXPERMINTAL__
## Build_aarch64_wheel.py
This app allows a person to build using AWS EC3 resources and requires AWS-CLI and Boto3 with AWS credentials to support building EC2 instances for the wheel builds. Can be used in a codebuild CD or from a local system.
### Usage
```build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch <RCtag>```

View File

@ -1,53 +0,0 @@
#!/bin/bash
set -eux -o pipefail
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
# Set CUDA architecture lists to match x86 build_cuda.sh
if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
elif [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
elif [[ "$GPU_ARCH_VERSION" == *"13.0"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;11.0;12.0+PTX"
fi
# Compress the fatbin with -compress-mode=size for CUDA 13
if [[ "$DESIRED_CUDA" == *"13"* ]]; then
export TORCH_NVCC_FLAGS="-compress-mode=size"
# Bundle ptxas into the cu13 wheel, see https://github.com/pytorch/pytorch/issues/163801
export BUILD_BUNDLE_PTXAS=1
fi
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
source $SCRIPTPATH/aarch64_ci_setup.sh
###############################################################################
# Run aarch64 builder python
###############################################################################
cd /
# adding safe directory for git as the permissions will be
# on the mounted pytorch repo
git config --global --add safe.directory /pytorch
pip install -r /pytorch/requirements.txt
pip install auditwheel==6.2.0 wheel
if [ "$DESIRED_CUDA" = "cpu" ]; then
echo "BASE_CUDA_VERSION is not set. Building cpu wheel."
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
else
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
export USE_SYSTEM_NCCL=1
# Check if we should use NVIDIA libs from PyPI (similar to x86 build_cuda.sh logic)
if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
echo "Bundling CUDA libraries with wheel for aarch64."
else
echo "Using nvidia libs from pypi for aarch64."
echo "Updated PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64: $PYTORCH_EXTRA_INSTALL_REQUIREMENTS"
export USE_NVIDIA_PYPI_LIBS=1
fi
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
fi

View File

@ -1,21 +0,0 @@
#!/bin/bash
set -eux -o pipefail
# This script is used to prepare the Docker container for aarch64_ci_wheel_build.py python script
# By creating symlinks from desired /opt/python to /usr/local/bin/
NUMPY_VERSION=2.0.2
if [[ "$DESIRED_PYTHON" == "3.13" || "$DESIRED_PYTHON" == "3.13t" ]]; then
NUMPY_VERSION=2.1.2
fi
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
source $SCRIPTPATH/../manywheel/set_desired_python.sh
pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2
for tool in python python3 pip pip3 ninja scons patchelf; do
ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin;
done
python --version

View File

@ -1,333 +0,0 @@
#!/usr/bin/env python3
# encoding: UTF-8
import os
import shutil
from subprocess import check_call, check_output
def list_dir(path: str) -> list[str]:
"""'
Helper for getting paths for Python
"""
return check_output(["ls", "-1", path]).decode().split("\n")
def replace_tag(filename) -> None:
with open(filename) as f:
lines = f.readlines()
for i, line in enumerate(lines):
if line.startswith("Tag:"):
lines[i] = line.replace("-linux_", "-manylinux_2_28_")
print(f"Updated tag from {line} to {lines[i]}")
break
with open(filename, "w") as f:
f.writelines(lines)
def patch_library_rpath(
folder: str,
lib_name: str,
use_nvidia_pypi_libs: bool = False,
desired_cuda: str = "",
) -> None:
"""Apply patchelf to set RPATH for a library in torch/lib"""
lib_path = f"{folder}/tmp/torch/lib/{lib_name}"
if use_nvidia_pypi_libs:
# For PyPI NVIDIA libraries, construct CUDA RPATH
cuda_rpaths = [
"$ORIGIN/../../nvidia/cudnn/lib",
"$ORIGIN/../../nvidia/nvshmem/lib",
"$ORIGIN/../../nvidia/nccl/lib",
"$ORIGIN/../../nvidia/cusparselt/lib",
]
if "130" in desired_cuda:
cuda_rpaths.append("$ORIGIN/../../nvidia/cu13/lib")
else:
cuda_rpaths.extend(
[
"$ORIGIN/../../nvidia/cublas/lib",
"$ORIGIN/../../nvidia/cuda_cupti/lib",
"$ORIGIN/../../nvidia/cuda_nvrtc/lib",
"$ORIGIN/../../nvidia/cuda_runtime/lib",
"$ORIGIN/../../nvidia/cufft/lib",
"$ORIGIN/../../nvidia/curand/lib",
"$ORIGIN/../../nvidia/cusolver/lib",
"$ORIGIN/../../nvidia/cusparse/lib",
"$ORIGIN/../../nvidia/nvtx/lib",
"$ORIGIN/../../nvidia/cufile/lib",
]
)
# Add $ORIGIN for local torch libs
rpath = ":".join(cuda_rpaths) + ":$ORIGIN"
else:
# For bundled libraries, just use $ORIGIN
rpath = "$ORIGIN"
if os.path.exists(lib_path):
os.system(
f"cd {folder}/tmp/torch/lib/; "
f"patchelf --set-rpath '{rpath}' --force-rpath {lib_name}"
)
def copy_and_patch_library(
src_path: str,
folder: str,
use_nvidia_pypi_libs: bool = False,
desired_cuda: str = "",
) -> None:
"""Copy a library to torch/lib and patch its RPATH"""
if os.path.exists(src_path):
lib_name = os.path.basename(src_path)
shutil.copy2(src_path, f"{folder}/tmp/torch/lib/{lib_name}")
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
def package_cuda_wheel(wheel_path, desired_cuda) -> None:
"""
Package the cuda wheel libraries
"""
folder = os.path.dirname(wheel_path)
os.mkdir(f"{folder}/tmp")
os.system(f"unzip {wheel_path} -d {folder}/tmp")
# Delete original wheel since it will be repackaged
os.system(f"rm {wheel_path}")
# Check if we should use PyPI NVIDIA libraries or bundle system libraries
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
if use_nvidia_pypi_libs:
print("Using nvidia libs from pypi - skipping CUDA library bundling")
# For PyPI approach, we don't bundle CUDA libraries - they come from PyPI packages
# We only need to bundle non-NVIDIA libraries
minimal_libs_to_copy = [
"/lib64/libgomp.so.1",
"/usr/lib64/libgfortran.so.5",
"/acl/build/libarm_compute.so",
"/acl/build/libarm_compute_graph.so",
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_lapack_core.so.0",
"/usr/local/lib/libnvpl_blas_core.so.0",
]
# Copy minimal libraries to unzipped_folder/torch/lib
for lib_path in minimal_libs_to_copy:
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
# Patch torch libraries used for searching libraries
torch_libs_to_patch = [
"libtorch.so",
"libtorch_cpu.so",
"libtorch_cuda.so",
"libtorch_cuda_linalg.so",
"libtorch_global_deps.so",
"libtorch_python.so",
"libtorch_nvshmem.so",
"libc10.so",
"libc10_cuda.so",
"libcaffe2_nvrtc.so",
"libshm.so",
]
for lib_name in torch_libs_to_patch:
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
else:
print("Bundling CUDA libraries with wheel")
# Original logic for bundling system CUDA libraries
# Common libraries for all CUDA versions
common_libs = [
# Non-NVIDIA system libraries
"/lib64/libgomp.so.1",
"/usr/lib64/libgfortran.so.5",
"/acl/build/libarm_compute.so",
"/acl/build/libarm_compute_graph.so",
# Common CUDA libraries (same for all versions)
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_lapack_core.so.0",
"/usr/local/lib/libnvpl_blas_core.so.0",
"/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so",
"/usr/local/cuda/lib64/libcudnn.so.9",
"/usr/local/cuda/lib64/libcusparseLt.so.0",
"/usr/local/cuda/lib64/libcurand.so.10",
"/usr/local/cuda/lib64/libnccl.so.2",
"/usr/local/cuda/lib64/libnvshmem_host.so.3",
"/usr/local/cuda/lib64/libcudnn_adv.so.9",
"/usr/local/cuda/lib64/libcudnn_cnn.so.9",
"/usr/local/cuda/lib64/libcudnn_graph.so.9",
"/usr/local/cuda/lib64/libcudnn_ops.so.9",
"/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9",
"/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9",
"/usr/local/cuda/lib64/libcudnn_heuristic.so.9",
"/usr/local/cuda/lib64/libcufile.so.0",
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
"/usr/local/cuda/lib64/libcusparse.so.12",
]
# CUDA version-specific libraries
if "13" in desired_cuda:
minor_version = desired_cuda[-1]
version_specific_libs = [
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13",
"/usr/local/cuda/lib64/libcublas.so.13",
"/usr/local/cuda/lib64/libcublasLt.so.13",
"/usr/local/cuda/lib64/libcudart.so.13",
"/usr/local/cuda/lib64/libcufft.so.12",
"/usr/local/cuda/lib64/libcusolver.so.12",
"/usr/local/cuda/lib64/libnvJitLink.so.13",
"/usr/local/cuda/lib64/libnvrtc.so.13",
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.{minor_version}",
]
elif "12" in desired_cuda:
# Get the last character for libnvrtc-builtins version (e.g., "129" -> "9")
minor_version = desired_cuda[-1]
version_specific_libs = [
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12",
"/usr/local/cuda/lib64/libcublas.so.12",
"/usr/local/cuda/lib64/libcublasLt.so.12",
"/usr/local/cuda/lib64/libcudart.so.12",
"/usr/local/cuda/lib64/libcufft.so.11",
"/usr/local/cuda/lib64/libcusolver.so.11",
"/usr/local/cuda/lib64/libnvJitLink.so.12",
"/usr/local/cuda/lib64/libnvrtc.so.12",
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.{minor_version}",
]
else:
raise ValueError(f"Unsupported CUDA version: {desired_cuda}.")
# Combine all libraries
libs_to_copy = common_libs + version_specific_libs
# Copy libraries to unzipped_folder/torch/lib
for lib_path in libs_to_copy:
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
# Make sure the wheel is tagged with manylinux_2_28
for f in os.scandir(f"{folder}/tmp/"):
if f.is_dir() and f.name.endswith(".dist-info"):
replace_tag(f"{f.path}/WHEEL")
break
os.system(f"wheel pack {folder}/tmp/ -d {folder}")
os.system(f"rm -rf {folder}/tmp/")
def complete_wheel(folder: str) -> str:
"""
Complete wheel build and put in artifact location
"""
wheel_name = list_dir(f"/{folder}/dist")[0]
# Please note for cuda we don't run auditwheel since we use custom script to package
# the cuda dependencies to the wheel file using update_wheel() method.
# However we need to make sure filename reflects the correct Manylinux platform.
if "pytorch" in folder and not enable_cuda:
print("Repairing Wheel with AuditWheel")
check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder)
repaired_wheel_name = list_dir(f"/{folder}/wheelhouse")[0]
print(f"Moving {repaired_wheel_name} wheel to /{folder}/dist")
os.rename(
f"/{folder}/wheelhouse/{repaired_wheel_name}",
f"/{folder}/dist/{repaired_wheel_name}",
)
else:
repaired_wheel_name = list_dir(f"/{folder}/dist")[0]
print(f"Copying {repaired_wheel_name} to artifacts")
shutil.copy2(
f"/{folder}/dist/{repaired_wheel_name}", f"/artifacts/{repaired_wheel_name}"
)
return repaired_wheel_name
def parse_arguments():
"""
Parse inline arguments
"""
from argparse import ArgumentParser
parser = ArgumentParser("AARCH64 wheels python CD")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")
parser.add_argument("--test-only", type=str)
parser.add_argument("--enable-mkldnn", action="store_true")
parser.add_argument("--enable-cuda", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
"""
Entry Point
"""
args = parse_arguments()
enable_mkldnn = args.enable_mkldnn
enable_cuda = args.enable_cuda
branch = check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd="/pytorch"
).decode()
print("Building PyTorch wheel")
build_vars = ""
# MAX_JOB=5 is not required for CPU backend (see commit 465d98b)
if enable_cuda:
build_vars += "MAX_JOBS=5 "
# Handle PyPI NVIDIA libraries vs bundled libraries
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
if use_nvidia_pypi_libs:
print("Configuring build for PyPI NVIDIA libraries")
# Configure for dynamic linking (matching x86 logic)
build_vars += "ATEN_STATIC_CUDA=0 USE_CUDA_STATIC_LINK=0 USE_CUPTI_SO=1 "
else:
print("Configuring build for bundled NVIDIA libraries")
# Keep existing static linking approach - already configured above
override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION")
desired_cuda = os.getenv("DESIRED_CUDA")
if override_package_version is not None:
version = override_package_version
build_vars += (
f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version} PYTORCH_BUILD_NUMBER=1 "
)
elif branch in ["nightly", "main"]:
build_date = (
check_output(["git", "log", "--pretty=format:%cs", "-1"], cwd="/pytorch")
.decode()
.replace("-", "")
)
version = (
check_output(["cat", "version.txt"], cwd="/pytorch").decode().strip()[:-2]
)
if enable_cuda:
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date}+{desired_cuda} PYTORCH_BUILD_NUMBER=1 "
else:
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 "
elif branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
if enable_mkldnn:
print("build pytorch with mkldnn+acl backend")
build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
build_vars += "ACL_ROOT_DIR=/acl "
if enable_cuda:
build_vars += "BLAS=NVPL "
else:
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS "
else:
print("build pytorch without mkldnn backend")
os.system(f"cd /pytorch; {build_vars} python3 -m build --wheel --no-isolation")
if enable_cuda:
print("Updating Cuda Dependency")
filename = os.listdir("/pytorch/dist/")
wheel_path = f"/pytorch/dist/{filename[0]}"
package_cuda_wheel(wheel_path, desired_cuda)
pytorch_wheel_name = complete_wheel("/pytorch/")
print(f"Build Complete. Created {pytorch_wheel_name}..")

View File

@ -1,999 +0,0 @@
#!/usr/bin/env python3
# This script is for building AARCH64 wheels using AWS EC2 instances.
# To generate binaries for the release follow these steps:
# 1. Update mappings for each of the Domain Libraries by adding new row to a table like this:
# "v1.11.0": ("0.11.0", "rc1"),
# 2. Run script with following arguments for each of the supported python versions and required tag, for example:
# build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch v1.11.0-rc3
import os
import subprocess
import sys
import time
from typing import Optional, Union
import boto3
# AMI images for us-east-1, change the following based on your ~/.aws/config
os_amis = {
"ubuntu20_04": "ami-052eac90edaa9d08f", # login_name: ubuntu
"ubuntu22_04": "ami-0c6c29c5125214c77", # login_name: ubuntu
"redhat8": "ami-0698b90665a2ddcf1", # login_name: ec2-user
}
ubuntu20_04_ami = os_amis["ubuntu20_04"]
def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]:
if key_name is None:
key_name = os.getenv("AWS_KEY_NAME")
if key_name is None:
return os.getenv("SSH_KEY_PATH", ""), ""
homedir_path = os.path.expanduser("~")
default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem")
return os.getenv("SSH_KEY_PATH", default_path), key_name
ec2 = boto3.resource("ec2")
def ec2_get_instances(filter_name, filter_value):
return ec2.instances.filter(
Filters=[{"Name": filter_name, "Values": [filter_value]}]
)
def ec2_instances_of_type(instance_type="t4g.2xlarge"):
return ec2_get_instances("instance-type", instance_type)
def ec2_instances_by_id(instance_id):
rc = list(ec2_get_instances("instance-id", instance_id))
return rc[0] if len(rc) > 0 else None
def start_instance(
key_name, ami=ubuntu20_04_ami, instance_type="t4g.2xlarge", ebs_size: int = 50
):
inst = ec2.create_instances(
ImageId=ami,
InstanceType=instance_type,
SecurityGroups=["ssh-allworld"],
KeyName=key_name,
MinCount=1,
MaxCount=1,
BlockDeviceMappings=[
{
"DeviceName": "/dev/sda1",
"Ebs": {
"DeleteOnTermination": True,
"VolumeSize": ebs_size,
"VolumeType": "standard",
},
}
],
)[0]
print(f"Create instance {inst.id}")
inst.wait_until_running()
running_inst = ec2_instances_by_id(inst.id)
print(f"Instance started at {running_inst.public_dns_name}")
return running_inst
class RemoteHost:
addr: str
keyfile_path: str
login_name: str
container_id: Optional[str] = None
ami: Optional[str] = None
def __init__(self, addr: str, keyfile_path: str, login_name: str = "ubuntu"):
self.addr = addr
self.keyfile_path = keyfile_path
self.login_name = login_name
def _gen_ssh_prefix(self) -> list[str]:
return [
"ssh",
"-o",
"StrictHostKeyChecking=no",
"-i",
self.keyfile_path,
f"{self.login_name}@{self.addr}",
"--",
]
@staticmethod
def _split_cmd(args: Union[str, list[str]]) -> list[str]:
return args.split() if isinstance(args, str) else args
def run_ssh_cmd(self, args: Union[str, list[str]]) -> None:
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
def check_ssh_output(self, args: Union[str, list[str]]) -> str:
return subprocess.check_output(
self._gen_ssh_prefix() + self._split_cmd(args)
).decode("utf-8")
def scp_upload_file(self, local_file: str, remote_file: str) -> None:
subprocess.check_call(
[
"scp",
"-i",
self.keyfile_path,
local_file,
f"{self.login_name}@{self.addr}:{remote_file}",
]
)
def scp_download_file(
self, remote_file: str, local_file: Optional[str] = None
) -> None:
if local_file is None:
local_file = "."
subprocess.check_call(
[
"scp",
"-i",
self.keyfile_path,
f"{self.login_name}@{self.addr}:{remote_file}",
local_file,
]
)
def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None:
self.run_ssh_cmd("sudo apt-get install -y docker.io")
self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}")
self.run_ssh_cmd("sudo service docker start")
self.run_ssh_cmd(f"docker pull {image}")
self.container_id = self.check_ssh_output(
f"docker run -t -d -w /root {image}"
).strip()
def using_docker(self) -> bool:
return self.container_id is not None
def run_cmd(self, args: Union[str, list[str]]) -> None:
if not self.using_docker():
return self.run_ssh_cmd(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + [
"docker",
"exec",
"-i",
self.container_id,
"bash",
]
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE)
p.communicate(
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
"utf-8"
)
)
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd)
def check_output(self, args: Union[str, list[str]]) -> str:
if not self.using_docker():
return self.check_ssh_output(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + [
"docker",
"exec",
"-i",
self.container_id,
"bash",
]
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
(out, err) = p.communicate(
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
"utf-8"
)
)
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err)
return out.decode("utf-8")
def upload_file(self, local_file: str, remote_file: str) -> None:
if not self.using_docker():
return self.scp_upload_file(local_file, remote_file)
tmp_file = os.path.join("/tmp", os.path.basename(local_file))
self.scp_upload_file(local_file, tmp_file)
self.run_ssh_cmd(
["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"]
)
self.run_ssh_cmd(["rm", tmp_file])
def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None:
if not self.using_docker():
return self.scp_download_file(remote_file, local_file)
tmp_file = os.path.join("/tmp", os.path.basename(remote_file))
self.run_ssh_cmd(
["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file]
)
self.scp_download_file(tmp_file, local_file)
self.run_ssh_cmd(["rm", tmp_file])
def download_wheel(
self, remote_file: str, local_file: Optional[str] = None
) -> None:
if self.using_docker() and local_file is None:
basename = os.path.basename(remote_file)
local_file = basename.replace(
"-linux_aarch64.whl", "-manylinux2014_aarch64.whl"
)
self.download_file(remote_file, local_file)
def list_dir(self, path: str) -> list[str]:
return self.check_output(["ls", "-1", path]).split("\n")
def wait_for_connection(addr, port, timeout=15, attempt_cnt=5):
import socket
for i in range(attempt_cnt):
try:
with socket.create_connection((addr, port), timeout=timeout):
return
except (ConnectionRefusedError, TimeoutError): # noqa: PERF203
if i == attempt_cnt - 1:
raise
time.sleep(timeout)
def update_apt_repo(host: RemoteHost) -> None:
time.sleep(5)
host.run_cmd("sudo systemctl stop apt-daily.service || true")
host.run_cmd("sudo systemctl stop unattended-upgrades.service || true")
host.run_cmd(
"while systemctl is-active --quiet apt-daily.service; do sleep 1; done"
)
host.run_cmd(
"while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done"
)
host.run_cmd("sudo apt-get update")
time.sleep(3)
host.run_cmd("sudo apt-get update")
def install_condaforge(
host: RemoteHost, suffix: str = "latest/download/Miniforge3-Linux-aarch64.sh"
) -> None:
print("Install conda-forge")
host.run_cmd(f"curl -OL https://github.com/conda-forge/miniforge/releases/{suffix}")
host.run_cmd(f"sh -f {os.path.basename(suffix)} -b")
host.run_cmd(f"rm -f {os.path.basename(suffix)}")
if host.using_docker():
host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc")
else:
host.run_cmd(
[
"sed",
"-i",
"'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH'",
".bashrc",
]
)
def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None:
if python_version == "3.6":
# Python-3.6 EOLed and not compatible with conda-4.11
install_condaforge(
host, suffix="download/4.10.3-10/Miniforge3-4.10.3-10-Linux-aarch64.sh"
)
host.run_cmd(f"conda install -y python={python_version} numpy pyyaml")
else:
install_condaforge(
host, suffix="download/4.11.0-4/Miniforge3-4.11.0-4-Linux-aarch64.sh"
)
# Pytorch-1.10 or older are not compatible with setuptools=59.6 or newer
host.run_cmd(
f"conda install -y python={python_version} numpy pyyaml setuptools>=59.5.0"
)
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
host.run_cmd("pip3 install auditwheel")
host.run_cmd(
"conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf"
)
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as tmp:
tmp.write(embed_library_script.encode("utf-8"))
tmp.flush()
host.upload_file(tmp.name, "embed_library.py")
print("Embedding libgomp into wheel")
if host.using_docker():
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
else:
host.run_cmd(f"python3 embed_library.py {wheel_name}")
def checkout_repo(
host: RemoteHost,
*,
branch: str = "main",
url: str,
git_clone_flags: str,
mapping: dict[str, tuple[str, str]],
) -> Optional[str]:
for prefix in mapping:
if not branch.startswith(prefix):
continue
tag = f"v{mapping[prefix][0]}-{mapping[prefix][1]}"
host.run_cmd(f"git clone {url} -b {tag} {git_clone_flags}")
return mapping[prefix][0]
host.run_cmd(f"git clone {url} -b {branch} {git_clone_flags}")
return None
def build_torchvision(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str,
run_smoke_tests: bool = True,
) -> str:
print("Checking out TorchVision repo")
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/vision",
git_clone_flags=git_clone_flags,
mapping={
"v1.7.1": ("0.8.2", "rc2"),
"v1.8.0": ("0.9.0", "rc3"),
"v1.8.1": ("0.9.1", "rc1"),
"v1.9.0": ("0.10.0", "rc1"),
"v1.10.0": ("0.11.1", "rc1"),
"v1.10.1": ("0.11.2", "rc1"),
"v1.10.2": ("0.11.3", "rc1"),
"v1.11.0": ("0.12.0", "rc1"),
"v1.12.0": ("0.13.0", "rc4"),
"v1.12.1": ("0.13.1", "rc6"),
"v1.13.0": ("0.14.0", "rc4"),
"v1.13.1": ("0.14.1", "rc2"),
"v2.0.0": ("0.15.1", "rc2"),
"v2.0.1": ("0.15.2", "rc2"),
},
)
print("Building TorchVision wheel")
# Please note libnpg and jpeg are required to build image.so extension
if use_conda:
host.run_cmd("conda install -y libpng jpeg")
# Remove .so files to force static linking
host.run_cmd(
"rm miniforge3/lib/libpng.so miniforge3/lib/libpng16.so miniforge3/lib/libjpeg.so"
)
# And patch setup.py to include libz dependency for libpng
host.run_cmd(
[
'sed -i -e \'s/image_link_flags\\.append("png")/image_link_flags += ["png", "z"]/\' vision/setup.py'
]
)
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"]
).strip()
if len(version) == 0:
# In older revisions, version was embedded in setup.py
version = (
host.check_output(["grep", '"version = \'"', "vision/setup.py"])
.strip()
.split("'")[1][:-2]
)
build_date = (
host.check_output("cd vision && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd vision && {build_vars} python3 -m build --wheel --no-isolation")
vision_wheel_name = host.list_dir("vision/dist")[0]
embed_libgomp(host, use_conda, os.path.join("vision", "dist", vision_wheel_name))
print("Copying TorchVision wheel")
host.download_wheel(os.path.join("vision", "dist", vision_wheel_name))
if run_smoke_tests:
host.run_cmd(
f"pip3 install {os.path.join('vision', 'dist', vision_wheel_name)}"
)
host.run_cmd("python3 vision/test/smoke_test.py")
print("Delete vision checkout")
host.run_cmd("rm -rf vision")
return vision_wheel_name
def build_torchdata(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchData repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/data",
git_clone_flags=git_clone_flags,
mapping={
"v1.13.1": ("0.5.1", ""),
"v2.0.0": ("0.6.0", "rc5"),
"v2.0.1": ("0.6.1", "rc1"),
},
)
print("Building TorchData wheel")
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f data/version.txt ]; then cat data/version.txt; fi"]
).strip()
build_date = (
host.check_output("cd data && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd data && {build_vars} python3 -m build --wheel --no-isolation")
wheel_name = host.list_dir("data/dist")[0]
embed_libgomp(host, use_conda, os.path.join("data", "dist", wheel_name))
print("Copying TorchData wheel")
host.download_wheel(os.path.join("data", "dist", wheel_name))
return wheel_name
def build_torchtext(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchText repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/text",
git_clone_flags=git_clone_flags,
mapping={
"v1.9.0": ("0.10.0", "rc1"),
"v1.10.0": ("0.11.0", "rc2"),
"v1.10.1": ("0.11.1", "rc1"),
"v1.10.2": ("0.11.2", "rc1"),
"v1.11.0": ("0.12.0", "rc1"),
"v1.12.0": ("0.13.0", "rc2"),
"v1.12.1": ("0.13.1", "rc5"),
"v1.13.0": ("0.14.0", "rc3"),
"v1.13.1": ("0.14.1", "rc1"),
"v2.0.0": ("0.15.1", "rc2"),
"v2.0.1": ("0.15.2", "rc2"),
},
)
print("Building TorchText wheel")
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f text/version.txt ]; then cat text/version.txt; fi"]
).strip()
build_date = (
host.check_output("cd text && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd text && {build_vars} python3 -m build --wheel --no-isolation")
wheel_name = host.list_dir("text/dist")[0]
embed_libgomp(host, use_conda, os.path.join("text", "dist", wheel_name))
print("Copying TorchText wheel")
host.download_wheel(os.path.join("text", "dist", wheel_name))
return wheel_name
def build_torchaudio(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchAudio repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/audio",
git_clone_flags=git_clone_flags,
mapping={
"v1.9.0": ("0.9.0", "rc2"),
"v1.10.0": ("0.10.0", "rc5"),
"v1.10.1": ("0.10.1", "rc1"),
"v1.10.2": ("0.10.2", "rc1"),
"v1.11.0": ("0.11.0", "rc1"),
"v1.12.0": ("0.12.0", "rc3"),
"v1.12.1": ("0.12.1", "rc5"),
"v1.13.0": ("0.13.0", "rc4"),
"v1.13.1": ("0.13.1", "rc2"),
"v2.0.0": ("2.0.1", "rc3"),
"v2.0.1": ("2.0.2", "rc2"),
},
)
print("Building TorchAudio wheel")
build_vars = ""
if branch == "nightly":
version = (
host.check_output(["grep", '"version = \'"', "audio/setup.py"])
.strip()
.split("'")[1][:-2]
)
build_date = (
host.check_output("cd audio && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(
f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \
&& ./packaging/ffmpeg/build.sh \
&& {build_vars} python3 -m build --wheel --no-isolation"
)
wheel_name = host.list_dir("audio/dist")[0]
embed_libgomp(host, use_conda, os.path.join("audio", "dist", wheel_name))
print("Copying TorchAudio wheel")
host.download_wheel(os.path.join("audio", "dist", wheel_name))
return wheel_name
def configure_system(
host: RemoteHost,
*,
compiler: str = "gcc-8",
use_conda: bool = True,
python_version: str = "3.8",
) -> None:
if use_conda:
install_condaforge_python(host, python_version)
print("Configuring the system")
if not host.using_docker():
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip")
else:
host.run_cmd("yum install -y sudo")
host.run_cmd("conda install -y ninja scons")
if not use_conda:
host.run_cmd(
"sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip"
)
host.run_cmd("pip3 install dataclasses typing-extensions")
if not use_conda:
print("Installing Cython + numpy from PyPy")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
def build_domains(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> tuple[str, str, str, str]:
vision_wheel_name = build_torchvision(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
audio_wheel_name = build_torchaudio(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
data_wheel_name = build_torchdata(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
text_wheel_name = build_torchtext(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
return (vision_wheel_name, audio_wheel_name, data_wheel_name, text_wheel_name)
def start_build(
host: RemoteHost,
*,
branch: str = "main",
compiler: str = "gcc-8",
use_conda: bool = True,
python_version: str = "3.8",
pytorch_only: bool = False,
pytorch_build_number: Optional[str] = None,
shallow_clone: bool = True,
enable_mkldnn: bool = False,
) -> tuple[str, str, str, str, str]:
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
if host.using_docker() and not use_conda:
print("Auto-selecting conda option for docker images")
use_conda = True
if not host.using_docker():
print("Disable mkldnn for host builds")
enable_mkldnn = False
configure_system(
host, compiler=compiler, use_conda=use_conda, python_version=python_version
)
if host.using_docker():
print("Move libgfortant.a into a standard location")
# HACK: pypa gforntran.a is compiled without PIC, which leads to the following error
# libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17' # noqa: E501, B950
# Workaround by copying gfortran library from the host
host.run_ssh_cmd("sudo apt-get install -y gfortran-8")
host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8")
host.run_ssh_cmd(
[
"docker",
"cp",
"/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a",
f"{host.container_id}:/opt/rh/devtoolset-10/root/usr/lib/gcc/aarch64-redhat-linux/10/",
]
)
print("Checking out PyTorch repo")
host.run_cmd(
f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}"
)
host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh")
print("Building PyTorch wheel")
build_opts = ""
if pytorch_build_number is not None:
build_opts += f" -C--build-option=--build-number={pytorch_build_number}"
# Breakpad build fails on aarch64
build_vars = "USE_BREAKPAD=0 "
if branch == "nightly":
build_date = (
host.check_output("cd pytorch && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
version = host.check_output("cat pytorch/version.txt").strip()[:-2]
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1"
if branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
if enable_mkldnn:
host.run_cmd("pytorch/.ci/docker/common/install_acl.sh")
print("build pytorch with mkldnn+acl backend")
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
build_vars += " BLAS=OpenBLAS"
build_vars += " OpenBLAS_HOME=/opt/OpenBLAS"
build_vars += " ACL_ROOT_DIR=/acl"
host.run_cmd(
f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
)
print("Repair the wheel")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
ld_library_path = "/acl/build:$HOME/pytorch/build/lib"
host.run_cmd(
f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}"
)
print("replace the original wheel with the repaired one")
pytorch_repaired_wheel_name = host.list_dir("wheelhouse")[0]
host.run_cmd(
f"cp $HOME/wheelhouse/{pytorch_repaired_wheel_name} $HOME/pytorch/dist/{pytorch_wheel_name}"
)
else:
print("build pytorch without mkldnn backend")
host.run_cmd(
f"cd pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
)
print("Deleting build folder")
host.run_cmd("cd pytorch && rm -rf build")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
embed_libgomp(host, use_conda, os.path.join("pytorch", "dist", pytorch_wheel_name))
print("Copying the wheel")
host.download_wheel(os.path.join("pytorch", "dist", pytorch_wheel_name))
print("Installing PyTorch wheel")
host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}")
if pytorch_only:
return (pytorch_wheel_name, None, None, None, None)
domain_wheels = build_domains(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
return (pytorch_wheel_name, *domain_wheels)
embed_library_script = """
#!/usr/bin/env python3
from auditwheel.patcher import Patchelf
from auditwheel.wheeltools import InWheelCtx
from auditwheel.elfutils import elf_file_filter
from auditwheel.repair import copylib
from auditwheel.lddtree import lddtree
from subprocess import check_call
import os
import shutil
import sys
from tempfile import TemporaryDirectory
def replace_tag(filename):
with open(filename, 'r') as f:
lines = f.read().split("\\n")
for i,line in enumerate(lines):
if not line.startswith("Tag: "):
continue
lines[i] = line.replace("-linux_", "-manylinux2014_")
print(f'Updated tag from {line} to {lines[i]}')
with open(filename, 'w') as f:
f.write("\\n".join(lines))
class AlignedPatchelf(Patchelf):
def set_soname(self, file_name: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name])
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name])
def embed_library(whl_path, lib_soname, update_tag=False):
patcher = AlignedPatchelf()
out_dir = TemporaryDirectory()
whl_name = os.path.basename(whl_path)
tmp_whl_name = os.path.join(out_dir.name, whl_name)
with InWheelCtx(whl_path) as ctx:
torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib')
ctx.out_wheel=tmp_whl_name
new_lib_path, new_lib_soname = None, None
for filename, elf in elf_file_filter(ctx.iter_files()):
if not filename.startswith('torch/lib'):
continue
libtree = lddtree(filename)
if lib_soname not in libtree['needed']:
continue
lib_path = libtree['libs'][lib_soname]['path']
if lib_path is None:
print(f"Can't embed {lib_soname} as it could not be found")
break
if lib_path.startswith(torchlib_path):
continue
if new_lib_path is None:
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
patcher.replace_needed(filename, lib_soname, new_lib_soname)
print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}')
if update_tag:
# Add manylinux2014 tag
for filename in ctx.iter_files():
if os.path.basename(filename) != 'WHEEL':
continue
replace_tag(filename)
shutil.move(tmp_whl_name, whl_path)
if __name__ == '__main__':
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
"""
def run_tests(host: RemoteHost, whl: str, branch="main") -> None:
print("Configuring the system")
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y python3-pip git")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
host.upload_file(whl, ".")
host.run_cmd(f"sudo pip3 install {whl}")
host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'")
host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch")
host.run_cmd("cd pytorch/test; python3 test_torch.py -v")
def get_instance_name(instance) -> Optional[str]:
if instance.tags is None:
return None
for tag in instance.tags:
if tag["Key"] == "Name":
return tag["Value"]
return None
def list_instances(instance_type: str) -> None:
print(f"All instances of type {instance_type}")
for instance in ec2_instances_of_type(instance_type):
ifaces = instance.network_interfaces
az = ifaces[0].subnet.availability_zone if len(ifaces) > 0 else None
print(
f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']} {az}"
)
def terminate_instances(instance_type: str) -> None:
print(f"Terminating all instances of type {instance_type}")
instances = list(ec2_instances_of_type(instance_type))
for instance in instances:
print(f"Terminating {instance.id}")
instance.terminate()
print("Waiting for termination to complete")
for instance in instances:
instance.wait_until_terminated()
def parse_arguments():
from argparse import ArgumentParser
parser = ArgumentParser("Build and test AARCH64 wheels using EC2")
parser.add_argument("--key-name", type=str)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")
parser.add_argument("--test-only", type=str)
group = parser.add_mutually_exclusive_group()
group.add_argument("--os", type=str, choices=list(os_amis.keys()))
group.add_argument("--ami", type=str)
parser.add_argument(
"--python-version",
type=str,
choices=[f"3.{d}" for d in range(6, 12)],
default=None,
)
parser.add_argument("--alloc-instance", action="store_true")
parser.add_argument("--list-instances", action="store_true")
parser.add_argument("--pytorch-only", action="store_true")
parser.add_argument("--keep-running", action="store_true")
parser.add_argument("--terminate-instances", action="store_true")
parser.add_argument("--instance-type", type=str, default="t4g.2xlarge")
parser.add_argument("--ebs-size", type=int, default=50)
parser.add_argument("--branch", type=str, default="main")
parser.add_argument("--use-docker", action="store_true")
parser.add_argument(
"--compiler",
type=str,
choices=["gcc-7", "gcc-8", "gcc-9", "clang"],
default="gcc-8",
)
parser.add_argument("--use-torch-from-pypi", action="store_true")
parser.add_argument("--pytorch-build-number", type=str, default=None)
parser.add_argument("--disable-mkldnn", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
ami = (
args.ami
if args.ami is not None
else os_amis[args.os]
if args.os is not None
else ubuntu20_04_ami
)
keyfile_path, key_name = compute_keyfile_path(args.key_name)
if args.list_instances:
list_instances(args.instance_type)
sys.exit(0)
if args.terminate_instances:
terminate_instances(args.instance_type)
sys.exit(0)
if len(key_name) == 0:
raise RuntimeError("""
Cannot start build without key_name, please specify
--key-name argument or AWS_KEY_NAME environment variable.""")
if len(keyfile_path) == 0 or not os.path.exists(keyfile_path):
raise RuntimeError(f"""
Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please
check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""")
# Starting the instance
inst = start_instance(
key_name, ami=ami, instance_type=args.instance_type, ebs_size=args.ebs_size
)
instance_name = f"{args.key_name}-{args.os}"
if args.python_version is not None:
instance_name += f"-py{args.python_version}"
inst.create_tags(
DryRun=False,
Tags=[
{
"Key": "Name",
"Value": instance_name,
}
],
)
addr = inst.public_dns_name
wait_for_connection(addr, 22)
host = RemoteHost(addr, keyfile_path)
host.ami = ami
if args.use_docker:
update_apt_repo(host)
host.start_docker()
if args.test_only:
run_tests(host, args.test_only)
sys.exit(0)
if args.alloc_instance:
if args.python_version is None:
sys.exit(0)
install_condaforge_python(host, args.python_version)
sys.exit(0)
python_version = args.python_version if args.python_version is not None else "3.10"
if args.use_torch_from_pypi:
configure_system(host, compiler=args.compiler, python_version=python_version)
print("Installing PyTorch wheel")
host.run_cmd("pip3 install torch")
build_domains(
host, branch=args.branch, git_clone_flags=" --depth 1 --shallow-submodules"
)
else:
start_build(
host,
branch=args.branch,
compiler=args.compiler,
python_version=python_version,
pytorch_only=args.pytorch_only,
pytorch_build_number=args.pytorch_build_number,
enable_mkldnn=not args.disable_mkldnn,
)
if not args.keep_running:
print(f"Waiting for instance {inst.id} to terminate")
inst.terminate()
inst.wait_until_terminated()

View File

@ -1,87 +0,0 @@
#!/usr/bin/env python3
import os
import shutil
import sys
from subprocess import check_call
from tempfile import TemporaryDirectory
from auditwheel.elfutils import elf_file_filter
from auditwheel.lddtree import lddtree
from auditwheel.patcher import Patchelf
from auditwheel.repair import copylib
from auditwheel.wheeltools import InWheelCtx
def replace_tag(filename):
with open(filename) as f:
lines = f.read().split("\\n")
for i, line in enumerate(lines):
if not line.startswith("Tag: "):
continue
lines[i] = line.replace("-linux_", "-manylinux2014_")
print(f"Updated tag from {line} to {lines[i]}")
with open(filename, "w") as f:
f.write("\\n".join(lines))
class AlignedPatchelf(Patchelf):
def set_soname(self, file_name: str, new_soname: str) -> None:
check_call(
["patchelf", "--page-size", "65536", "--set-soname", new_soname, file_name]
)
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
check_call(
[
"patchelf",
"--page-size",
"65536",
"--replace-needed",
soname,
new_soname,
file_name,
]
)
def embed_library(whl_path, lib_soname, update_tag=False):
patcher = AlignedPatchelf()
out_dir = TemporaryDirectory()
whl_name = os.path.basename(whl_path)
tmp_whl_name = os.path.join(out_dir.name, whl_name)
with InWheelCtx(whl_path) as ctx:
torchlib_path = os.path.join(ctx._tmpdir.name, "torch", "lib")
ctx.out_wheel = tmp_whl_name
new_lib_path, new_lib_soname = None, None
for filename, _ in elf_file_filter(ctx.iter_files()):
if not filename.startswith("torch/lib"):
continue
libtree = lddtree(filename)
if lib_soname not in libtree["needed"]:
continue
lib_path = libtree["libs"][lib_soname]["path"]
if lib_path is None:
print(f"Can't embed {lib_soname} as it could not be found")
break
if lib_path.startswith(torchlib_path):
continue
if new_lib_path is None:
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
patcher.replace_needed(filename, lib_soname, new_lib_soname)
print(f"Replacing {lib_soname} with {new_lib_soname} for {filename}")
if update_tag:
# Add manylinux2014 tag
for filename in ctx.iter_files():
if os.path.basename(filename) != "WHEEL":
continue
replace_tag(filename)
shutil.move(tmp_whl_name, whl_path)
if __name__ == "__main__":
embed_library(
sys.argv[1], "libgomp.so.1", len(sys.argv) > 2 and sys.argv[2] == "--update-tag"
)

View File

@ -125,10 +125,10 @@ case "$tag" in
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
;;
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks)
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks)
CUDA_VERSION=12.8.1
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11
GCC_VERSION=9
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
@ -146,6 +146,16 @@ case "$tag" in
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
;;
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9)
CUDA_VERSION=12.8.1
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=9
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
;;
pytorch-linux-jammy-py3-clang12-onnx)
ANACONDA_PYTHON_VERSION=3.10
CLANG_VERSION=12
@ -178,7 +188,7 @@ case "$tag" in
fi
GCC_VERSION=11
VISION=yes
ROCM_VERSION=7.1
ROCM_VERSION=7.0
NINJA_VERSION=1.9.0
TRITON=yes
KATEX=yes

View File

@ -60,16 +60,14 @@ EOF
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev
fi
if [[ $(ver $ROCM_VERSION) -lt $(ver 7.1) ]]; then
# precompiled miopen kernels added in ROCm 3.5, renamed in ROCm 5.5, removed in ROCm 7.1
# search for all unversioned packages
# if search fails it will abort this script; use true to avoid case where search fails
MIOPENHIPGFX=$(apt-cache search --names-only miopen-hip-gfx | awk '{print $1}' | grep -F -v . || true)
if [[ "x${MIOPENHIPGFX}" = x ]]; then
echo "miopen-hip-gfx package not available" && exit 1
else
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENHIPGFX}
fi
# precompiled miopen kernels added in ROCm 3.5, renamed in ROCm 5.5
# search for all unversioned packages
# if search fails it will abort this script; use true to avoid case where search fails
MIOPENHIPGFX=$(apt-cache search --names-only miopen-hip-gfx | awk '{print $1}' | grep -F -v . || true)
if [[ "x${MIOPENHIPGFX}" = x ]]; then
echo "miopen-hip-gfx package not available" && exit 1
else
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENHIPGFX}
fi
# ROCm 6.0 had a regression where journal_mode was enabled on the kdb files resulting in permission errors at runtime

View File

@ -12,8 +12,8 @@ function do_install() {
rocm_version_nodot=${rocm_version//./}
# https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
# post merge of https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2"
rocm_dir="/opt/rocm"

View File

@ -402,6 +402,3 @@ scikit-build==0.18.1
pyre-extensions==0.0.32
tabulate==0.9.0
#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI
Jinja2==3.1.6
#Description: required for torch.distributed.debug

View File

@ -4,14 +4,17 @@ set -ex
SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Source the common build script for architecture-specific configurations (MKLDNN, ACL, etc.)
source "${SCRIPTPATH}/../pytorch/build.sh" || true
case "${GPU_ARCH_TYPE:-BLANK}" in
cuda)
cuda | cuda-aarch64)
bash "${SCRIPTPATH}/build_cuda.sh"
;;
rocm)
bash "${SCRIPTPATH}/build_rocm.sh"
;;
cpu | cpu-cxx11-abi | cpu-s390x)
cpu | cpu-cxx11-abi | cpu-aarch64 | cpu-s390x)
bash "${SCRIPTPATH}/build_cpu.sh"
;;
xpu)

View File

@ -18,12 +18,31 @@ retry () {
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
}
# Detect architecture first
ARCH=$(uname -m)
echo "Detected architecture: $ARCH"
PLATFORM=""
# TODO move this into the Docker images
OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release)
if [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
retry yum install -q -y zip openssl
PLATFORM="manylinux_2_28_x86_64"
# Set platform based on architecture
case $ARCH in
x86_64)
PLATFORM="manylinux_2_28_x86_64"
;;
aarch64)
PLATFORM="manylinux_2_28_aarch64"
;;
s390x)
PLATFORM="manylinux_2_28_s390x"
;;
*)
echo "Unsupported architecture: $ARCH"
exit 1
;;
esac
elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
retry dnf install -q -y zip openssl
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
@ -38,6 +57,8 @@ else
exit 1
fi
echo "Platform set to: $PLATFORM"
# We use the package name to test the package by passing this to 'pip install'
# This is the env variable that setup.py uses to name the package. Note that
# pip 'normalizes' the name first by changing all - to _
@ -299,8 +320,8 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
# ROCm workaround for roctracer dlopens
if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then
patchedpath=$(fname_without_so_number $destpath)
# Keep the so number for XPU dependencies and libgomp.so.1 to avoid twice load
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" ]]; then
# Keep the so number for XPU dependencies, libgomp.so.1, ACL libraries, and NVPL libraries to avoid twice load
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" || "$filename" == libarm_compute* || "$filename" == libnvpl* || "$filename" == "libgfortran.so.5" ]]; then
patchedpath=$destpath
else
patchedpath=$(fname_with_sha256 $destpath)
@ -346,9 +367,22 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
done
# create Manylinux 2_28 tag this needs to happen before regenerate the RECORD
if [[ $PLATFORM == "manylinux_2_28_x86_64" && $GPU_ARCH_TYPE != "cpu-s390x" && $GPU_ARCH_TYPE != "xpu" ]]; then
# Support all architectures (x86_64, aarch64, s390x)
if [[ "$IS_MANYLINUX2_28" == "1" && $GPU_ARCH_TYPE != "xpu" ]]; then
wheel_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/WHEEL/g')
sed -i -e s#linux_x86_64#"${PLATFORM}"# $wheel_file;
echo "Updating wheel tag for $ARCH architecture"
# Replace linux_* with manylinux_2_28_* based on architecture
case $ARCH in
x86_64)
sed -i -e 's#linux_x86_64#manylinux_2_28_x86_64#g' $wheel_file
;;
aarch64)
sed -i -e 's#linux_aarch64#manylinux_2_28_aarch64#g' $wheel_file
;;
s390x)
sed -i -e 's#linux_s390x#manylinux_2_28_s390x#g' $wheel_file
;;
esac
fi
# regenerate the RECORD file with new hashes

View File

@ -15,6 +15,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
EXTRA_CAFFE2_CMAKE_FLAGS=()
fi
# Detect architecture
ARCH=$(uname -m)
echo "Building CPU wheel for architecture: $ARCH"
WHEELHOUSE_DIR="wheelhousecpu"
LIBTORCH_HOUSE_DIR="libtorch_housecpu"
if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then
@ -34,8 +38,10 @@ elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
LIBGOMP_PATH="/usr/lib64/libgomp.so.1"
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
if [[ "$(uname -m)" == "s390x" ]]; then
if [[ "$ARCH" == "s390x" ]]; then
LIBGOMP_PATH="/usr/lib/s390x-linux-gnu/libgomp.so.1"
elif [[ "$ARCH" == "aarch64" ]]; then
LIBGOMP_PATH="/usr/lib/aarch64-linux-gnu/libgomp.so.1"
else
LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1"
fi
@ -49,6 +55,34 @@ DEPS_SONAME=(
"libgomp.so.1"
)
# Add ARM-specific library dependencies for CPU builds
if [[ "$ARCH" == "aarch64" ]]; then
echo "Adding ARM-specific CPU library dependencies"
# ARM Compute Library (if available)
if [[ -d "/acl/build" ]]; then
echo "Adding ARM Compute Library for CPU"
DEPS_LIST+=(
"/acl/build/libarm_compute.so"
"/acl/build/libarm_compute_graph.so"
)
DEPS_SONAME+=(
"libarm_compute.so"
"libarm_compute_graph.so"
)
fi
# ARM system libraries
DEPS_LIST+=(
"/usr/lib64/libgfortran.so.5"
"/opt/OpenBLAS/lib/libopenblas.so.0"
)
DEPS_SONAME+=(
"libgfortran.so.5"
"libopenblas.so.0"
)
fi
rm -rf /usr/local/cuda*
SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )"

View File

@ -29,6 +29,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
EXTRA_CAFFE2_CMAKE_FLAGS=()
fi
# Detect architecture
ARCH=$(uname -m)
echo "Building for architecture: $ARCH"
# Determine CUDA version and architectures to build for
#
# NOTE: We should first check `DESIRED_CUDA` when determining `CUDA_VERSION`,
@ -53,34 +57,60 @@ fi
cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
# Function to remove architectures from a list
remove_archs() {
local result="$1"
shift
for arch in "$@"; do
result="${result//${arch};/}"
done
echo "$result"
}
# Function to filter CUDA architectures for aarch64
# aarch64 ARM GPUs only support certain compute capabilities
# Keep: 8.0 (A100), 9.0+ (Hopper, Grace Hopper, newer)
# Remove: < 8.0 (no ARM GPUs), 8.6 (x86_64 RTX 3090/A6000 only)
filter_aarch64_archs() {
local arch_list="$1"
# Explicitly remove architectures not needed on aarch64
arch_list=$(remove_archs "$arch_list" "5.0" "6.0" "7.0" "7.5" "8.6")
echo "$arch_list"
}
# Base: Common architectures across all modern CUDA versions
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0"
case ${CUDA_VERSION} in
#removing sm_50-sm_60 as these architectures are deprecated in CUDA 12.8/9 and will be removed in future releases
#however we would like to keep sm_70 architecture see: https://github.com/pytorch/pytorch/issues/157517
12.8)
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0"
;;
12.9)
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0+PTX"
# WAR to resolve the ld error in libtorch build with CUDA 12.9
12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;${TORCH_CUDA_ARCH_LIST}" ;; # Only 12.6 includes Legacy Maxwell/Pascal that will be removed in future releases
12.8) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0" ;; # +Hopper/Blackwell support
12.9) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0+PTX" # +Hopper/Blackwell support + PTX for forward compatibility
if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then
TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX"
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//7.0;/}" # Remove 7.0 to resolve the ld error
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//8.6;/}" # Remove 8.6 for libtorch
fi
;;
13.0)
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX"
;;
12.6)
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0"
;;
*)
echo "unknown cuda version $CUDA_VERSION"
exit 1
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;$([[ "$ARCH" == "aarch64" ]] && echo "11.0;" || echo "")12.0+PTX"
export TORCH_NVCC_FLAGS="-compress-mode=size"
export BUILD_BUNDLE_PTXAS=1
;;
*) echo "unknown cuda version $CUDA_VERSION"; exit 1 ;;
esac
# Filter for aarch64: Remove < 8.0 and 8.6
[[ "$ARCH" == "aarch64" ]] && TORCH_CUDA_ARCH_LIST=$(filter_aarch64_archs "$TORCH_CUDA_ARCH_LIST")
echo "TORCH_CUDA_ARCH_LIST set to: $TORCH_CUDA_ARCH_LIST"
export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}
echo "${TORCH_CUDA_ARCH_LIST}"
# Disable MAGMA for aarch64 as pre-built libraries are x86-64 only
if [[ "$ARCH" == "aarch64" ]]; then
echo "Disabling MAGMA for aarch64 architecture"
export USE_MAGMA=0
fi
# Package directories
WHEELHOUSE_DIR="wheelhouse$cuda_version_nodot"
LIBTORCH_HOUSE_DIR="libtorch_house$cuda_version_nodot"
@ -244,6 +274,51 @@ else
exit 1
fi
# Add ARM-specific library dependencies
if [[ "$ARCH" == "aarch64" ]]; then
echo "Adding ARM-specific library dependencies"
# ARM Compute Library (if available)
if [[ -d "/acl/build" ]]; then
echo "Adding ARM Compute Library"
DEPS_LIST+=(
"/acl/build/libarm_compute.so"
"/acl/build/libarm_compute_graph.so"
)
DEPS_SONAME+=(
"libarm_compute.so"
"libarm_compute_graph.so"
)
fi
# ARM system libraries
DEPS_LIST+=(
"/lib64/libgomp.so.1"
"/usr/lib64/libgfortran.so.5"
)
DEPS_SONAME+=(
"libgomp.so.1"
"libgfortran.so.5"
)
# NVPL libraries (ARM optimized BLAS/LAPACK)
if [[ -d "/usr/local/lib" && -f "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ]]; then
echo "Adding NVPL libraries for ARM"
DEPS_LIST+=(
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0"
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0"
"/usr/local/lib/libnvpl_lapack_core.so.0"
"/usr/local/lib/libnvpl_blas_core.so.0"
)
DEPS_SONAME+=(
"libnvpl_lapack_lp64_gomp.so.0"
"libnvpl_blas_lp64_gomp.so.0"
"libnvpl_lapack_core.so.0"
"libnvpl_blas_core.so.0"
)
fi
fi
# run_tests.sh requires DESIRED_CUDA to know what tests to exclude
export DESIRED_CUDA="$cuda_version_nodot"
@ -251,9 +326,11 @@ export DESIRED_CUDA="$cuda_version_nodot"
rm -rf /usr/local/cuda || true
ln -s "/usr/local/cuda-${CUDA_VERSION}" /usr/local/cuda
# Switch `/usr/local/magma` to the desired CUDA version
rm -rf /usr/local/magma || true
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
# Switch `/usr/local/magma` to the desired CUDA version (skip for aarch64)
if [[ "$ARCH" != "aarch64" ]]; then
rm -rf /usr/local/magma || true
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
fi
export CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) # 10.0.130
export CUDA_VERSION_SHORT=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev | cut -f1,2 -d".") # 10.0

View File

@ -21,87 +21,3 @@ if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]]; then
fi
mkdir -p "$pytest_reports_dir" || true
##########################################
# copied from .ci/pytorch/common_utils.sh
##########################################
function get_pinned_commit() {
cat .github/ci_commit_pins/"${1}".txt
}
function pip_install_whl() {
# This is used to install PyTorch and other build artifacts wheel locally
# without using any network connection
# Convert the input arguments into an array
local args=("$@")
# Check if the first argument contains multiple paths separated by spaces
if [[ "${args[0]}" == *" "* ]]; then
# Split the string by spaces into an array
IFS=' ' read -r -a paths <<< "${args[0]}"
# Loop through each path and install individually
for path in "${paths[@]}"; do
echo "Installing $path"
python3 -mpip install --no-index --no-deps "$path"
done
else
# Loop through each argument and install individually
for path in "${args[@]}"; do
echo "Installing $path"
python3 -mpip install --no-index --no-deps "$path"
done
fi
}
function pip_build_and_install() {
local build_target=$1
local wheel_dir=$2
local found_whl=0
for file in "${wheel_dir}"/*.whl
do
if [[ -f "${file}" ]]; then
found_whl=1
break
fi
done
# Build the wheel if it doesn't exist
if [ "${found_whl}" == "0" ]; then
python3 -m pip wheel \
--no-build-isolation \
--no-deps \
-w "${wheel_dir}" \
"${build_target}"
fi
for file in "${wheel_dir}"/*.whl
do
pip_install_whl "${file}"
done
}
function install_torchvision() {
local orig_preload
local commit
commit=$(get_pinned_commit vision)
orig_preload=${LD_PRELOAD}
if [ -n "${LD_PRELOAD}" ]; then
# Silence dlerror to work-around glibc ASAN bug, see https://sourceware.org/bugzilla/show_bug.cgi?id=27653#c9
echo 'char* dlerror(void) { return "";}'|gcc -fpic -shared -o "${HOME}/dlerror.so" -x c -
LD_PRELOAD=${orig_preload}:${HOME}/dlerror.so
fi
if [[ "${BUILD_ENVIRONMENT}" == *cuda* ]]; then
# Not sure if both are needed, but why not
export FORCE_CUDA=1
export WITH_CUDA=1
fi
pip_build_and_install "git+https://github.com/pytorch/vision.git@${commit}" dist/vision
if [ -n "${LD_PRELOAD}" ]; then
LD_PRELOAD=${orig_preload}
fi
}

View File

@ -19,7 +19,7 @@ git config --global --add safe.directory /var/lib/jenkins/workspace
if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
# TODO: This can be removed later once vision is also part of the Docker image
install_torchvision
pip install -q --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)"
# JIT C++ extensions require ninja, so put it into PATH.
export PATH="/var/lib/jenkins/.local/bin:$PATH"
# NB: ONNX test is fast (~15m) so it's ok to retry it few more times to avoid any flaky issue, we

View File

@ -86,10 +86,20 @@ else
fi
fi
# Enable MKLDNN with ARM Compute Library for ARM builds
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
export USE_MKLDNN=1
# ACL is required for aarch64 builds
if [[ ! -d "/acl" ]]; then
echo "ERROR: ARM Compute Library not found at /acl"
echo "ACL is required for aarch64 builds. Check Docker image setup."
exit 1
fi
export USE_MKLDNN_ACL=1
export ACL_ROOT_DIR=/acl
echo "ARM Compute Library enabled for MKLDNN: ACL_ROOT_DIR=/acl"
fi
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then

View File

@ -1250,97 +1250,6 @@ test_custom_script_ops() {
assert_git_not_dirty
}
test_libtorch_agnostic_targetting() {
echo "Testing libtorch_agnostic runs correctly on TORCH_TARGET_VERSION"
REPO_DIR=$(pwd)
WHEEL_DIR="${REPO_DIR}/test/cpp_extensions/.wheels"
# Build wheel with current PyTorch (this has TORCH_TARGET_VERSION 2_9_0)
echo "Building 2.9 extension wheel with current PyTorch..."
pushd test/cpp_extensions/libtorch_agnostic_2_9_extension
time python setup.py bdist_wheel
# Save the wheel
mkdir -p "$WHEEL_DIR"
cp dist/*.whl "$WHEEL_DIR/"
WHEEL_FILE=$(find "$WHEEL_DIR" -maxdepth 1 -name "*.whl" -type f | head -1)
echo "Built wheel: $(basename "$WHEEL_FILE")"
popd
# Create venv and install PyTorch 2.9
python -m venv venv_pytorch_2_9
# shellcheck disable=SC1091
. venv_pytorch_2_9/bin/activate
# Clear PYTHONPATH to avoid using the development PyTorch
echo "Clearing PYTHONPATH to use only venv packages..."
unset PYTHONPATH
# Upgrade pip to latest version
echo "Upgrading pip to latest version..."
pip install --upgrade pip
pip --version
echo "Installing PyTorch 2.9..."
# Install from release channel only
PYTORCH_VERSION="2.9.0"
# Extract CUDA version from BUILD_ENVIRONMENT (e.g., "cuda12.1" -> "cu121")
if [[ "$BUILD_ENVIRONMENT" =~ cuda([0-9]+)\.([0-9]+) ]]; then
CUDA_MAJOR="${BASH_REMATCH[1]}"
CUDA_MINOR="${BASH_REMATCH[2]}"
CUDA_VERSION="cu${CUDA_MAJOR}${CUDA_MINOR}"
echo " Detected CUDA ${CUDA_MAJOR}.${CUDA_MINOR} from BUILD_ENVIRONMENT, using ${CUDA_VERSION}"
else
# Default to CPU build
CUDA_VERSION="cpu"
echo " No CUDA detected in BUILD_ENVIRONMENT, using CPU build"
fi
if pip install torch=="${PYTORCH_VERSION}" --index-url https://download.pytorch.org/whl/${CUDA_VERSION}/; then
echo "Installed PyTorch ${PYTORCH_VERSION} from release channel (${CUDA_VERSION})"
else
echo " FAILED to install PyTorch 2.9.0 from release channel"
echo " URL: https://download.pytorch.org/whl/${CUDA_VERSION}/"
deactivate
rm -rf venv_pytorch_2_9
return 1
fi
INSTALLED_VERSION=$(python -c "import torch; print(torch.__version__)" 2>/dev/null || echo "unknown")
echo " Installed version: $INSTALLED_VERSION"
# Install test dependencies
echo "Installing test dependencies..."
pip install expecttest numpy unittest-xml-reporting
# Install the pre-built wheel
echo ""
echo "Installing pre-built 2.9 extension wheel (built with PyTorch 2.10)..."
pip install "$WHEEL_FILE"
echo "Installed $(basename "$WHEEL_FILE") into PyTorch 2.9 environment"
# Run tests with PyTorch 2.9 runtime (2.10 tests will be skipped automatically)
echo ""
echo "Running tests with PyTorch 2.9 runtime (using wheel built on PyTorch 2.10)..."
if time python test/cpp_extensions/test_libtorch_agnostic.py -v; then
echo ""
echo " Wheel built with current torch and TORCH_TARGET_VERSION 2_9_0 works with PyTorch 2.9 runtime!"
else
echo "targeting test failed"
deactivate
rm -rf venv_pytorch_2_9 "$WHEEL_DIR"
return 1
fi
deactivate
rm -rf venv_pytorch_2_9 "$WHEEL_DIR"
assert_git_not_dirty
}
test_jit_hooks() {
echo "Testing jit hooks in cpp"
HOOK_BUILD="${CUSTOM_TEST_ARTIFACT_BUILD_DIR}/jit-hook-build"
@ -1813,8 +1722,6 @@ elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]];
elif [[ "${TEST_CONFIG}" == *backward* ]]; then
test_forward_backward_compatibility
# Do NOT add tests after bc check tests, see its comment.
elif [[ "${TEST_CONFIG}" == *libtorch_agnostic_targetting* ]]; then
test_libtorch_agnostic_targetting
elif [[ "${TEST_CONFIG}" == *xla* ]]; then
install_torchvision
build_xla

7
.github/labeler.yml vendored
View File

@ -91,6 +91,13 @@
"ciflow/trunk":
- .ci/docker/ci_commit_pins/triton.txt
"oncall: distributed":
- torch/csrc/distributed/**
- torch/distributed/**
- torch/nn/parallel/**
- test/distributed/**
- torch/testing/_internal/distributed/**
"release notes: distributed (checkpoint)":
- torch/distributed/checkpoint/**
- test/distributed/checkpoint/**

View File

@ -260,11 +260,8 @@ jobs:
"${DOCKER_IMAGE}"
)
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
if [[ ${BUILD_ENVIRONMENT} == *"aarch64"* ]]; then
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/aarch64_linux/aarch64_ci_build.sh"
else
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
fi
# Unified build script for all architectures (x86_64, aarch64, s390x)
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
- name: Chown artifacts
if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }}

View File

@ -23,7 +23,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '8.0 9.0'
test-matrix: |
@ -39,7 +39,7 @@ jobs:
needs: attn-microbenchmark-build
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }}
secrets: inherit
@ -51,7 +51,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'
test-matrix: |
@ -66,7 +66,7 @@ jobs:
needs: opmicrobenchmark-build-b200
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }}
test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only

View File

@ -52,7 +52,8 @@ jobs:
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11,
pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm,
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
pytorch-linux-jammy-py3.10-clang12,
pytorch-linux-jammy-py3.11-clang12,
@ -74,8 +75,7 @@ jobs:
pytorch-linux-jammy-py3-clang12-onnx,
pytorch-linux-jammy-linter,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter,
# TODO: Re-enable me when docker pin update happens
# pytorch-linux-jammy-py3-clang12-executorch,
pytorch-linux-jammy-py3-clang12-executorch,
pytorch-linux-jammy-py3.12-triton-cpu,
pytorch-linux-noble-riscv64-py3.12-gcc14
]

View File

@ -50,10 +50,9 @@ jobs:
matrix:
runner: [linux.rocm.gfx942.docker-cache]
docker-image: [
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}"
#"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
#"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
#"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
]
runs-on: "${{ matrix.runner }}"
steps:

View File

@ -30,14 +30,14 @@ jobs:
opt_out_experiments: lf
build:
name: cuda12.8-py3.10-gcc11-sm80
name: cuda12.8-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-build.yml
needs:
- get-default-label-prefix
with:
runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'
test-matrix: |
{ include: [
@ -46,11 +46,11 @@ jobs:
secrets: inherit
test:
name: cuda12.8-py3.10-gcc11-sm80
name: cuda12.8-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-test.yml
needs: build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
timeout-minutes: 720

View File

@ -27,14 +27,14 @@ jobs:
opt_out_experiments: lf
build:
name: cuda12.8-py3.10-gcc11-sm80
name: cuda12.8-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-build.yml
needs:
- get-default-label-prefix
with:
runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'
test-matrix: |
{ include: [
@ -47,11 +47,11 @@ jobs:
secrets: inherit
test:
name: cuda12.8-py3.10-gcc11-sm80
name: cuda12.8-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-test.yml
needs: build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
# disable monitor in perf tests for more investigation

View File

@ -80,7 +80,7 @@ jobs:
opt_out_experiments: lf
build:
name: cuda12.8-py3.10-gcc11-sm100
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
@ -90,8 +90,8 @@ jobs:
# from trunk. Also use a memory-intensive runner here because memory is
# usually the bottleneck
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '10.0'
test-matrix: |
{ include: [
@ -104,12 +104,12 @@ jobs:
secrets: inherit
test-periodically:
name: cuda12.8-py3.10-gcc11-sm100
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-test.yml
needs: build
if: github.event.schedule == '0 7 * * 1-6'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
@ -121,12 +121,12 @@ jobs:
secrets: inherit
test-weekly:
name: cuda12.8-py3.10-gcc11-sm100
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-test.yml
needs: build
if: github.event.schedule == '0 7 * * 0'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
@ -138,11 +138,11 @@ jobs:
secrets: inherit
test:
name: cuda12.8-py3.10-gcc11-sm100
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-test.yml
needs: build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}

View File

@ -95,8 +95,8 @@ jobs:
# from trunk. Also use a memory-intensive runner here because memory is
# usually the bottleneck
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '9.0'
test-matrix: |
{ include: [
@ -132,7 +132,7 @@ jobs:
needs: build
if: github.event.schedule == '15 0 * * 1-6'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
@ -149,7 +149,7 @@ jobs:
needs: build
if: github.event.schedule == '0 7 * * 0'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
@ -168,7 +168,7 @@ jobs:
# needs one round of benchmark
if: ${{ github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' }}
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90
dashboard-tag: training-${{ inputs.training || 'true' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cudagraphs-${{ inputs.cudagraphs || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'false' }}-aotinductor-${{ inputs.aotinductor || 'false' }}-maxautotune-${{ inputs.maxautotune || 'false' }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs || 'false' }}-cudagraphs_low_precision-${{ inputs.cudagraphs || 'false' }}
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}

View File

@ -80,15 +80,15 @@ jobs:
opt_out_experiments: lf
build:
name: cuda12.8-py3.10-gcc11-sm80
name: cuda12.8-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
# Every bit to make perf run faster helps
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'
test-matrix: |
{ include: [
@ -117,12 +117,12 @@ jobs:
secrets: inherit
test-nightly:
name: cuda12.8-py3.10-gcc11-sm80
name: cuda12.8-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-test.yml
needs: build
if: github.event.schedule == '0 7 * * 1-6'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
@ -133,12 +133,12 @@ jobs:
secrets: inherit
test-weekly:
name: cuda12.8-py3.10-gcc11-sm80
name: cuda12.8-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-test.yml
needs: build
if: github.event.schedule == '0 7 * * 0'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
@ -150,12 +150,12 @@ jobs:
secrets: inherit
test:
name: cuda12.8-py3.10-gcc11-sm80
name: cuda12.8-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-test.yml
needs: build
if: github.event_name == 'workflow_dispatch'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}

View File

@ -37,8 +37,8 @@ jobs:
needs: get-default-label-prefix
with:
runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0;8.6'
test-matrix: |
{ include: [
@ -76,7 +76,7 @@ jobs:
uses: ./.github/workflows/_linux-test.yml
needs: periodic-dynamo-benchmarks-build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86
docker-image: ${{ needs.periodic-dynamo-benchmarks-build.outputs.docker-image }}
test-matrix: ${{ needs.periodic-dynamo-benchmarks-build.outputs.test-matrix }}
secrets: inherit
@ -138,8 +138,8 @@ jobs:
- get-default-label-prefix
with:
runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'
test-matrix: |
{ include: [
@ -153,7 +153,7 @@ jobs:
uses: ./.github/workflows/_linux-test.yml
needs: inductor-smoke-build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.inductor-smoke-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-smoke-build.outputs.test-matrix }}
secrets: inherit

View File

@ -33,8 +33,8 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.6'
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
@ -52,7 +52,7 @@ jobs:
uses: ./.github/workflows/_linux-test.yml
needs: inductor-build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86
docker-image: ${{ needs.inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-build.outputs.test-matrix }}
secrets: inherit

View File

@ -49,8 +49,8 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.6'
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
@ -69,7 +69,7 @@ jobs:
uses: ./.github/workflows/_linux-test.yml
needs: inductor-build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86
docker-image: ${{ needs.inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-build.outputs.test-matrix }}
secrets: inherit

View File

@ -25,7 +25,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '8.0 9.0'
test-matrix: |
@ -41,7 +41,7 @@ jobs:
needs: opmicrobenchmark-build
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }}
secrets: inherit
@ -53,7 +53,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'
test-matrix: |
@ -68,7 +68,7 @@ jobs:
needs: opmicrobenchmark-build-b200
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }}
test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only

View File

@ -90,7 +90,6 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: 8.6
test-matrix: |
{ include: [
{ config: "nogpu_AVX512", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
@ -98,9 +97,7 @@ jobs:
{ config: "nogpu_AVX512", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "multigpu", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] },
{ config: "multigpu", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
]}
secrets: inherit
@ -116,14 +113,40 @@ jobs:
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-debug-build:
name: linux-jammy-cuda12.8-py3.10-gcc11-debug
linux-jammy-cuda12_8-py3_10-gcc9-build:
name: linux-jammy-cuda12.8-py3.10-gcc9
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-debug
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
build-environment: linux-jammy-cuda12.8-py3.10-gcc9
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
cuda-arch-list: 8.6
test-matrix: |
{ include: [
{ config: "multigpu", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] },
{ config: "multigpu", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] },
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc9-test:
name: linux-jammy-cuda12.8-py3.10-gcc9
uses: ./.github/workflows/_linux-test.yml
needs: linux-jammy-cuda12_8-py3_10-gcc9-build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc9
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc9-debug-build:
name: linux-jammy-cuda12.8-py3.10-gcc9-debug
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
cuda-arch-list: 8.9
test-matrix: |
{ include: [
@ -137,16 +160,16 @@ jobs:
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-debug-test:
name: linux-jammy-cuda12.8-py3.10-gcc11-debug
linux-jammy-cuda12_8-py3_10-gcc9-debug-test:
name: linux-jammy-cuda12.8-py3.10-gcc9-debug
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-debug-build
- linux-jammy-cuda12_8-py3_10-gcc9-debug-build
- target-determination
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-debug
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-debug-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-debug-build.outputs.test-matrix }}
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-debug-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-debug-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda13_0-py3_10-gcc11-build:

View File

@ -70,7 +70,6 @@ jobs:
{ config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
]}
secrets: inherit
@ -318,14 +317,14 @@ jobs:
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-inductor-build:
name: cuda12.8-py3.10-gcc11-sm75
linux-jammy-cuda12_8-py3_10-gcc9-inductor-build:
name: cuda12.8-py3.10-gcc9-sm75
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm75
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '7.5'
test-matrix: |
{ include: [
@ -333,14 +332,14 @@ jobs:
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-inductor-test:
name: cuda12.8-py3.10-gcc11-sm75
linux-jammy-cuda12_8-py3_10-gcc9-inductor-test:
name: cuda12.8-py3.10-gcc9-sm75
uses: ./.github/workflows/_linux-test.yml
needs: linux-jammy-cuda12_8-py3_10-gcc11-inductor-build
needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm75
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.test-matrix }}
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
secrets: inherit
linux-noble-xpu-n-py3_10-build:

View File

@ -26,14 +26,14 @@ jobs:
curr_ref_type: ${{ github.ref_type }}
build:
name: cuda12.8-py3.10-gcc11-sm80
name: cuda12.8-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-build.yml
needs:
- get-default-label-prefix
with:
runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'
test-matrix: |
{ include: [
@ -42,11 +42,11 @@ jobs:
secrets: inherit
test:
name: cuda12.8-py3.10-gcc11-sm80
name: cuda12.8-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-test.yml
needs: build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
secrets: inherit

View File

@ -83,7 +83,6 @@ jobs:
{ config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
{ config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
{ config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" },
{ config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" },
]}
secrets: inherit
@ -231,8 +230,8 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
build-environment: linux-jammy-cuda12.8-py3.12-gcc11-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks
build-environment: linux-jammy-cuda12.8-py3.12-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'
secrets: inherit
@ -283,7 +282,6 @@ jobs:
name: linux-jammy-py3-clang12-executorch
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
if: false # Has been broken for a while
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3-clang12-executorch

View File

@ -3541,9 +3541,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
const int64_t out_features) {
auto M = inp.size(0);
TORCH_CHECK(
inp.dtype() == kFloat,
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
__func__,
" : expect input to be 32-bit float tensor.");
" : expect input to be float32 or bfloat16 tensor.");
TORCH_CHECK(
block_size == in_features ||
(!(block_size % 32) && !(in_features % block_size)),

View File

@ -813,43 +813,8 @@ void smooth_l1_kernel(TensorIteratorBase& iter, double beta) {
}
void huber_kernel(TensorIterator& iter, double delta) {
// Special-case kHalf: compute in float for numerical stability
if (iter.dtype() == kHalf) {
const float delta_val(static_cast<float>(delta));
const Vectorized<float> delta_vec(static_cast<float>(delta));
const Vectorized<float> point_five_vec(static_cast<float>(0.5));
cpu_kernel_vec(
iter,
// scalar lambda: convert half -> float, compute in float, cast back to half
[&delta_val] (at::Half a, at::Half b) -> at::Half {
float af = static_cast<float>(a);
float bf = static_cast<float>(b);
float z = std::abs(af - bf);
float out = z < delta_val
? 0.5f * z * z
: delta_val * (z - 0.5f * delta_val);
return static_cast<at::Half>(out);
},
[&delta_vec, &point_five_vec] (Vectorized<Half> a, Vectorized<Half> b) {
auto [a0, a1] = convert_half_float(a);
auto [b0, b1] = convert_half_float(b);
auto z = (a0 - b0).abs();
a0 = Vectorized<float>::blendv(
point_five_vec * z * z,
delta_vec * (z - point_five_vec * delta_vec),
z >= delta_vec);
z = (a1 - b1).abs();
a1 = Vectorized<float>::blendv(
point_five_vec * z * z,
delta_vec * (z - point_five_vec * delta_vec),
z >= delta_vec);
return convert_float_half(a0, a1);
}
);
return;
}
else {
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "huber_cpu", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, iter.dtype(), "huber_cpu", [&]() {
using Vec = Vectorized<scalar_t>;
const scalar_t delta_val(delta);
const Vec delta_val_vec(delta_val);
@ -870,7 +835,6 @@ void huber_kernel(TensorIterator& iter, double delta) {
z >= delta_val_vec);
});
});
}
}
void sigmoid_backward_kernel(TensorIteratorBase& iter) {

View File

@ -8,6 +8,7 @@
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/int_mm_kernel.h>
#include <ATen/native/cpu/utils.h>
#include <cmath>
#include <c10/util/Unroll.h>
#include <c10/util/irange.h>
@ -793,6 +794,139 @@ bool can_use_kleidiai(
}
#endif
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
size_t m,
size_t n,
size_t k,
const uint16_t* lhs_bf16,
const uint8_t* rhs_qs4cx,
const float* rhs_scales,
uint16_t* dst_bf16,
float scalar_min,
float scalar_max,
const float* bias) {
// Roundup lambda for internal stride calculations
auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; };
// Cast bfloat16 to float32 inline
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
float f;
std::memcpy(&f, &tmp, sizeof(f));
return f;
};
// Cast float32 to bfloat16 inline
auto cast_f32_to_bf16 = [](float f) {
uint32_t bits;
std::memcpy(&bits, &f, sizeof(bits));
return static_cast<uint16_t>(bits >> 16);
};
// Quantization pack lambda (channelwise QA8DX)
auto quant_pack_8bit_channelwise =
[&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
for (size_t i = 0; i < M; ++i) {
const uint16_t* row_ptr = src_bf16 + i * K;
// find min/max
float mn = FLT_MAX, mx = -FLT_MAX;
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
mn = std::min(mn, v);
mx = std::max(mx, v);
}
float rmin = std::min(0.0f, mn);
float rmax = std::max(0.0f, mx);
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
float recip = scale ? 1.0f / scale : 0.0f;
int32_t zp;
float des_min = rmin * scale;
float des_max = rmax * scale;
float err_min = qmin + des_min;
float err_max = qmax + des_max;
float zp_f =
(err_min + err_max) > 0 ? qmin - des_min : qmax - des_max;
zp_f = std::clamp(zp_f, qmin, qmax);
zp = std::lrintf(zp_f);
int8_t* out_ptr = dst_qa8dx + i * dst_stride;
// store header
*reinterpret_cast<float*>(out_ptr) = recip;
*reinterpret_cast<int32_t*>(out_ptr + sizeof(float)) = -zp;
out_ptr += sizeof(float) + sizeof(int32_t);
// quantize
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
int32_t q = static_cast<int32_t>(std::round(v * scale)) + zp;
q = std::clamp(
q, static_cast<int32_t>(kI8Min), static_cast<int32_t>(kI8Max));
*out_ptr++ = static_cast<int8_t>(q);
}
}
};
// MatMul lambda (MXN x MXK -> MNXK BF16)
auto matmul_kernel = [&](size_t M,
size_t N,
size_t K,
const int8_t* lhs,
const uint8_t* rhs,
const float* scales,
uint16_t* dst,
float lo,
float hi) {
const size_t lhs_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
const size_t rhs_stride = roundup(K, 2) / 2;
for (size_t i = 0; i < M; ++i) {
const int8_t* lhs_row = lhs + i * lhs_stride;
for (size_t j = 0; j < N; ++j) {
int32_t acc = 0;
const int8_t* lptr = lhs_row;
const uint8_t* rptr = rhs + j * rhs_stride;
float lhs_scale = *reinterpret_cast<const float*>(lptr);
int32_t lhs_off =
*reinterpret_cast<const int32_t*>(lptr + sizeof(float));
lptr += sizeof(float) + sizeof(int32_t);
for (size_t t = 0; t < K; ++t) {
int32_t lv = static_cast<int32_t>(lptr[t]);
uint8_t bv = rptr[t / 2];
int32_t rv = ((t & 1) == 0) ? (static_cast<int32_t>(bv & 0xF) - 8)
: (static_cast<int32_t>(bv >> 4) - 8);
acc += lv * rv + lhs_off * rv;
}
float res = static_cast<float>(acc) * scales[j] * lhs_scale;
if (bias) {
res += bias[j];
}
res = std::clamp(res, lo, hi);
*dst++ = cast_f32_to_bf16(res);
}
}
};
// allocate and run
std::unique_ptr<int8_t[]> packed(
new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]);
quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get());
matmul_kernel(
m,
n,
k,
packed.get(),
rhs_qs4cx,
rhs_scales,
dst_bf16,
scalar_min,
scalar_max);
}
/**
* The Int4 quantized weights must be represented as a uint8 tensor
* For matrix multiplication with a weight shape of (N x K)
@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel(
#if AT_KLEIDIAI_ENABLED()
if (can_use_kleidiai(scales_zeros, K, block_size)) {
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type());
packed_weights.resize_({weight_packed_size});
kleidiai::kai_pack_int4_rhs(
packed_weights, weights, scales_zeros, bias, N, K, block_size);
} else
#endif
{
TORCH_CHECK(
bias.has_value() == 0,
__func__,
" : Bias is unsupported in reference implementation");
packed_weights = packed_weights.to(kFloat);
auto weight_reshaped = weights.view({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
auto weight_reshaped = weights.reshape({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat);
std::vector<at::Tensor> tensors_to_cat = {weight_reshaped, scales_zeros_reshaped};
if (bias.has_value()) {
tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat));
}
auto res = at::cat(tensors_to_cat, 0);
packed_weights.resize_(res.sizes()).copy_(res);
}
}
@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
const float* rhs_scales_f32,
float* dst_f32,
float scalar_min,
float scalar_max) {
float scalar_max,
const float* bias) {
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
// required format for matmul
auto input_quant_pack_8bit_channelwise =
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
}
// Maximum/minimum int8 values
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
zero_point0 = std::min(zero_point0, qmax);
// Round to nearest integer
const int32_t nudged_zero_point0 = lrintf(zero_point0);
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = v0_s32 + nudged_zero_point0;
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
v0_s32 = std::max(v0_s32, static_cast<int32_t>(kI8Min));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(kI8Max));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[n_idx];
}
// Clamp (min-max) operation
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float* rhs_scales_fp32,
float* dst_f32,
float scalar_min,
float scalar_max) {
float scalar_max,
const float* bias) {
// Lambda for LHS quantization
auto lhs_quant_pack = [&](size_t m,
size_t k,
const float* lhs_f32,
int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
min0 = std::min(src0_0, min0);
}
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
zero_point0 = std::max(zero_point0, qmin);
zero_point0 = std::min(zero_point0, qmax);
const int32_t nudged_zero_point0 = lrintf(zero_point0);
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float src0_0 = src_ptr[k_idx];
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = std::max(
std::min(
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
static_cast<int32_t>(INT8_MIN));
std::min(v0_s32 + nudged_zero_point0, static_cast<int32_t>(kI8Max)),
static_cast<int32_t>(kI8Min));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[col_idx];
}
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
/**
* Dynamic Input Quant 4 bit weights matmul execution flow
(INT4 Weights + FP scales + FP32 Bias)
FP32 Input Packed Buffer
| |
Quantize Cast
to INT8 to INT8
| |
v v
INT8 Input INT8 Weights
\ /
\ /
\ /
INT8 Matrix Multiplication
|
v
FP32 Dequantized and Accumulate in FP32
|
v
FP32 Final Output
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
* Float32 Scales. If not provided, we will use fallback implementation.
* Dynamic INT4 weight-only MatMul with per-row input quantization.
*
* Execution Flow:
*
* (INT4 Weights + FP Scales [+ optional Bias])
*
* Input (FP32 or BF16) Packed Weight Buffer
* | |
* Row-wise Quantization (INT8) |
* | |
* INT8 Input Activation INT4 Quantized Weights + Scales
* \ /
* \ /
* Quantized Matrix Multiply
* |
* Output Tensor (BF16 or FP32)
*
* Notes:
* - Groupwise kernels expect BF16 scales
* - Channelwise kernels expect FP32 scales
* - Bias is currently unsupported in fallback path
*/
void dyn_quant_matmul_4bit_kernel(
const Tensor& output,
@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel(
const int64_t block_size) {
#if AT_KLEIDIAI_ENABLED()
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type());
if (weight_packed_size == packed_weights.numel()) {
// KleidiAI interface internally handles the Channelwise and groupwise
// distinction
kleidiai::kai_quant_pack_lhs_int4_mm(
output, inp, packed_weights, M, N, K, block_size);
kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size);
} else
#endif
{
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
const auto weights_size = N * K / 2;
// The weights needs to be in uint8_t data type after quantization
auto extracted_weights =
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
auto float32_scales =
(packed_weights.narrow(
0, weights_size, packed_weights.size(0) - weights_size))
.to(kFloat);
uint8_t* rhs_4bit =
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M,
N,
K,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M,
N,
K,
block_size,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else {
TORCH_CHECK(
block_size == K || (!(block_size % 32) && !(K % block_size)),
__func__,
": Group size should be multiple 32 or in_features [",
K,
"]. Provided ",
block_size);
{
void* input = inp.data_ptr();
void* dst = output.data_ptr();
// Extract weights, sclaes and biases form from packed tensor
const int weights_elements = N * K / 2;
const int scale_elements = N * (K / block_size);
TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size");
auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte);
auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat);
auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements);
int bias_elements = packed_weights.numel() - (weights_elements + scale_elements);
float* weight_scales = float32_scales.data_ptr<float>();
void* bias_data = nullptr;
if (bias_elements) {
auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements);
TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension");
bias_data = float32_bias.data_ptr();
}
// 2 elements of 4 bit weights are packed into 1 uint8 packet
uint8_t* weights_4bit = reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
// Dispatch to reference kernels
if (inp.scalar_type() == at::kBFloat16) {
// BF16 input, BF16 output
constexpr float BF16_MAX = 3.38953139e+38f;
constexpr float BF16_MIN = -BF16_MAX;
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
M, N, K,
(uint16_t*)input, weights_4bit, weight_scales,
(uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for BF16 fallback");
}
} else if (inp.scalar_type() == at::kFloat) {
// FP32 input, FP32 output
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M, N, K,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M, N, K, block_size,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for FP32 fallback");
}
} else {
TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel");
}
}
}
}
} // anonymous namespace
}
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)

View File

@ -296,7 +296,7 @@ template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmAndBiasCublasLt(
// args contains result which is modified
cublasCommonArgs& args,
const std::optional<Tensor>& self,
const Tensor& self,
const Scalar& alpha,
Activation activation = Activation::None
) {
@ -304,8 +304,12 @@ bool launchGemmAndBiasCublasLt(
// or when it can be squeezed to 1D.
// self_ptr == nullptr implies ignore bias epilogue
// and use standard gemm-like API.
const auto* self_ptr = self.has_value() ? self.value().const_data_ptr<scalar_t>() : static_cast<const scalar_t*>(nullptr);
const auto* self_ptr = [&]() -> auto {
if (self.dim() == 1 || self.squeeze().dim() == 1) {
return self.const_data_ptr<scalar_t>();
}
return static_cast<const scalar_t*>(nullptr);
}();
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
@ -388,30 +392,35 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
disable_addmm_cuda_lt = disable_addmm_cuda_lt || isGloballyDisabledAddmmCudaLt(self.device());
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
#endif
// Condition on the input
disable_addmm_cuda_lt = disable_addmm_cuda_lt || !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation);
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation) || disable_addmm_cuda_lt;
// }
at::ScalarType scalar_type = mat1.scalar_type();
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
#ifdef USE_ROCM
disable_addmm_cuda_lt = disable_addmm_cuda_lt || is_float_output_with_half_input;
#endif
bool use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
// for float output with half input cublasLT with bias produces wrong results
use_bias_ptr_lt &= !is_float_output_with_half_input;
// Handle result/self shapes
if (!result.is_same(self)) {
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
// We do not copy bias only when we need the bias ptr
// We use bias ptr in the Lt path only when bias is 1D
const auto use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
if (!use_bias_ptr_lt) {
// We do expand self even before
// check for beta != 0.0 to make sure that
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
// runs green.
return expand_size(self, result.sizes(), "addmm");
}
return c10::MaybeOwned<Tensor>::borrowed(self);
}();
// We do not copy bias only when we need the bias ptr
if (beta.toComplexDouble() != 0.0 && !use_bias_ptr_lt) {
// NOTE: self should broadcast over result
at::native::copy_(result, *expand_size(self, result.sizes(), "addmm"));
at::native::copy_(result, *self_maybe_expanded);
}
}
@ -459,7 +468,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation);
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
}
);
#endif
@ -471,7 +480,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation);
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
}
);
} // end is_float_output_with_half_input
@ -927,7 +936,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
return _int_mm_out_cuda(self, mat2, result);
}
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
@ -951,7 +960,7 @@ static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& bat
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
if (self_baddbmm.has_value()) {
if (!is_bmm && self_baddbmm.has_value()) {
const auto& self = self_baddbmm.value();
TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
TORCH_CHECK(self.sizes() == output_size, "self must have the same shape as the output");
@ -959,12 +968,15 @@ static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& bat
}
Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype));
IntArrayRef batch1_sizes = batch1.sizes();
IntArrayRef batch2_sizes = batch2.sizes();
Tensor out = at::empty({batch1_sizes[0], batch1_sizes[1], batch2_sizes[2]}, batch1.options().dtype(out_dtype));
return _bmm_out_dtype_cuda(batch1, batch2, out_dtype, out);
}
Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, Tensor &out) {
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype);
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype, true);
Scalar beta(0.0);
Scalar alpha(1.0);
{
@ -976,16 +988,14 @@ Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at
}
Tensor _baddbmm_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) {
TORCH_CHECK(self.scalar_type() == out_dtype || self.scalar_type() == batch1.dtype(),
"self dtype must match either out_dtype or batch1 dtype");
Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype));
return _baddbmm_out_dtype_cuda(self, batch1, batch2, out_dtype, beta, alpha, out);
// We need to copy the tensor
Tensor out = self.clone().to(self.options().dtype(out_dtype));
return _baddbmm_out_dtype_cuda(out, batch1, batch2, out_dtype, beta, alpha, out);
}
Tensor& _baddbmm_out_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, out);
// We need to copy the tensor
out.copy_(self);
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, false, self);
{
NoNamesGuard guard;
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha);
@ -1020,27 +1030,24 @@ Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::Sca
}
Tensor _addmm_dtype_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) {
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
Tensor result = at::empty({mat1.size(0), mat2.size(1)}, self.options().dtype(out_dtype));
Tensor result = at::empty(self.sizes(), self.options().dtype(out_dtype));
return _addmm_dtype_out_cuda(self, mat1, mat2, out_dtype, beta, alpha, result);
}
Tensor& _addmm_dtype_out_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
// repeat dimensionality checks for direct calls to `out` overload
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(out_dtype == mat1.scalar_type() ||
(out_dtype == at::ScalarType::Float && (mat1.scalar_type() == at::ScalarType::Half || mat1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(out_dtype == self.scalar_type() || self.scalar_type() == mat1.scalar_type(),
"self dtype must match either out_dtype or mat1 dtype");
TORCH_CHECK(out_dtype == self.scalar_type() ||
(out_dtype == at::ScalarType::Float && (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
addmm_out_cuda_impl(out, self, mat1, mat2, beta, alpha);

View File

@ -346,9 +346,8 @@ void dispatch_bf16_grouped_kernel_on_tile_size(
bool small = (M <= 128 || N <= 128);
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
const bool sm10x = properties != nullptr && properties->major == 10;
const bool sm11x = properties != nullptr && properties->major == 11;
if (sm10x || sm11x) {
if (sm10x) {
if (small){
bf16bf16_grouped_gemm_impl_sm90_sm100<
cutlass::arch::Sm100,

View File

@ -607,8 +607,6 @@ _scaled_grouped_mm_cuda_v2(
// scale shape checks
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
// swizze checks
TORCH_CHECK_VALUE(swizzle_a_enum.size() == 1 && swizzle_b_enum.size() == 1, "Expected single swizzle argument");
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,

View File

@ -5,69 +5,11 @@
#include <cuda_bf16.h>
#endif
// ROCm 6.3 is planned to have these functions, but until then here they are.
#if defined(USE_ROCM)
#include <device_functions.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
#if (defined(__gfx942__)) && \
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
union {
__hip_bfloat162_raw bf162_raw;
vec_short2 vs2;
} u{static_cast<__hip_bfloat162_raw>(value)};
u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2);
return static_cast<__hip_bfloat162>(u.bf162_raw);
#else
static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw));
union u_hold {
__hip_bfloat162_raw h2r;
unsigned int u32;
};
u_hold old_val, new_val;
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
do {
new_val.h2r = __hadd2(old_val.h2r, value);
} while (!__hip_atomic_compare_exchange_strong(
(unsigned int*)address, &old_val.u32, new_val.u32,
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
return old_val.h2r;
#endif
}
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
#if (defined(__gfx942__)) && \
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
// The api expects an ext_vector_type of half
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
static_assert(sizeof(vec_fp162) == sizeof(__half2_raw));
union {
__half2_raw h2r;
vec_fp162 fp16;
} u {static_cast<__half2_raw>(value)};
u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16);
return static_cast<__half2>(u.h2r);
#else
static_assert(sizeof(__half2_raw) == sizeof(unsigned int));
union u_hold {
__half2_raw h2r;
unsigned int u32;
};
u_hold old_val, new_val;
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
do {
new_val.h2r = __hadd2(old_val.h2r, value);
} while (!__hip_atomic_compare_exchange_strong(
(unsigned int*)address, &old_val.u32, new_val.u32,
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
return old_val.h2r;
#endif
}
#define ATOMICADD preview_unsafeAtomicAdd
#define ATOMICADD unsafeAtomicAdd
#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f)
#else
#define ATOMICADD atomicAdd

View File

@ -2,250 +2,18 @@
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/native/cuda/ScanUtils.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/OpMathType.h>
#include <c10/util/MathConstants.h>
#include <c10/util/complex.h>
#include <cmath>
#include <limits>
// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
namespace at::native {
// custom min and max to be used in logaddexp for complex arguments
template <typename scalar_t, bool min>
__host__ __device__ c10::complex<scalar_t> _logaddexp_minmax(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
scalar_t xr = std::real(x);
scalar_t yr = std::real(y);
if (::isnan(yr) || (::isnan(std::imag(y)))) {
return y;
} else if (::isnan(xr) || (::isnan(std::imag(x)))) {
return x;
} else if (min) { // min
return (xr < yr) ? x : y;
} else { // max
return (xr >= yr) ? x : y;
}
}
template <typename scalar_t>
__host__ __device__ scalar_t _log_add_exp_helper(const scalar_t& x, const scalar_t& y) {
// Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
// Using the original expression: `at::_isnan(y) ? y : std::min(x, y)` causes an error in ROCM
const auto isnan_x = at::_isnan(x);
const auto isnan_y = at::_isnan(y);
scalar_t min = isnan_y ? y : (isnan_x ? x : std::min(x, y));
scalar_t max = isnan_y ? y : (isnan_x ? x : std::max(x, y));
if (min != max || ::isfinite(min)) {
// nan will be propagated here
return ::log1p(std::exp(min - max)) + max;
} else {
// special case to correctly handle infinite cases
return x;
}
}
template <typename scalar_t>
__host__ __device__ c10::complex<scalar_t> _fast_build_exp(const c10::complex<scalar_t>& x) {
// complex exponential function, but implemented manually to get fast compilation time
// this function only handles the case where the x is finite (not inf nor nan)
const auto xreal = std::real(x);
const auto ximag = std::imag(x);
const auto exp_x_abs = std::exp(xreal);
auto exp_x_real = exp_x_abs * std::cos(ximag);
auto exp_x_imag = exp_x_abs * std::sin(ximag);
return {exp_x_real, exp_x_imag};
}
template <typename scalar_t>
__host__ __device__ c10::complex<scalar_t> _fast_build_exp_inf(const c10::complex<scalar_t>& x) {
// complex exponential function, but implemented manually to get fast compilation time
// this function only handles the case where the real part of x is infinite
const auto ximag = std::imag(x);
constexpr auto exp_x_abs = std::numeric_limits<scalar_t>::infinity();
if (!::isfinite(ximag)) { // add this to make consitent with std::exp(x+yi)
return {exp_x_abs, std::numeric_limits<scalar_t>::quiet_NaN()};
}
const auto sin = std::sin(ximag);
const auto cos = std::cos(ximag);
// special case if the angle is exactly the multiple of pi/2
auto exp_x_real = (cos == 0) ? (scalar_t)0.0 : exp_x_abs * cos;
auto exp_x_imag = (sin == 0) ? (scalar_t)0.0 : exp_x_abs * sin;
return {exp_x_real, exp_x_imag};
}
template <typename scalar_t>
__host__ __device__ c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
c10::complex<scalar_t> min = _logaddexp_minmax<scalar_t, /*min=*/true>(x, y);
c10::complex<scalar_t> max = _logaddexp_minmax<scalar_t, /*min=*/false>(x, y);
scalar_t min_real = std::real(min);
scalar_t max_real = std::real(max);
if (::isnan(min_real) || ::isnan(std::imag(min))) {
// handling the "infectious" NaNs
return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
}
else if ((!::isfinite(min_real)) && (min_real == max_real)) {
if (min_real < 0) {
// handle the -inf case, the imaginary part here does not really matter as the exp(value)
// will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
// It does not matter if we're taking the exp of this value
return min;
} else {
// handle the +inf case, we don't need the special precision for log1p for small values
// and to avoid producing nan in case of real(max) == real(min) == +inf
const auto exp_min = _fast_build_exp_inf(min);
const auto exp_max = _fast_build_exp_inf(max);
return ::log1p(exp_min + exp_max - 1); // log1p(x - 1) builds faster than log
}
} else {
const auto minmax = min - max;
c10::complex<scalar_t> exp_minmax;
if (!::isfinite(minmax.real())) {
exp_minmax = minmax.real() < 0 ? c10::complex<scalar_t>{0.0, 0.0} : _fast_build_exp_inf(minmax);
} else {
exp_minmax = _fast_build_exp(minmax);
}
return ::log1p(exp_minmax) + max;
}
}
// Complex logaddexp jiterator string
const auto logaddexp_complex_string = jiterator_stringify(
template<typename T>
std::complex<T> log1p(const std::complex<T>& z)
{
using complex_t = std::complex<T>;
T x = z.real();
T y = z.imag();
T zabs = abs(z);
T theta = atan2(y, x + T(1));
if (zabs < 0.5) {
T r = x * (T(2) + x) + y * y;
if (r == 0) { // handle underflow
return complex_t(x, theta);
}
return complex_t(T(0.5) * std::log1p(r), theta);
} else {
T z0 = std::hypot(x + 1, y);
return complex_t(log(z0), theta);
}
}
// separated _logaddexp_minmax into 2 different functions for jiterator_string
template <typename T>
std::complex<T> logaddexp_min(const std::complex<T>& x, const std::complex<T>& y) {
T xr = x.real();
T yr = y.real();
if (isnan(yr) || isnan(y.imag())) {
return y;
} else if (isnan(xr) || isnan(x.imag())) {
return x;
} else {
return (xr < yr) ? x : y;
}
}
template <typename T>
std::complex<T> logaddexp_max(const std::complex<T>& x, const std::complex<T>& y) {
T xr = x.real();
T yr = y.real();
if (isnan(yr) || isnan(y.imag())) {
return y;
} else if (isnan(xr) || isnan(x.imag())) {
return x;
} else {
return (xr >= yr) ? x : y;
}
}
template <typename T>
std::complex<T> fast_build_exp(const std::complex<T>& x) {
const auto xreal = x.real();
const auto ximag = x.imag();
const auto exp_x_abs = exp(xreal);
auto exp_x_real = exp_x_abs * cos(ximag);
auto exp_x_imag = exp_x_abs * sin(ximag);
return std::complex<T>(exp_x_real, exp_x_imag);
}
template <typename T>
std::complex<T> fast_build_exp_inf(const std::complex<T>& x) {
using complex_t = std::complex<T>;
const auto ximag = x.imag();
const T exp_x_abs = INFINITY;
if (!isfinite(ximag)) {
return complex_t(exp_x_abs, NAN);
}
const auto sin_val = sin(ximag);
const auto cos_val = cos(ximag);
auto exp_x_real = (cos_val == T(0)) ? T(0) : exp_x_abs * cos_val;
auto exp_x_imag = (sin_val == T(0)) ? T(0) : exp_x_abs * sin_val;
return complex_t(exp_x_real, exp_x_imag);
}
template <typename complex_t>
complex_t logaddexp_complex(complex_t x, complex_t y) {
using T = typename complex_t::value_type;
complex_t min_val = logaddexp_min(x, y);
complex_t max_val = logaddexp_max(x, y);
T min_real = min_val.real();
T max_real = max_val.real();
if (isnan(min_real) || isnan(min_val.imag())) {
return complex_t(NAN, NAN);
}
else if ((!isfinite(min_real)) && (min_real == max_real)) {
if (min_real < T(0)) {
return min_val;
} else {
const auto exp_min = fast_build_exp_inf<T>(min_val);
const auto exp_max = fast_build_exp_inf<T>(max_val);
return log1p(exp_min + exp_max - complex_t(1, 0));
}
} else {
const auto minmax = min_val - max_val;
complex_t exp_minmax;
if (!isfinite(minmax.real())) {
exp_minmax = (minmax.real() < T(0)) ? complex_t(0, 0) : fast_build_exp_inf<T>(minmax);
} else {
exp_minmax = fast_build_exp<T>(minmax);
}
return log1p(exp_minmax) + max_val;
}
}
);
constexpr char logaddexp_complex_name[] = "logaddexp_complex";
void logaddexp_kernel_cuda(TensorIteratorBase& iter) {
if (at::isComplexType(iter.dtype())) {
#if AT_USE_JITERATOR()
AT_DISPATCH_COMPLEX_TYPES_AND(at::ScalarType::ComplexHalf, iter.dtype(), "logaddexp_cuda", [&]() {
jitted_gpu_kernel<
/*name=*/logaddexp_complex_name,
/*return_dtype=*/scalar_t,
/*common_dtype=*/scalar_t,
/*arity=*/2>(iter, logaddexp_complex_string);
});
#else
AT_DISPATCH_COMPLEX_TYPES_AND(at::ScalarType::ComplexHalf, iter.dtype(), "logaddexp_cuda", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t {
const auto a = static_cast<opmath_t>(a_);
const auto b = static_cast<opmath_t>(b_);
return static_cast<scalar_t>(_log_add_exp_helper(a, b));
});
});
#endif
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16, ScalarType::Half,
iter.dtype(), "logaddexp_cuda",
[&]() {
@ -261,7 +29,6 @@ void logaddexp_kernel_cuda(TensorIteratorBase& iter) {
}
});
});
}
}
void logaddexp2_kernel_cuda(TensorIteratorBase& iter) {

View File

@ -958,9 +958,8 @@ void dispatch_fp8_rowwise_kernel_on_sm(
const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9;
const bool sm9x = properties != nullptr && properties->major == 9;
const bool sm10x = properties != nullptr && properties->major == 10;
const bool sm11x = properties != nullptr && properties->major == 11;
const bool sm12x = properties != nullptr && properties->major == 12;
if (!(sm89 || sm9x || sm10x || sm11x || sm12x)) {
if (!(sm89 || sm9x || sm10x || sm12x)) {
TORCH_CHECK(
false, "Rowwise scaling is not currently supported on your device");
}
@ -969,7 +968,7 @@ void dispatch_fp8_rowwise_kernel_on_sm(
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
/*ArchTag=*/cutlass::arch::Sm90,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
} else if (sm10x || sm11x) {
} else if (sm10x) {
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
/*ArchTag=*/cutlass::arch::Sm100,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);

View File

@ -1101,19 +1101,6 @@ _scaled_mxfp8_mxfp8(
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
void
_check_mxfp4_support() {
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
// Only on B200 GPUs
TORCH_CHECK_NOT_IMPLEMENTED(
// B200 = 10.0, B300 = 10.3
dprops->major == 10,
"MXFP4 scaling only supported in CUDA for B200/B300"
);
#endif
}
Tensor&
_scaled_mxfp4_mxfp4(
@ -1126,7 +1113,6 @@ _scaled_mxfp4_mxfp4(
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#else
_check_mxfp4_support();
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",

View File

@ -21,18 +21,27 @@ void kai_pack_int4_rhs(
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
if (weight.scalar_type() == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
}
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -63,19 +72,29 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl) {
const int64_t bl,
at::ScalarType tensor_dtype) {
size_t packed_size = n * k;
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
if (tensor_dtype == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
}
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
});
}
void kai_quant_pack_lhs_int4_mm(
static void kai_quant_pack_lhs_int4_mm_bf16_channelwise(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k) {
// Kernel IDs for GEMM and GEMV
constexpr kai_kernel_id gemm_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm;
constexpr kai_kernel_id gemv_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod;
// Get total threads and select kernel
const int64_t total_threads = at::get_num_threads();
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id);
if (cpuinfo_has_arm_i8mm() && m > 1) {
kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id);
}
// Thread blocking parameters
const int64_t n_step = kernel_packet.ukernel.get_n_step();
const size_t mr = kernel_packet.ukernel.get_mr();
const size_t kr = kernel_packet.ukernel.get_kr();
const size_t sr = kernel_packet.ukernel.get_sr();
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
uint8_t* dst_act_mtx_bf16 = reinterpret_cast<uint8_t*>(output.data_ptr());
const uint8_t* lhs_native_mtx_bf16 =
reinterpret_cast<const uint8_t*>(input.data_ptr());
const uint8_t* rhs_packed_mtx_qs4cx =
reinterpret_cast<const uint8_t*>(weight.data_ptr());
uint8_t* lhs_packed_base = lhs_packed.get();
constexpr int32_t element_size = sizeof(uint16_t);
const size_t lhs_stride = k * element_size;
const size_t dst_stride = n * element_size;
// LHS quantization packing
int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr);
int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread;
const size_t src_stride = vec_per_thread * lhs_stride;
auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) {
const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.kai_run_lhs_quant_pack(
vec_num,
k,
mr,
kr,
sr,
0,
(const uint16_t*)lhs_src_ptr,
lhs_stride,
lhs_packed_ptr);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
lhs_quant_pack(thread_id);
}
});
// Matrix multiplication
vec_per_thread = get_vec_per_thread(n, total_threads, n_step);
num_threads = (n + vec_per_thread - 1) / vec_per_thread;
auto mm = [=, &kernel_packet](int64_t thread_id) {
const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx +
kernel_packet.ukernel.get_rhs_packed_offset(
thread_id * vec_per_thread, k);
auto dst_ptr = dst_act_mtx_bf16 +
kernel_packet.ukernel.get_dst_offset(
0, thread_id * vec_per_thread, dst_stride);
const int64_t vec_num = (thread_id == num_threads - 1)
? (n - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.ukernel.run_matmul(
m,
vec_num,
k,
lhs_packed_base,
rhs_packed_ptr,
(uint16_t*)dst_ptr,
dst_stride,
element_size, // dst_stride_col
-FLT_MAX,
FLT_MAX);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
mm(thread_id);
}
});
}
void kai_quant_pack_lhs_int4_mm(
const at::Tensor& output,
const at::Tensor& input,
const at::Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else if (!(bl % 32) && !(k % bl)) {
const auto input_dtype = input.dtype();
if (input_dtype == at::kBFloat16) {
if (cpuinfo_has_arm_bf16()) {
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support.");
}
} else if (input_dtype == at::kFloat) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"Unsupported input data type: Only Bfloat16 and Float inputs are supported.");
}
} else if ((bl % 32 == 0) && (k % bl == 0)) {
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
output, input, weight, m, n, k, bl);
}

View File

@ -25,7 +25,8 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl);
const int64_t bl,
at::ScalarType tensor_dtype = at::kFloat);
/**
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )

View File

@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(
@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4(
auto weight_packed_data =
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
const auto weight_data = weight.data_ptr<uint8_t>();
const auto scales_data = scales.data_ptr<float>();
const auto scales_data = scales.to(kFloat).data_ptr<float>();
if (weight_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(

View File

@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return channelwise_8bit_4bit_kernels.at(id);
}
// Kernel Mapping - BF16 Channelwise
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>
bf16_channelwise_8bit_4bit_kernels = {
{kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}},
{kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}};
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return bf16_channelwise_8bit_4bit_kernels.at(id);
}
} // namespace at::native::kleidiai
#endif

View File

@ -10,21 +10,32 @@
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
namespace at::native::kleidiai {
enum class kai_kernel_id {
// FP32 inputs, 4-bit weights, FP32 output
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
0, // Groupwise 4 bit GEMV
0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD)
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
1, // Groupwise 4 bit GEMM
1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM)
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
2, // Channelwise 4 bit GEMV
2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD)
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
3 // Channelwise 4 bit GEMM
3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM)
// BF16 inputs, 4-bit weights, BF16 output
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
4, // Channelwise 4-bit GEMV with BF16 input/output
matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm =
5 // Channelwise 4-bit GEMM with BF16 input/output
};
// Channelwise Kernel mapping
@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
// bf16 Channelwise Kernel mapping
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp {
struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel;
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
size_t (*kai_get_lhs_packed_size)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr);
size_t (*kai_get_rhs_packed_size)(
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr);
void (*kai_run_lhs_quant_pack)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr,
size_t m_idx_start,
const void* lhs,
size_t lhs_stride,
void* lhs_packed);
void (*kai_run_rhs_pack)(
size_t num_groups,
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
const uint8_t* rhs,
const float* bias,
const float* scale,
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp(
const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel)
: ukernel(kernel),
kai_get_lhs_packed_size(
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon),
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){}
};
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp
kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id);
// Groupwise Kernel mapping
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(

View File

@ -5,7 +5,6 @@
#include <ATen/native/Resize.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/xpu/Blas.h>
#include <ATen/xpu/XPUScaledBlas.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -340,399 +339,4 @@ Tensor _scaled_mm_xpu(
out);
}
using acceptance_fn = std::function<bool(
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&)>;
using namespace std::placeholders;
namespace scaled_blas = at::native::onednn::scaled;
using scaled_blas::convert_int_to_enum;
using scaled_blas::ScaledGemmImplementation;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2>
scale_kernel_dispatch = {{
{"tensorwise_tensorwise",
scaled_blas::check_tensorwise_recipe,
ScaledGemmImplementation::TENSORWISE_TENSORWISE},
{"rowwise_rowwise",
scaled_blas::check_rowwise_recipe,
ScaledGemmImplementation::ROWWISE_ROWWISE},
}};
Tensor& _scaled_tensorwise_tensorwise(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
bool use_fast_accum,
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(),
mat_b.scalar_type());
TORCH_CHECK_VALUE(
scale_a.numel() == 1 && scale_a.scalar_type() == kFloat,
"scale_a must have 1 Float element")
TORCH_CHECK_VALUE(
scale_b.numel() == 1 && scale_b.scalar_type() == kFloat,
"scale_b must have 1 Float element")
auto scaling_choice_a = ScalingType::TensorWise;
auto scaling_choice_b = ScalingType::TensorWise;
_scaled_gemm(
mat_a,
mat_b,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
use_fast_accum,
out);
return out;
}
Tensor& _scaled_rowwise_rowwise(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
bool use_fast_accum,
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32, shape M/N for A/B
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(),
mat_b.scalar_type());
TORCH_CHECK_VALUE(
scale_a.size(0) == mat_a.size(0) && scale_a.size(1) == 1,
"scale_a must have shape [",
mat_a.size(0),
", 1], got [",
scale_a.sizes(),
"]");
TORCH_CHECK_VALUE(
scale_a.numel() == mat_a.size(0) && scale_a.scalar_type() == kFloat,
"scale_a must have ",
mat_a.size(0),
" Float elements, got ",
scale_a.numel())
TORCH_CHECK_VALUE(
scale_b.numel() == mat_b.size(1) && scale_b.scalar_type() == kFloat,
"scale_b must have ",
mat_b.size(1),
" Float elements, got ",
scale_b.numel())
TORCH_CHECK_VALUE(
scale_a.stride(1) == 1,
"expected scale_a.stride(1) to be 1, but got ",
scale_a.stride(1));
TORCH_CHECK_VALUE(
scale_b.stride(1) == 1,
"expected scale_b.stride(1) to be 1, but got ",
scale_b.stride(1));
auto scaling_choice_a = ScalingType::RowWise;
auto scaling_choice_b = ScalingType::RowWise;
_scaled_gemm(
mat_a,
mat_b,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
use_fast_accum,
out);
return out;
}
// V2: Computes matrix multiply + bias while applying scaling to input and
// output matrices Scales are only applicable when matrices are of Float8 type
// and assumed to be equal to 1.0 by default. If output matrix type is 16 or
// 32-bit type, scale_result is not applied. Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
// - If 1-dimensional tensors are used then scale_a should be size =
// mat1.size(0)
// and scale_b should have size = to mat2.size(1)
// Arguments:
// - `mat_a`: the first operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `mat_b`: the second operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_recipe_a`: An integer corresponding to an enum describing the
// scaling scheme used for `scale_a`
// - `swizzle_a`: An integer corresponding to a `SwizzleType` enum describing
// the swizzling scheme for `scale_a`.
// Not supported for XPU for now.
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_recipe_b`: An integer corresponding to an enum describing the
// scaling scheme used for `scale_b`
// - `swizzle_b`: An integer corresponding to a `SwizzleType` enum describing
// the swizzling scheme for `scale_b`.
// Not supported for XPU for now.
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
// - `out_dtype`: the output dtype, can either be a float8 or a higher
// precision floating point type
// - `contraction_dim`: describe which dimensions are `K` in the matmul.
// Not supported for XPU. Should always be empty.
// - `use_fast_accum`: Not supported for XPU, should always be false.
// - `out`: a reference to the output tensor
Tensor& _scaled_mm_xpu_v2_out(
const Tensor& mat_a,
const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum,
Tensor& out) {
TORCH_CHECK_VALUE(mat_a.dim() == 2, "mat_a must be a matrix");
TORCH_CHECK_VALUE(mat_b.dim() == 2, "mat_b must be a matrix");
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
// kernels do not support this case).
if (mat_a.size(0) == 0 || mat_a.size(1) == 0 || mat_b.size(1) == 0) {
// `out` was created with `at::empty`. In the case where we are multiplying
// MxK by KxN and K is the zero dim, we need to initialize here to properly
// return a tensor of zeros.
at::native::resize_output(out, {mat_a.size(0), mat_b.size(1)});
if (mat_a.size(1) == 0) {
out.zero_();
}
return out;
}
// Note: The `contraction_dim` is not actually used for now. We will need to
// align this code when upstreamed CUDA code is done. Currently, only keeps
// the code here for check.
// Check if the input matrix sizes can be multiplied
// - if optional contraction dims are provided, use those
// -- mostly for < 1B formats (i.e. nvfp4x2) where cheap .t() is not
// available.
if (contraction_dim.size() > 0) {
TORCH_CHECK_VALUE(
contraction_dim.size() == 2,
"contraction_dim must have exactly 2 elements");
auto mat_a_dim = contraction_dim[0];
auto mat_b_dim = contraction_dim[1];
TORCH_CHECK_VALUE(
mat_a.size(mat_a_dim) == mat_b.size(mat_b_dim),
"mat_a and mat_b shapes cannot be multiplied (",
mat_a.size(0),
"x",
mat_a.size(1),
" and ",
mat_b.size(0),
"x",
mat_b.size(1),
") ",
"with contraction dims mat_a: ",
mat_a_dim,
", mat_b: ",
mat_b_dim);
} else {
TORCH_CHECK_VALUE(
mat_a.size(1) == mat_b.size(0),
"mat_a and mat_b shapes cannot be multiplied (",
mat_a.size(0),
"x",
mat_a.size(1),
" and ",
mat_b.size(0),
"x",
mat_b.size(1),
")");
}
TORCH_CHECK_VALUE(
!bias || bias->numel() == mat_b.sizes()[1],
"Bias must be size ",
mat_b.sizes()[1],
" but got ",
bias->numel());
TORCH_CHECK_VALUE(
!out_dtype || *out_dtype == out.scalar_type(),
"out_dtype must match output matrix type");
if (bias) {
TORCH_CHECK_VALUE(
bias->scalar_type() == kFloat ||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
bias->scalar_type() == c10::ScalarType::Half,
"Bias must be Float32 or BFloat16 or Half, but got ",
bias->scalar_type());
}
{
auto bias_ = bias.value_or(Tensor());
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{
{out, "out", 0},
{mat_a, "mat_a", 1},
{mat_b, "mat_b", 2},
{bias_, "bias", 3},
{scale_a[0], "scale_a", 4},
{scale_b[0], "scale_b", 5}};
checkAllSameGPU(__func__, targs);
}
// Align with CUDA's default out to be bf16
auto out_dtype_ = out_dtype.value_or(c10::ScalarType::BFloat16);
// Conversion of implicitly-defined enums to explicit
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
// XPU does not support swizzle for now. So directly return false.
TORCH_CHECK_VALUE(
swizzle_a_enum[0] == at::blas::SwizzleType::NO_SWIZZLE &&
swizzle_b_enum[0] == at::blas::SwizzleType::NO_SWIZZLE,
"XPU does not support swizzle yet.");
// at this point we can start working out what we want to be doing
// Try to do as few steps as possible.
// NOTE: support is deliberately sparse, can explicitly enumerate all
// combinations allowed. Do this via a list of defined (name, acceptance,
// concrete_impl) tuples.
bool found_impl = false;
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
for (const auto& fn_entry : scale_kernel_dispatch) {
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
bool ok = accept_fn(
mat_a.scalar_type(),
scale_recipe_a_enum,
scale_a,
mat_b.scalar_type(),
scale_recipe_b_enum,
scale_b);
if (ok) {
gemm_impl = scaled_gemm_impl;
found_impl = true;
break;
}
}
TORCH_CHECK_VALUE(
found_impl,
"Invalid scaling configuration.\n"
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
mat_a.size(0),
", 1) and scale_b should be (1, ",
mat_b.size(1),
"), and both should be contiguous.\n"
"Got mat_a.dtype()=",
mat_a.scalar_type(),
", scale_a[0].dtype()=",
scale_a[0].scalar_type(),
", scale_a[0].size()=",
scale_a[0].sizes(),
", scale_a[0].stride()=",
scale_a[0].strides(),
", ",
"mat_b.dtype()=",
mat_b.scalar_type(),
", scale_b[0].dtype()=",
scale_b[0].scalar_type(),
", scale_b[0].size()=",
scale_b[0].sizes(),
" and scale_b[0].stride()=",
scale_b[0].strides());
at::native::resize_output(out, {mat_a.size(0), mat_b.size(1)});
auto bias_ = bias.value_or(Tensor());
// dispatch to appropriate lower-level calls for error checking & execution
if (gemm_impl == ScaledGemmImplementation::TENSORWISE_TENSORWISE) {
return _scaled_tensorwise_tensorwise(
mat_a,
mat_b,
scale_a[0],
scale_b[0],
bias,
out_dtype_,
use_fast_accum,
out);
} else if (gemm_impl == ScaledGemmImplementation::ROWWISE_ROWWISE) {
return _scaled_rowwise_rowwise(
mat_a,
mat_b,
scale_a[0],
scale_b[0],
bias,
out_dtype_,
use_fast_accum,
out);
} else {
TORCH_CHECK_VALUE(
false, "Invalid state - found an implementation, but not really");
}
}
Tensor _scaled_mm_xpu_v2(
const Tensor& mat_a,
const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_xpu_v2_out(
mat_a,
mat_b,
scale_a,
scale_recipe_a,
swizzle_a,
scale_b,
scale_recipe_b,
swizzle_b,
bias,
out_dtype,
contraction_dim,
use_fast_accum,
out);
}
} // namespace at::native

View File

@ -147,19 +147,6 @@ class MetalShaderLibrary {
const std::optional<c10::Scalar> alpha = std::nullopt,
const std::optional<c10::ScalarType> scalar_arg_type = std::nullopt);
template <typename T>
void exec_unary_kernel_with_params(
TensorIteratorBase& iter,
const std::string& name,
T params,
const std::string& params_type_name);
template <typename T>
void exec_binary_kernel_with_params(
TensorIteratorBase& iter,
const std::string& name,
T params,
const std::string& params_type_name);
protected:
virtual MTLLibrary_t getLibrary();
virtual MTLLibrary_t getLibrary(

View File

@ -7,12 +7,10 @@
#include <ATen/Tensor.h>
#include <ATen/TensorIterator.h>
#include <ATen/Utils.h>
#include <ATen/mps/MPSProfiler.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/MetalShaderLibrary.h>
#include <ATen/native/mps/TensorFactory.h>
#include <c10/core/ScalarType.h>
#include <fmt/format.h>
#include <torch/library.h>
#include <unordered_map>
@ -632,147 +630,4 @@ inline bool needsGather(const TensorBase& t) {
return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset());
}
template <typename T>
void MetalShaderLibrary::exec_unary_kernel_with_params(TensorIteratorBase& iter,
const std::string& name,
T params,
const std::string& params_type_name) {
using namespace at::mps;
// Decompose 64-bit tensor into 32-bit ones
if (!iter.can_use_32bit_indexing()) {
for (auto&& sub_iter : iter.with_32bit_indexing()) {
exec_unary_kernel_with_params(sub_iter, name, params, params_type_name);
}
return;
}
auto inputTensor = iter.input(0);
auto outputTensor = iter.output(0);
uint32_t length = iter.numel();
if (length == 0) {
return;
}
auto kernel_name = fmt::format("{}_{}_{}_{}{}",
name,
iter.is_contiguous() ? "dense" : "strided",
scalarToMetalTypeString(outputTensor),
scalarToMetalTypeString(inputTensor),
fmt::format("_{}", params_type_name));
@autoreleasepool {
auto cplState = getPipelineStateForFunc(kernel_name);
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
auto computeEncoder = mpsStream->commandEncoder();
getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor});
[computeEncoder setComputePipelineState:cplState];
bind_iter_tensors(computeEncoder, iter);
if (!iter.is_contiguous()) {
mtl_setArgs<2>(computeEncoder,
outputTensor.sizes(),
inputTensor.strides(),
outputTensor.strides(),
inputTensor.ndimension());
}
detail::mtl_setArg(computeEncoder, params, iter.is_contiguous() ? 2 : 6);
mtl_dispatch1DJob(computeEncoder, cplState, length);
getMPSProfiler().endProfileKernel(cplState);
});
}
}
template <typename T>
void MetalShaderLibrary::exec_binary_kernel_with_params(TensorIteratorBase& iter,
const std::string& name,
T params,
const std::string& params_type_name) {
using namespace mps;
// TODO: Figure a better place to downcast double scalars (probably in tensor iterator itself?)
// Right now running something like 1.0-torch.rand(5, device='mps') will create iterator with
// double as common dtype (because Python floating point are always 64-bit values)
TORCH_CHECK(iter.output().scalar_type() != at::kDouble, "float64 is not supported on MPS");
// Skip for empty iterators
if (iter.numel() == 0) {
return;
}
// Decompose 64-bit tensor into 32-bit ones
if (!iter.can_use_32bit_indexing()) {
for (auto&& sub_iter : iter.with_32bit_indexing()) {
exec_binary_kernel_with_params(sub_iter, name, params, params_type_name);
}
return;
}
auto convert_double_scalar = [](Tensor& t) {
if (t.dim() != 0) {
return;
}
if (t.scalar_type() == kDouble) {
t = t.to(kFloat);
} else if (t.scalar_type() == kComplexDouble) {
t = t.to(kComplexFloat);
}
};
Tensor input = iter.input(0);
Tensor other = iter.input(1);
Tensor out = iter.output();
convert_double_scalar(input);
convert_double_scalar(other);
MPSStream* mpsStream = getCurrentMPSStream();
const auto cast_needed = input.scalar_type() != other.scalar_type();
const auto suffix = iter.is_contiguous() ? "dense" : "strided";
// TODO: Implicitly pass both input and output types to non-cast kernels
const auto kernel_name = cast_needed
? fmt::format("{}_{}_cast_{}_{}", name, suffix, scalarToMetalTypeString(out), params_type_name)
: fmt::format("{}_{}_{}_{}_{}",
name,
suffix,
scalarToMetalTypeString(out),
scalarToMetalTypeString(input),
params_type_name);
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
auto computeEncoder = mpsStream->commandEncoder();
auto binaryPSO = getPipelineStateForFunc(kernel_name);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other});
[computeEncoder setComputePipelineState:binaryPSO];
// Set input and output tensors
bind_iter_tensors(computeEncoder, iter);
// Iterator is contiguous if all of its elements are dense in storage,
// i.e. it's true for both row-first and column-first tensors
if (iter.is_contiguous()) {
detail::mtl_setArg(computeEncoder, params, 3);
if (cast_needed) {
std::array<int, 4> size_and_types = {static_cast<int>(c10::elementSize(input.scalar_type())),
static_cast<int>(c10::elementSize(other.scalar_type())),
static_cast<int>(input.scalar_type()),
static_cast<int>(other.scalar_type())};
mtl_setBytes(computeEncoder, size_and_types, 4);
}
} else {
// Please note that shapes and strides of the iterator might be
// different than that of its operands, for example binary op
// between 4x4 tensor and scalar will result in 1D 16 element iterator
std::array<int, 4> ndim_and_types = {iter.ndim(),
static_cast<int>(input.scalar_type()),
static_cast<int>(other.scalar_type()),
static_cast<int>(out.scalar_type())};
mtl_setArgs<3>(
computeEncoder, params, iter.shape(), iter.strides(0), iter.strides(1), iter.strides(2), ndim_and_types);
}
mtl_dispatch1DJob(computeEncoder, binaryPSO, iter.numel());
getMPSProfiler().endProfileKernel(binaryPSO);
}
});
}
} // namespace at::native::mps

View File

@ -1,16 +0,0 @@
#pragma once
template <typename T>
struct ELUParams {
T alpha;
T scale;
T input_scale;
};
template <typename T>
struct ELUBackwardParams {
T alpha;
T scale;
T input_scale;
bool is_result;
};

View File

@ -1,4 +1,3 @@
#include <ATen/native/mps/kernels/Activation.h>
#include <c10/metal/indexing.h>
#include <c10/metal/special_math.h>
#include <metal_stdlib>
@ -100,59 +99,6 @@ REGISTER_BINARY_OP(hardswish_backward, float, float);
REGISTER_BINARY_OP(hardswish_backward, half, half);
REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat);
struct elu_functor {
template <typename T>
inline T operator()(const T self_, const ELUParams<T> params) {
using op_T = opmath_t<T>;
auto alpha = static_cast<op_T>(params.alpha);
auto scale = static_cast<op_T>(params.scale);
auto input_scale = static_cast<op_T>(params.input_scale);
auto self = static_cast<op_T>(self_);
auto neg_res = alpha * (::metal::precise::exp(self * input_scale) - 1);
return static_cast<T>(scale * (self < 0 ? neg_res : self));
}
};
struct elu_backward_functor {
template <typename T>
inline T operator()(
const T grad_output_,
const T self_,
ELUBackwardParams<T> params) {
using op_T = opmath_t<T>;
auto alpha = static_cast<op_T>(params.alpha);
auto scale = static_cast<op_T>(params.scale);
auto input_scale = static_cast<op_T>(params.input_scale);
auto grad_output = static_cast<op_T>(grad_output_);
auto self = static_cast<op_T>(self_);
if (params.is_result) {
auto neg_coef = input_scale * (self + alpha * scale);
return static_cast<T>(grad_output * (self <= 0 ? neg_coef : scale));
} else {
auto neg_coef = input_scale * alpha * scale *
::metal::precise::exp(self * input_scale);
return static_cast<T>(grad_output * (self <= 0 ? neg_coef : scale));
}
}
};
#define REGISTER_ELU_OP(T) \
typedef ELUParams<T> ELUParams_##T; \
REGISTER_UNARY_ALPHA_OP(elu, T, ELUParams_##T, T);
REGISTER_ELU_OP(float);
REGISTER_ELU_OP(half);
REGISTER_ELU_OP(bfloat);
#define REGISTER_ELU_BACKWARD_OP(T) \
typedef ELUBackwardParams<T> ELUBackwardParams_##T; \
REGISTER_BINARY_ALPHA_OP(elu_backward, T, ELUBackwardParams_##T, T);
REGISTER_ELU_BACKWARD_OP(float);
REGISTER_ELU_BACKWARD_OP(half);
REGISTER_ELU_BACKWARD_OP(bfloat);
struct leaky_relu_functor {
template <typename T>
inline T operator()(const T x, const T negative_slope) {

View File

@ -11,6 +11,8 @@
#include <ATen/ops/_log_softmax_native.h>
#include <ATen/ops/_prelu_kernel_backward_native.h>
#include <ATen/ops/_prelu_kernel_native.h>
#include <ATen/ops/elu_backward_native.h>
#include <ATen/ops/elu_native.h>
#include <ATen/ops/gelu_backward_native.h>
#include <ATen/ops/gelu_native.h>
#include <ATen/ops/glu_backward_native.h>
@ -696,6 +698,194 @@ TORCH_IMPL_FUNC(gelu_backward_out_mps)
}
}
static void elu_variants_out_mps(const Tensor& self,
const Scalar& alpha,
const Scalar& scale,
const Scalar& input_scale,
const Tensor& result,
std::string func_name) {
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
auto resultMemFormat = result.suggest_memory_format();
bool executeGatherOp = !(self.is_contiguous(resultMemFormat) && result.is_contiguous(resultMemFormat));
Tensor out;
if (executeGatherOp) {
out = at::empty_like(result, MemoryFormat::Contiguous);
}
// Empty output
if (result.numel() == 0) {
return;
}
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
std::string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to<double>()) + ":" +
std::to_string(scale.to<double>()) + ":" + std::to_string(input_scale.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
// scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) ))
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to<double>()
shape:@[ @1 ]
dataType:getMPSDataType(self)];
MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to<double>()
shape:@[ @1 ]
dataType:getMPSDataType(self)];
MPSGraphTensor* scaleTensor = [mpsGraph constantWithScalar:scale.to<double>()
shape:@[ @1 ]
dataType:getMPSDataType(self)];
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
secondaryTensor:inputScaleTensor
name:nil];
MPSGraphTensor* exponentTensor = [mpsGraph exponentWithTensor:scaledInputTensor name:nil];
MPSGraphTensor* exponentMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:exponentTensor
secondaryTensor:unitTensor
name:nil];
MPSGraphTensor* alphaTimesTensor = [mpsGraph multiplicationWithPrimaryTensor:exponentMinusOneTensor
secondaryTensor:alphaTensor
name:nil];
MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor
secondaryTensor:zeroTensor
name:nil];
MPSGraphTensor* fusedOutput = [mpsGraph selectWithPredicateTensor:predicateTensor
truePredicateTensor:inputTensor
falsePredicateTensor:alphaTimesTensor
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph multiplicationWithPrimaryTensor:fusedOutput
secondaryTensor:scaleTensor
name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
});
auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : result, nil, false);
auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
if (out.has_storage()) {
result.copy_(out);
}
}
}
// scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) ))
TORCH_IMPL_FUNC(elu_out_mps)
(const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result) {
elu_variants_out_mps(self, alpha, scale, input_scale, result, "elu_out_mps");
}
TORCH_IMPL_FUNC(elu_backward_out_mps)
(const Tensor& grad_output,
const Scalar& alpha,
const Scalar& scale,
const Scalar& input_scale,
bool is_result,
const Tensor& self_or_result,
const Tensor& grad_input) {
using namespace mps;
using CachedGraph = MPSUnaryGradCachedGraph;
auto gradMemFormat = grad_input.suggest_memory_format();
bool executeGatherOp = !(grad_output.is_contiguous(gradMemFormat) && self_or_result.is_contiguous(gradMemFormat) &&
grad_input.is_contiguous(gradMemFormat));
Tensor out;
if (executeGatherOp && gradMemFormat == MemoryFormat::ChannelsLast) {
out = at::empty_like(grad_input, MemoryFormat::Contiguous);
}
// Empty output
if (grad_input.numel() == 0) {
return;
}
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
std::string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" +
std::to_string(alpha.to<double>()) + ":" + std::to_string(scale.to<double>()) + ":" +
std::to_string(input_scale.to<double>()) + ":" + std::to_string(is_result);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* selfOrResultTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result);
MPSGraphTensor* lessThanZeroGradTensor = nil;
if (is_result) {
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to<double>()
shape:@[ @1 ]
dataType:getMPSDataType(grad_output)];
MPSGraphTensor* resultPlusAlphaTensor = [mpsGraph additionWithPrimaryTensor:selfOrResultTensor
secondaryTensor:alphaTensor
name:nil];
auto constMul = scale.to<double>() * input_scale.to<double>();
MPSGraphTensor* constMulTensor = [mpsGraph constantWithScalar:constMul
shape:@[ @1 ]
dataType:getMPSDataType(grad_output)];
lessThanZeroGradTensor = [mpsGraph multiplicationWithPrimaryTensor:resultPlusAlphaTensor
secondaryTensor:constMulTensor
name:nil];
} else {
MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to<double>()
shape:@[ @1 ]
dataType:getMPSDataType(grad_output)];
MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:selfOrResultTensor
secondaryTensor:inputScaleTensor
name:nil];
MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:scaledInputTensor name:nil];
auto constMul = scale.to<double>() * input_scale.to<double>() * alpha.to<double>();
MPSGraphTensor* constMulTensor = [mpsGraph constantWithScalar:constMul
shape:@[ @1 ]
dataType:getMPSDataType(grad_output)];
lessThanZeroGradTensor = [mpsGraph multiplicationWithPrimaryTensor:expTensor
secondaryTensor:constMulTensor
name:nil];
}
MPSGraphTensor* scaleTensor = [mpsGraph constantWithScalar:scale.to<double>()
shape:@[ @1 ]
dataType:getMPSDataType(grad_output)];
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output)];
MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:selfOrResultTensor
secondaryTensor:zeroTensor
name:nil];
MPSGraphTensor* gradTensor = [mpsGraph selectWithPredicateTensor:predicateTensor
truePredicateTensor:scaleTensor
falsePredicateTensor:lessThanZeroGradTensor
name:nil];
MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradTensor
secondaryTensor:gradOutputTensor
name:nil];
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
newCachedGraph->inputTensor_ = selfOrResultTensor;
newCachedGraph->gradInputTensor_ = gradInputTensor;
});
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output, nil, executeGatherOp);
Placeholder selfOrResultPlaceholder = Placeholder(cachedGraph->inputTensor_, self_or_result, nil, executeGatherOp);
Placeholder gradInputPlaceholder =
Placeholder(cachedGraph->gradInputTensor_, out.has_storage() ? out : grad_input, nil, false);
auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, selfOrResultPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, gradInputPlaceholder);
if (out.has_storage()) {
grad_input.copy_(out);
}
}
}
TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor& output) {
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;

View File

@ -1,10 +1,8 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/TensorIterator.h>
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/Activation.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/kernels/Activation.h>
#include <fmt/format.h>
namespace at::native {
@ -43,30 +41,6 @@ static void hardswish_backward_kernel(at::TensorIterator& iter) {
lib.exec_binary_kernel(iter, "hardswish_backward");
}
static void elu_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) {
AT_DISPATCH_FLOATING_TYPES_AND2(c10::kHalf, c10::kBFloat16, iter.common_dtype(), "elu_mps", [&]() {
ELUParams<scalar_t> params{alpha.to<scalar_t>(), scale.to<scalar_t>(), input_scale.to<scalar_t>()};
lib.exec_unary_kernel_with_params(
iter, "elu", params, fmt::format("ELUParams_{}", mps::scalarToMetalTypeString(iter.common_dtype())));
});
}
static void elu_backward_kernel(TensorIteratorBase& iter,
const Scalar& alpha,
const Scalar& scale,
const Scalar& input_scale,
bool is_result) {
AT_DISPATCH_FLOATING_TYPES_AND2(c10::kHalf, c10::kBFloat16, iter.common_dtype(), "elu_backward_mps", [&]() {
ELUBackwardParams<scalar_t> params{
alpha.to<scalar_t>(), scale.to<scalar_t>(), input_scale.to<scalar_t>(), is_result};
lib.exec_binary_kernel_with_params(
iter,
"elu_backward",
params,
fmt::format("ELUBackwardParams_{}", mps::scalarToMetalTypeString(iter.common_dtype())));
});
}
static void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negative_slope) {
lib.exec_unary_kernel(iter, "leaky_relu", negative_slope);
}
@ -82,8 +56,6 @@ REGISTER_DISPATCH(hardsigmoid_stub, hardsigmoid_kernel);
REGISTER_DISPATCH(hardsigmoid_backward_stub, hardsigmoid_backward_kernel);
REGISTER_DISPATCH(hardswish_stub, hardswish_kernel);
REGISTER_DISPATCH(hardswish_backward_stub, hardswish_backward_kernel);
REGISTER_DISPATCH(elu_stub, elu_kernel);
REGISTER_DISPATCH(elu_backward_stub, elu_backward_kernel);
REGISTER_DISPATCH(leaky_relu_stub, leaky_relu_kernel);
REGISTER_DISPATCH(leaky_relu_backward_stub, leaky_relu_backward_kernel);

View File

@ -91,30 +91,25 @@ static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/Repeat_metallib.h>
#endif
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
TORCH_CHECK(repeat.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
template <typename index_t>
void computeRepeatIndices(const index_t* repeat_ptr,
const int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {
id<MTLBuffer> repeatBuffer = reinterpret_cast<id<MTLBuffer>>(repeat_ptr);
id<MTLBuffer> cumsumBuffer = reinterpret_cast<id<MTLBuffer>>(cumsum_ptr);
id<MTLBuffer> resultBuffer = reinterpret_cast<id<MTLBuffer>>(result_ptr);
TORCH_CHECK(repeatBuffer && cumsumBuffer && resultBuffer);
std::string scalar_type;
if (repeat.scalar_type() == kInt) {
if constexpr (std::is_same_v<index_t, int32_t>) {
scalar_type = "int32_t";
} else if (repeat.scalar_type() == kLong) {
} else if constexpr (std::is_same_v<index_t, int64_t>) {
scalar_type = "int64_t";
} else {
TORCH_CHECK(false, "repeats has to be Long or Int tensor");
TORCH_CHECK(false, "repeat_interleave: unsupported indexing data type");
}
if (repeat.size(0) == 0) {
return at::empty_like(repeat, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
Tensor repeat_ = repeat.contiguous();
Tensor cumsum = repeat.cumsum(0);
int64_t total = 0;
if (output_size.has_value()) {
total = output_size.value();
} else {
total = cumsum[-1].item<int64_t>();
TORCH_CHECK((repeat >= 0).all().item<uint8_t>(), "repeats can not be negative");
}
auto result = at::empty({total}, repeat.options());
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@ -126,13 +121,20 @@ Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output
getMPSProfiler().beginProfileKernel(pipelineState, "repeat_interleave:" + scalar_type, false);
[computeEncoder setComputePipelineState:pipelineState];
mps::mtl_setArgs(computeEncoder, repeat_, cumsum, result, repeat.size(0));
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, repeat.size(0));
mps::mtl_setArgs(computeEncoder, repeatBuffer, cumsumBuffer, resultBuffer, size);
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, size);
getMPSProfiler().endProfileKernel(pipelineState);
}
});
return result;
}
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
Tensor output;
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
});
return output;
}
} // namespace at::native

View File

@ -4225,7 +4225,7 @@
MTIA: mm_out_mtia
MPS: mm_out_mps
XPU: mm_out_xpu
SparseCPU, SparseCUDA, SparseMPS: _sparse_mm_out
SparseCPU, SparseCUDA: _sparse_mm_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm_out
- func: mm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor
@ -12064,7 +12064,8 @@
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
CPU, CUDA, MPS: elu_out
CPU, CUDA: elu_out
MPS: elu_out_mps
- func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
structured_delegate: elu.out
@ -12077,7 +12078,8 @@
structured_inherits: TensorIteratorBase
python_module: nn
dispatch:
CPU, CUDA, MPS: elu_backward_out
CPU, CUDA: elu_backward_out
MPS: elu_backward_out_mps
- func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor
structured_delegate: elu_backward.grad_input

View File

@ -1,122 +0,0 @@
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/util/typeid.h>
#include <cstdint>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/BlasBackend.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/GroupedMMUtils.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/ceil_div.h>
#include <ATen/xpu/XPUScaledBlas.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
using at::blas::ScalingType;
namespace at::native::onednn::scaled {
/**
* Both inputs must be fp8,
* Each needs a single scale, {Tensorwise (float)}
*/
bool check_tensorwise_recipe(
c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 ||
recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::TensorWise)
return false;
if (scales_a[0].scalar_type() != ScalarType::Float)
return false;
if (recipe_b[0] != ScalingType::TensorWise)
return false;
if (scales_b[0].scalar_type() != ScalarType::Float)
return false;
return true;
}
/**
* Both inputs must be fp8,
* Each needs scales, {Rowwise (float)}
*/
bool check_rowwise_recipe(
c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 ||
recipe_b.size() != 1) {
return false;
}
// Need {RowWise, dp32} for A & B
if (recipe_a[0] != ScalingType::RowWise)
return false;
if (scales_a[0].scalar_type() != ScalarType::Float)
return false;
if (recipe_b[0] != ScalingType::RowWise)
return false;
if (scales_b[0].scalar_type() != ScalarType::Float)
return false;
return true;
}
} // namespace at::native::onednn::scaled

View File

@ -1,95 +0,0 @@
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/util/typeid.h>
#include <cstdint>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/NamedTensor.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/BlasBackend.h>
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
#include <fbgemm_gpu/torch_ops.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
using at::blas::ScalingType;
namespace at::native::onednn::scaled {
/**
* Track concrete implementations available
*/
enum class ScaledGemmImplementation {
NONE = 0,
TENSORWISE_TENSORWISE = 1,
ROWWISE_ROWWISE = 2,
};
/**
* Convert passed int (enum) from python back into a
* strictly-typed enum
*/
template <class EnumType, class ArrayType>
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
std::vector<EnumType> converted;
converted.reserve(v.size());
for (auto vi : v) {
converted.push_back(static_cast<EnumType>(vi));
}
return converted;
}
bool check_tensorwise_recipe(
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
bool check_rowwise_recipe(
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&,
c10::ScalarType,
std::vector<ScalingType>&,
ArrayRef<Tensor>&);
} // namespace at::native::onednn::scaled

View File

@ -9,61 +9,28 @@ def check_perf_csv(filename, threshold, threshold_scale):
"""
Basic performance checking.
"""
try:
df = pd.read_csv(filename)
except FileNotFoundError:
print(f"Error: File {filename} not found")
sys.exit(1)
effective_threshold = threshold * threshold_scale
print(f"Checking {filename} (speedup threshold >= {effective_threshold:.2f}x)\n")
df = pd.read_csv(filename)
failed = []
for _, row in df.iterrows():
model_name = row["name"]
speedup = float(row["speedup"])
abs_latency = float(row["abs_latency"])
compilation_latency = float(row["compilation_latency"])
compression_ratio = float(row["compression_ratio"])
eager_peak_mem = float(row["eager_peak_mem"])
dynamo_peak_mem = float(row["dynamo_peak_mem"])
speedup = row["speedup"]
if speedup < threshold * threshold_scale:
failed.append(model_name)
perf_summary = f"{model_name:34} speedup={speedup:.3f}x"
if pd.notna(abs_latency):
perf_summary += f", latency={abs_latency:.1f} ms/iter"
if pd.notna(compilation_latency):
perf_summary += f", compile={compilation_latency:.3f}s"
if pd.notna(compression_ratio):
perf_summary += f", mem_ratio={1 / compression_ratio:.2f}x"
if pd.notna(eager_peak_mem) and pd.notna(dynamo_peak_mem):
perf_summary += (
f" (eager={eager_peak_mem:.1f} GB, dynamo={dynamo_peak_mem:.1f} GB)"
)
if speedup < effective_threshold:
failed.append((model_name, speedup))
print(perf_summary)
print(f"{model_name:34} {speedup}")
if failed:
print(
textwrap.dedent(
f"""
Error {len(failed)} model(s) performance regressed
{" ".join([name for name, _ in failed])}
Error {len(failed)} models performance regressed
{" ".join(failed)}
"""
)
)
for name, sp in sorted(failed, key=lambda x: x[1]):
pct_from_target = (sp / effective_threshold - 1.0) * 100.0
print(
f" - {name}: {sp:.3f}x (< {effective_threshold:.2f}x; {pct_from_target:.1f}% from target)"
)
sys.exit(1)
else:
print(
f"\nAll {len(df)} model(s) passed threshold check (>= {effective_threshold:.2f}x)"
)
if __name__ == "__main__":
@ -77,7 +44,7 @@ if __name__ == "__main__":
"-s",
type=float,
default=1.0,
help="multiply threshold by this value to relax the check",
help="multiple threshold by this value to relax the check",
)
args = parser.parse_args()
check_perf_csv(args.file, args.threshold, args.threshold_scale)

View File

@ -2379,9 +2379,7 @@ class BenchmarkRunner:
print(
f"Load model outputs from {self.args.compare_model_outputs_with} to compare"
)
saved_result = torch.load(
self.args.compare_model_outputs_with, weights_only=False
)
saved_result = torch.load(self.args.compare_model_outputs_with)
is_bitwise_same = bitwise_same(saved_result, new_result)
if not is_bitwise_same:
print(

View File

@ -92,6 +92,13 @@ inline bool isComplexType(ScalarType t) {
t == ScalarType::ComplexDouble);
}
inline bool isQIntType(ScalarType t) {
// Don't forget to extend this when adding new QInt types
return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
t == ScalarType::QUInt2x4;
}
inline bool isBitsType(ScalarType t) {
return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||

View File

@ -1,4 +1,5 @@
#include <c10/util/Exception.h>
#include <c10/util/FileSystem.h>
#include <c10/util/Logging.h>
#include <c10/util/Type.h>
@ -27,7 +28,7 @@ Error::Error(
const void* caller)
: Error(
str("[enforce fail at ",
detail::StripBasename(file),
c10::filesystem::path(file).filename(),
":",
line,
"] ",

View File

@ -1,4 +1,5 @@
#include <c10/util/Backtrace.h>
#include <c10/util/FileSystem.h>
#include <c10/util/Flags.h>
#include <c10/util/Lazy.h>
#include <c10/util/Logging.h>
@ -478,8 +479,7 @@ MessageLogger::MessageLogger(
<< std::setfill('0') << ' ' << std::setw(2) << timeinfo->tm_hour
<< ':' << std::setw(2) << timeinfo->tm_min << ':' << std::setw(2)
<< timeinfo->tm_sec << '.' << std::setw(9) << ns << ' '
<< c10::detail::StripBasename(std::string(file)) << ':' << line
<< "] ";
<< c10::filesystem::path(file).filename() << ':' << line << "] ";
}
// Output the contents of the stream to the proper channel on destruction.

View File

@ -1643,8 +1643,6 @@ if(USE_CUDA)
target_link_libraries(torch_cuda PUBLIC c10_cuda)
if(TARGET torch::nvtx3)
target_link_libraries(torch_cuda PRIVATE torch::nvtx3)
else()
target_link_libraries(torch_cuda PUBLIC torch::nvtoolsext)
endif()
target_include_directories(
@ -1741,9 +1739,6 @@ if(BUILD_SHARED_LIBS)
if(USE_CUDA)
target_link_libraries(torch_global_deps ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS})
target_link_libraries(torch_global_deps torch::cudart)
if(TARGET torch::nvtoolsext)
target_link_libraries(torch_global_deps torch::nvtoolsext)
endif()
endif()
install(TARGETS torch_global_deps DESTINATION "${TORCH_INSTALL_LIB_DIR}")
endif()

View File

@ -113,23 +113,11 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND _file_compile_flags "-gencode;arch=compute_103a,code=sm_103a")
endif()
endif()
# We will need to gate against CUDA version, because sm_110a is available on CUDA 13.0+
if("${_arch}" STREQUAL "110a" AND CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
if(_existing_arch_flags MATCHES ".*compute_110.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_110a,code=sm_110a")
endif()
endif()
if("${_arch}" STREQUAL "120a")
if(_existing_arch_flags MATCHES ".*compute_120.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
endif()
endif()
# We will need to gate against CUDA version, sm_121a was introduced in CUDA 12.9
if("${_arch}" STREQUAL "121a" AND CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
if(_existing_arch_flags MATCHES ".*compute_120.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
endif()
endif()
endforeach()
list(JOIN _file_compile_flags " " _file_compile_flags)
@ -138,13 +126,13 @@ if(INTERN_BUILD_ATEN_OPS)
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
"89;90a;100a;103a;110a;120a;121a")
"89;90a;100a;103a;120a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
"90a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu"
"90a;100a;103a;110a")
"90a;100a;103a")
endif()

View File

@ -968,11 +968,8 @@ find_package_handle_standard_args(nvtx3 DEFAULT_MSG nvtx3_dir)
if(nvtx3_FOUND)
add_library(torch::nvtx3 INTERFACE IMPORTED)
target_include_directories(torch::nvtx3 INTERFACE "${nvtx3_dir}")
target_compile_definitions(torch::nvtx3 INTERFACE TORCH_CUDA_USE_NVTX3)
else()
message(WARNING "Cannot find NVTX3, find old NVTX instead")
add_library(torch::nvtoolsext INTERFACE IMPORTED)
set_property(TARGET torch::nvtoolsext PROPERTY INTERFACE_LINK_LIBRARIES CUDA::nvToolsExt)
message(FATAL_ERROR "Cannot find NVTX3!")
endif()

View File

@ -15,14 +15,12 @@ if(NOT __AOTRITON_INCLUDED)
"manylinux_2_28" # rocm6.3
"manylinux_2_28" # rocm6.4
"manylinux_2_28" # rocm7.0
"manylinux_2_28" # rocm7.1
)
set(__AOTRITON_ROCM_LIST
"rocm6.2"
"rocm6.3"
"rocm6.4"
"rocm7.0"
"rocm7.1"
)
set(__AOTRITON_CI_COMMIT "972223c501ffc22068bb035ac5d64cf54318d895")
set(__AOTRITON_SHA256_LIST
@ -30,7 +28,6 @@ if(NOT __AOTRITON_INCLUDED)
"72a153549ea20707331e8a1f1e3d1b8de2913f9d5af2b900c56235d578b57efe" # rocm6.3
"c7f319dd7448cbbbab81889dd8a37d47dbc25ebcbd89760f09e6a0904e556393" # rocm6.4
"a2a974e0ad929a5e5827c0f896c59bda4872459cbaf8dd8e0a00407f404491cf" # rocm7.0
"d4eb24c9f1a0cfedb35f9292efb41d16589cf5a4b98c3c0940181bbefc49d722" # rocm7.1
)
set(__AOTRITON_IMAGE_LIST
"amd-gfx90a"

View File

@ -132,9 +132,6 @@ if(@USE_CUDA@)
else()
set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB})
endif()
if(TARGET torch::nvtoolsext)
list(APPEND TORCH_CUDA_LIBRARIES torch::nvtoolsext)
endif()
if(@BUILD_SHARED_LIBS@)
find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")

View File

@ -987,24 +987,6 @@ In addition, `TORCH_DISTRIBUTED_DEBUG=DETAIL` can be used in conjunction with `T
collective desynchronization checks will work for all applications that use `c10d` collective calls backed by process groups created with the
{func}`torch.distributed.init_process_group` and {func}`torch.distributed.new_group` APIs.
### torch.distributed.debug HTTP Server
The `torch.distributed.debug` module provides a HTTP server that can be used to debug distributed applications. The server can
be started by calling {func}`torch.distributed.debug.start_debug_server`. This
allows users to collect data across all workers at runtime.
```{eval-rst}
.. automodule:: torch.distributed.debug
:members:
:undoc-members:
:show-inheritance:
:special-members: __init__
:member-order: bysource
```
## Logging
In addition to explicit debugging support via {func}`torch.distributed.monitored_barrier` and `TORCH_DISTRIBUTED_DEBUG`, the underlying C++ library of `torch.distributed` also outputs log

View File

@ -1,238 +0,0 @@
# Owner(s): ["module: complex"]
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.distributed as dist
# Support both when imported from elsewhere or directly as a file
try:
from .utils import (
COMPLEX_DTYPES,
Descriptor,
force_test_op_db,
get_overload_packet_from_name,
implemented_op_db,
TestCase,
Variant,
)
except ImportError:
from utils import (
COMPLEX_DTYPES,
Descriptor,
force_test_op_db,
get_overload_packet_from_name,
implemented_op_db,
TestCase,
Variant,
)
from torch._subclasses.complex_tensor._ops.common import ComplexTensorMode
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
OpDTypes,
ops,
)
from torch.testing._internal.common_utils import (
run_tests,
TestGradients,
unMarkDynamoStrictTest,
)
if TYPE_CHECKING:
from torch.testing._internal.opinfo.core import OpInfo
aten = torch.ops.aten
SKIPS = {
Descriptor(op=aten.empty_like, variant=None): "Non-deterministic output",
Descriptor(op=aten.randn_like, variant=None): "Non-deterministic output",
Descriptor(op=aten.angle, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.asinh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.atanh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(
op=aten.reciprocal, variant=Variant.GradCheck
): "Numerical inconsistency",
Descriptor(op=aten.rsqrt, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.select, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.asin, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.log, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sgn, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.cumprod, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.slice, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sqrt, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.tan, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(
op=aten.true_divide, variant=Variant.GradCheck
): "Numerical inconsistency",
Descriptor(op=aten.prod, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.div, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.expm1, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.var, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.bmm, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.diagonal, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sinh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.abs, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.sin, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.atan, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.acos, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.acosh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.cos, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.cosh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.addmm, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.pow, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.log1p, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.tanh, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.mm, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.dot, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.mul, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.exp, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(op=aten.to, variant=Variant.GradCheck): "Numerical inconsistency",
Descriptor(
op=aten.any, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.all, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.allclose, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.conj_physical, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten._conj_physical, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.cumprod, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.index_add, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.diagonal_scatter, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.flip, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.masked_fill, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.masked_scatter, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.rsub, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.ne, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.squeeze, variant=Variant.Distributed
): "does not have a sharding strategy registered",
Descriptor(
op=aten.index_select, variant=Variant.Distributed
): "Sharding propagation failed",
Descriptor(op=aten.real, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.imag, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.isfinite, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.transpose, variant=Variant.Distributed): "No scalar support",
Descriptor(op=aten.view_as_real, variant=Variant.Distributed): "No scalar support",
}
EXTRA_KWARGS = {
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Op): {
"rtol": 2e-5,
"atol": 5e-5,
},
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Op): {
"rtol": 1e-4,
"atol": 1e-5,
},
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Op): {
"rtol": 2e-2,
"atol": 2e-6,
},
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 2e-5,
"atol": 5e-5,
},
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 1e-4,
"atol": 1e-5,
},
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 2e-2,
"atol": 2e-6,
},
Descriptor(op=aten.tan, dtype=torch.complex64, variant=Variant.Distributed): {
"rtol": 2e-6,
"atol": 1e-2,
},
}
class TestComplexTensor(TestCase):
_default_dtype_check_enabled = True
@ops(
implemented_op_db,
dtypes=OpDTypes.supported,
allowed_dtypes=list(COMPLEX_DTYPES),
)
def test_consistency(self, device, dtype, op: OpInfo):
self.check_consistency(device, dtype, op, Variant.Op)
@ops(force_test_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
def test_maybe_error(self, device, dtype, op: OpInfo):
self.check_consistency(device, dtype, op, Variant.Op)
@unMarkDynamoStrictTest
class TestComplexBwdGradients(TestGradients):
_default_dtype_check_enabled = True
@ops(
implemented_op_db,
dtypes=OpDTypes.supported_backward,
allowed_dtypes=[torch.complex128],
)
def test_fn_grad(self, device: str, dtype: torch.dtype, op: OpInfo) -> None:
test_info = Descriptor(
op=get_overload_packet_from_name(op.name),
device_type=torch.device(device).type,
dtype=dtype,
variant=Variant.GradCheck,
)
for xfail_info, reason in SKIPS.items():
if xfail_info.matches(test_info):
self.skipTest(reason)
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
self.skipTest(f"Skipped! {dtype=} is not in supported backward dtypes!")
with ComplexTensorMode():
op.gradcheck_fast_mode = False
self._grad_test_helper(device, dtype, op, op.get_op())
instantiate_device_type_tests(TestComplexTensor, globals())
instantiate_device_type_tests(TestComplexBwdGradients, globals())
if dist.is_available():
from torch.testing._internal.common_distributed import MultiProcessTestCase
@unMarkDynamoStrictTest
class TestComplexDistributed(TestCase, MultiProcessTestCase):
@ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
def test_distributed(self, device, dtype, op: OpInfo):
self.check_consistency(device, dtype, op, Variant.Distributed)
instantiate_device_type_tests(TestComplexDistributed, globals())
if __name__ == "__main__":
run_tests()

View File

@ -1,214 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field, fields
from enum import auto, Enum
from typing import Any, TYPE_CHECKING
import torch
import torch.distributed as dist
from torch._subclasses.complex_tensor._ops.common import (
_as_complex_tensor,
_as_interleaved,
_get_op_name,
COMPLEX_OPS_TABLE,
COMPLEX_TO_REAL,
FORCE_TEST_LIST,
OpOverloadPacket,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
from torch.utils._pytree import tree_flatten
if TYPE_CHECKING:
from collections.abc import Callable
from torch.distributed.tensor import DTensor
from torch.testing._internal.opinfo.core import OpInfo
COMPLEX_DTYPES = set(COMPLEX_TO_REAL)
class Variant(Enum):
Op = auto()
GradCheck = auto()
Distributed = auto()
def _as_local(arg: DTensor | Any) -> torch.Tensor | Any:
if not (dist.is_available() and isinstance(arg, dist.tensor.DTensor)):
return arg
return arg.full_tensor()
def _as_complex_dtensor(arg: torch.Tensor | Any) -> torch.Tensor | Any:
if not isinstance(arg, torch.Tensor):
return arg
return dist.tensor.DTensor.from_local(_as_complex_tensor(arg))
TRANSFORM_FUNCS = {
Variant.Op: _as_complex_tensor,
Variant.Distributed: _as_complex_dtensor,
}
@dataclass(frozen=True, kw_only=True)
class Descriptor:
op: OpOverloadPacket
variant: Variant | None
device_type: str | None = field(default=None)
dtype: torch.dtype | None = field(default=None)
def matches(self, other: Descriptor) -> bool:
fields1 = fields(self)
fields2 = fields(other)
if fields1 != fields2:
return False
for f in fields1:
f1 = getattr(self, f.name)
f2 = getattr(other, f.name)
if f1 is not None and f2 is not None and f1 != f2:
return False
return True
class TestCase(PytorchTestCase):
def assertSameResult(
self,
expected: Callable[[], Any],
actual: Callable[[], Any],
*args,
**kwargs,
) -> None:
try:
result_e = expected()
exception_e = None
except Exception as e: # noqa: BLE001
result_e = None
exception_e = e
try:
result_a = actual()
exception_a = None
except Exception as e: # noqa: BLE001
result_a = None
exception_a = e
if (exception_e is None) != (exception_a is None):
if exception_a is not None and exception_e is None:
raise exception_a
self.assertIs(
type(exception_e),
type(exception_a),
f"\n{exception_e=}\n{exception_a=}",
)
if exception_e is None:
flattened_e, spec_e = tree_flatten(result_e)
flattened_a, spec_a = tree_flatten(result_a)
self.assertEqual(
spec_e,
spec_a,
"Both functions must return a result with the same tree structure.",
)
for value_e, value_a in zip(flattened_e, flattened_a, strict=True):
value_e = _as_interleaved(_as_local(value_e))
value_a = _as_interleaved(_as_local(value_a))
self.assertEqual(value_e, value_a, *args, **kwargs)
def check_consistency(
self, device: str, dtype, op: OpInfo, variant: Variant
) -> None:
try:
from .test_complex_tensor import EXTRA_KWARGS, SKIPS
except ImportError:
from test_complex_tensor import EXTRA_KWARGS, SKIPS
test_info = Descriptor(
op=get_overload_packet_from_name(op.name),
device_type=torch.device(device).type,
dtype=dtype,
variant=variant,
)
for xfail_info, reason in SKIPS.items():
if xfail_info.matches(test_info):
self.skipTest(reason)
kwargs = {}
for extra_info, extra_kw in EXTRA_KWARGS.items():
if extra_info.matches(test_info):
kwargs = extra_kw
break
sample_inputs = op.sample_inputs(device, dtype)
transform_fn = TRANSFORM_FUNCS[variant]
for sample_input in sample_inputs:
def expected(sample_input=sample_input):
return op(sample_input.input, *sample_input.args, **sample_input.kwargs)
subclass_sample = sample_input.transform(transform_fn)
def actual(subclass_sample=subclass_sample):
return op(
subclass_sample.input,
*subclass_sample.args,
**subclass_sample.kwargs,
)
self.assertSameResult(expected, actual, **kwargs)
aten = torch.ops.aten
complex_op_db = tuple(
filter(lambda op: any(op.supports_dtype(ct, "cpu") for ct in COMPLEX_DTYPES), op_db)
)
def get_overload_packet_from_name(name: str) -> OpOverloadPacket:
for domain_name in torch.ops:
op_namespace = getattr(torch.ops, domain_name)
op: OpOverloadPacket | None = getattr(op_namespace, name, None)
if op is not None:
return op
raise RuntimeError(f"No op with {name=} found.")
force_test_names = set(map(_get_op_name, FORCE_TEST_LIST))
implemented_op_names = (
set(map(_get_op_name, COMPLEX_OPS_TABLE.keys())) - force_test_names
)
implemented_op_db = tuple(
filter(lambda op: op.name in implemented_op_names, complex_op_db)
)
force_test_op_db = tuple(filter(lambda op: op.name in force_test_names, op_db))
tested_op_names = {op.name for op in implemented_op_db} | {
op.name for op in force_test_op_db
}
non_tested_ops = {
op for op in COMPLEX_OPS_TABLE if _get_op_name(op) not in tested_op_names
}
# TODO (hameerabbasi): There are a number of ops that don't have any associated
# OpInfos. We still need to write tests for those ops.
if len(non_tested_ops) != 0:
import textwrap
import warnings
list_missing_ops = "\n".join(sorted([str(op) for op in non_tested_ops]))
warnings.warn(
"Not all implemented ops are tested. List of ops missing tests:"
f"\n{textwrap.indent(list_missing_ops, ' ')}",
UserWarning,
stacklevel=2,
)

View File

@ -101,14 +101,3 @@ TEST(TestScalarType, toUnderlying) {
AT_FORALL_FLOAT8_TYPES(DEFINE_CHECK);
#undef DEFINE_CHECK
}
TEST(TestScalarType, isQIntType) {
using torch::headeronly::isQIntType;
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(_, name) EXPECT_TRUE(isQIntType(ScalarType::name));
AT_FORALL_QINT_TYPES(DEFINE_CHECK);
#undef DEFINE_CHECK
#define DEFINE_CHECK(_, name) EXPECT_FALSE(isQIntType(ScalarType::name));
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CHECK);
#undef DEFINE_CHECK
}

View File

@ -15,7 +15,7 @@ namespace jit {
TEST(CustomOperatorTest, InferredSchema) {
torch::RegisterOperators reg(
"foo::bar", [](double a, at::Tensor b) { return a + b; });
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
ASSERT_EQ(ops.size(), 1);
auto& op = ops.front();
@ -43,7 +43,8 @@ TEST(CustomOperatorTest, ExplicitSchema) {
"foo::bar_with_schema(float a, Tensor b) -> Tensor",
[](double a, at::Tensor b) { return a + b; });
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
auto& ops =
getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
ASSERT_EQ(ops.size(), 1);
auto& op = ops.front();
@ -76,7 +77,7 @@ TEST(CustomOperatorTest, ListParameters) {
torch::List<c10::complex<double>> complexdoubles,
torch::List<at::Tensor> tensors) { return floats; });
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
ASSERT_EQ(ops.size(), 1);
auto& op = ops.front();
@ -122,7 +123,7 @@ TEST(CustomOperatorTest, ListParameters2) {
"foo::lists2(Tensor[] tensors) -> Tensor[]",
[](torch::List<at::Tensor> tensors) { return tensors; });
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
ASSERT_EQ(ops.size(), 1);
auto& op = ops.front();
@ -212,7 +213,7 @@ TEST(TestCustomOperator, OperatorGeneratorUndeclared) {
},
aliasAnalysisFromSchema())});
auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist"));
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist"));
ASSERT_EQ(ops.size(), 0);
}
@ -231,7 +232,7 @@ TEST(TestCustomOperator, OperatorGeneratorBasic) {
},
aliasAnalysisFromSchema())});
auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar"));
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar"));
ASSERT_EQ(ops.size(), 1);
auto& op = ops.front();

View File

@ -1,20 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
using torch::stable::Tensor;
uint64_t get_any_data_ptr(Tensor t, bool mutable_) {
if (mutable_) {
return reinterpret_cast<uint64_t>(t.mutable_data_ptr());
} else {
return reinterpret_cast<uint64_t>(t.const_data_ptr());
}
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("get_any_data_ptr(Tensor t, bool mutable_) -> int");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("get_any_data_ptr", TORCH_BOX(&get_any_data_ptr));
}

View File

@ -1,34 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
using torch::stable::Tensor;
uint64_t get_template_any_data_ptr(Tensor t, torch::headeronly::ScalarType dtype, bool mutable_) {
#define DEFINE_CASE(T, name) \
case torch::headeronly::ScalarType::name: { \
if (mutable_) { \
return reinterpret_cast<uint64_t>(t.mutable_data_ptr<T>()); \
} else { \
return reinterpret_cast<uint64_t>(t.const_data_ptr<T>()); \
} \
}
switch (dtype) {
// per aten/src/ATen/templates/TensorMethods.cpp:
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
DEFINE_CASE(uint16_t, UInt16)
DEFINE_CASE(uint32_t, UInt32)
DEFINE_CASE(uint64_t, UInt64)
default:
return 0;
}
#undef DEFINE_CASE
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("get_template_any_data_ptr(Tensor t, ScalarType dtype, bool mutable_) -> int");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("get_template_any_data_ptr", TORCH_BOX(&get_template_any_data_ptr));
}

View File

@ -1,41 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <vector>
using torch::stable::Tensor;
// Declare my__foreach_mul (defined in my__foreach_mul.cpp)
extern std::vector<Tensor> my__foreach_mul(
torch::headeronly::HeaderOnlyArrayRef<Tensor> self,
torch::headeronly::HeaderOnlyArrayRef<Tensor> other);
// Helper function for cloning
Tensor my_clone(Tensor t) {
return clone(t);
}
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
// This function tests that my__foreach_mul can take in std::initializer_lists
// in addition to std::vectors.
Tensor t1_1 = my_clone(t1);
Tensor t1_2 = my_clone(t1);
Tensor t2_1 = my_clone(t2);
Tensor t2_2 = my_clone(t2);
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def(
"make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]");
}
STABLE_TORCH_LIBRARY_IMPL(
libtorch_agnostic_2_10,
CompositeExplicitAutograd,
m) {
m.impl(
"make_tensor_clones_and_call_foreach",
TORCH_BOX(&make_tensor_clones_and_call_foreach));
}

View File

@ -1,40 +0,0 @@
// This is duplicated from the libtorch_agnostic_2_9_extension
// as a negative test for test_version_compatibility.py
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/TensorAccessor.h>
#include "tensor_accessor_kernel.h"
using torch::stable::Tensor;
Tensor mv_tensor_accessor_cpu(Tensor m, Tensor v) {
STD_TORCH_CHECK(m.dim() == 2, "m must be 2D");
STD_TORCH_CHECK(v.dim() == 1, "v must be 1D");
STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold");
STD_TORCH_CHECK(m.scalar_type() == v.scalar_type(), "m and v must have the same dtype");
STD_TORCH_CHECK(m.device() == v.device(), "m and v must be on the same device");
Tensor res = new_empty(m, {m.size(0)});
THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cpu",
AT_WRAP(([&]() {
auto resa = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(res.data_ptr()), res.sizes().data(), res.strides().data());
auto ma = Accessor_cpu<scalar_t, 2>(reinterpret_cast<scalar_t*>(m.data_ptr()), m.sizes().data(), m.strides().data());
auto va = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(v.data_ptr()), v.sizes().data(), v.strides().data());
mv_tensor_accessor_kernel<Accessor_cpu, scalar_t>(resa, ma, va);
})),
AT_FLOATING_TYPES);
return res;
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("mv_tensor_accessor_cpu(Tensor res, Tensor m, Tensor v) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("mv_tensor_accessor_cpu", TORCH_BOX(&mv_tensor_accessor_cpu));
}

View File

@ -1,47 +0,0 @@
// This is duplicated from the libtorch_agnostic_2_9_extension
// as a negative test for test_version_compatibility.py
#include "tensor_accessor_kernel.h"
#include <cuda_runtime.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
using torch::stable::Tensor;
Tensor mv_tensor_accessor_cuda(Tensor m, Tensor v) {
STD_TORCH_CHECK(m.dim() == 2, "m must be 2D");
STD_TORCH_CHECK(v.dim() == 1, "v must be 1D");
STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold");
STD_TORCH_CHECK(
m.scalar_type() == v.scalar_type(), "m and v must have the same dtype");
STD_TORCH_CHECK(
m.device() == v.device(), "m and v must be on the same device");
Tensor res = new_empty(m, {m.size(0)});
THO_DISPATCH_V2(
m.scalar_type(),
"mv_tensor_accessor_cuda",
AT_WRAP(([&]() {
auto resa = Accessor_cuda<scalar_t, 1>(
reinterpret_cast<scalar_t*>(res.data_ptr()),
res.sizes().data(),
res.strides().data());
auto ma = Accessor_cuda<scalar_t, 2>(
reinterpret_cast<scalar_t*>(m.data_ptr()),
m.sizes().data(),
m.strides().data());
auto va = Accessor_cuda<scalar_t, 1>(
reinterpret_cast<scalar_t*>(v.data_ptr()),
v.sizes().data(),
v.strides().data());
mv_tensor_accessor_kernel<Accessor_cuda, scalar_t>
<<<1, 1, 0, 0>>>(resa, ma, va);
})),
AT_FLOATING_TYPES);
return res;
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CUDA, m) {
m.impl("mv_tensor_accessor", TORCH_BOX(&mv_tensor_accessor_cuda));
}

View File

@ -1,20 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <vector>
using torch::stable::Tensor;
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data());
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("my__foreach_mul", TORCH_BOX(&my__foreach_mul));
}

View File

@ -1,19 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/stableivalue_conversions.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
using torch::stable::Tensor;
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("my__foreach_mul_", TORCH_BOX(&my__foreach_mul_));
}

View File

@ -1,25 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/device.h>
#include <torch/csrc/stable/ops.h>
#include <optional>
using torch::stable::Tensor;
Tensor my_empty(
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
std::optional<torch::headeronly::ScalarType> dtype,
std::optional<torch::stable::Device> device,
std::optional<bool> pin_memory) {
return empty(size, dtype, device, pin_memory);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def(
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("my_empty", TORCH_BOX(&my_empty));
}

View File

@ -1,17 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
using torch::stable::Tensor;
Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) {
return reshape(t, shape);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("my_reshape(Tensor t, int[] shape) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("my_reshape", TORCH_BOX(&my_reshape));
}

View File

@ -1,20 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
using torch::stable::Tensor;
Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) {
return view(t, size);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("my_view(Tensor t, int[] size) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(
libtorch_agnostic_2_10,
CompositeExplicitAutograd,
m) {
m.impl("my_view", TORCH_BOX(&my_view));
}

View File

@ -1,31 +0,0 @@
// This is duplicated from the libtorch_agnostic_2_9_extension
// as a negative test for test_version_compatibility.py
#pragma once
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/TensorAccessor.h>
template <typename T, size_t N>
using Accessor_cpu = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
#if defined(__CUDACC__) || defined(__HIPCC__)
#define MAYBE_GLOBAL __global__
template <typename T, size_t N>
using Accessor_cuda = torch::headeronly::HeaderOnlyGenericPackedTensorAccessor<T, N, torch::headeronly::RestrictPtrTraits>;
#else
#define MAYBE_GLOBAL
#endif
template <template <typename, size_t> class Accessor, typename scalar_t>
MAYBE_GLOBAL void mv_tensor_accessor_kernel(Accessor<scalar_t, 1> resa, Accessor<scalar_t, 2> ma, Accessor<scalar_t, 1> va) {
for (int64_t i = 0; i < resa.size(0); i++) {
scalar_t val = 0;
for (int64_t j = 0; j < ma.size(1); j++) {
val += ma[i][j] * va[j];
}
resa[i] = val;
}
}

View File

@ -1,37 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/device.h>
#include <string>
torch::stable::Device test_device_constructor(
bool is_cuda,
torch::stable::DeviceIndex index,
bool use_str) {
using torch::stable::Device;
using torch::stable::DeviceType;
if (use_str) {
std::string device_str;
if (is_cuda) {
device_str = "cuda:" + std::to_string(index);
} else {
device_str = "cpu";
}
return Device(device_str);
} else {
if (is_cuda) {
return Device(DeviceType::CUDA, index);
} else {
return Device(DeviceType::CPU);
}
}
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def(
"test_device_constructor(bool is_cuda, DeviceIndex index, bool use_str) -> Device");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("test_device_constructor", TORCH_BOX(&test_device_constructor));
}

View File

@ -1,14 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/device.h>
bool test_device_equality(torch::stable::Device d1, torch::stable::Device d2) {
return d1 == d2;
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("test_device_equality(Device d1, Device d2) -> bool");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("test_device_equality", TORCH_BOX(&test_device_equality));
}

View File

@ -1,14 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/device.h>
torch::stable::DeviceIndex test_device_index(torch::stable::Device device) {
return device.index();
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("test_device_index(Device device) -> DeviceIndex");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("test_device_index", TORCH_BOX(&test_device_index));
}

View File

@ -1,14 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/device.h>
bool test_device_is_cpu(torch::stable::Device device) {
return device.is_cpu();
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("test_device_is_cpu(Device device) -> bool");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("test_device_is_cpu", TORCH_BOX(&test_device_is_cpu));
}

View File

@ -1,14 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/device.h>
bool test_device_is_cuda(torch::stable::Device device) {
return device.is_cuda();
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("test_device_is_cuda(Device device) -> bool");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("test_device_is_cuda", TORCH_BOX(&test_device_is_cuda));
}

View File

@ -1,17 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/device.h>
torch::stable::Device test_device_set_index(
torch::stable::Device device,
torch::stable::DeviceIndex index) {
device.set_index(index);
return device;
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("test_device_set_index(Device device, DeviceIndex index) -> Device");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("test_device_set_index", TORCH_BOX(&test_device_set_index));
}

View File

@ -1,14 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
uint32_t test_get_num_threads() {
return torch::stable::get_num_threads();
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("test_get_num_threads() -> int");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("test_get_num_threads", TORCH_BOX(&test_get_num_threads));
}

View File

@ -1,49 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/device.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
using torch::stable::Tensor;
Tensor test_parallel_for(int64_t size, int64_t grain_size) {
AtenTensorHandle tensor_handle;
int64_t stride = 1;
aoti_torch_empty_strided(
1,
&size,
&stride,
aoti_torch_dtype_int64(),
aoti_torch_device_type_cpu(),
0,
&tensor_handle);
Tensor tensor(tensor_handle);
int64_t* data_ptr = reinterpret_cast<int64_t*>(tensor.data_ptr());
torch::stable::zero_(tensor);
// Use parallel_for to fill each element with its index
// If using a parallel path, the thread id is encoded in the upper 32 bits
torch::stable::parallel_for(
0, size, grain_size, [data_ptr](int64_t begin, int64_t end) {
for (auto i = begin; i < end; i++) {
STD_TORCH_CHECK(i <= UINT32_MAX);
uint32_t thread_id;
torch_get_thread_idx(&thread_id);
data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32);
}
});
return tensor;
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("test_parallel_for(int size, int grain_size) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("test_parallel_for", TORCH_BOX(&test_parallel_for));
}

View File

@ -1,17 +0,0 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/device.h>
using torch::stable::Tensor;
torch::stable::Device test_tensor_device(torch::stable::Tensor tensor) {
return tensor.device();
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("test_tensor_device(Tensor t) -> Device");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("test_tensor_device", TORCH_BOX(&test_tensor_device));
}

View File

@ -1,225 +0,0 @@
import torch
from torch import Tensor
def my__foreach_mul_(tensors, others) -> ():
"""
Updates tensors to be the result of pointwise multiplying with others.
Args:
tensors: list of tensors
others: list of tensors (with the same corresponding shapes as tensors)
Returns: nothing, tensors is updated in place.
"""
torch.ops.libtorch_agnostic_2_10.my__foreach_mul_.default(tensors, others)
def my__foreach_mul(tensors, others) -> list[Tensor]:
"""
Returns a list of tensors that are the results of pointwise multiplying
tensors and others.
Args:
tensors: list of tensors
others: list of tensors (with the same corresponding shapes as tensors)
Returns: list of multiplied tensors
"""
return torch.ops.libtorch_agnostic_2_10.my__foreach_mul.default(tensors, others)
def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]:
"""
Returns a list of 2 tensors corresponding to the square of the inputs.
Args:
t1: Tensor
t2: Tensor
Returns: list of [t1^2, t2^2]
"""
return torch.ops.libtorch_agnostic_2_10.make_tensor_clones_and_call_foreach.default(
t1, t2
)
def test_tensor_device(t):
"""
Tests Tensor device() method.
Args:
t: Tensor - tensor to get device from
Returns: Device - device of the tensor
"""
return torch.ops.libtorch_agnostic_2_10.test_tensor_device.default(t)
def test_device_constructor(is_cuda, index, use_str):
"""
Tests creating a Device from DeviceType and index, or from a string.
Args:
is_cuda: bool - if True, creates CUDA device; if False, creates CPU device
index: int - device index
use_str: bool - if True, constructs from string; if False, constructs from DeviceType
Returns: Device - A device with the specified type and index
"""
return torch.ops.libtorch_agnostic_2_10.test_device_constructor.default(
is_cuda, index, use_str
)
def test_device_equality(d1, d2) -> bool:
"""
Tests Device equality operator.
Args:
d1: Device - first device
d2: Device - second device
Returns: bool - True if devices are equal
"""
return torch.ops.libtorch_agnostic_2_10.test_device_equality.default(d1, d2)
def test_device_set_index(device, index):
"""
Tests Device set_index() method.
Args:
device: Device - device to modify
index: int - new device index
Returns: Device - device with updated index
"""
return torch.ops.libtorch_agnostic_2_10.test_device_set_index.default(device, index)
def test_device_index(device) -> int:
"""
Tests Device index() method.
Args:
device: Device - device to query
Returns: int - device index
"""
return torch.ops.libtorch_agnostic_2_10.test_device_index.default(device)
def test_device_is_cuda(device) -> bool:
"""
Tests Device is_cuda() method.
Args:
device: Device - device to check
Returns: bool - True if device is CUDA
"""
return torch.ops.libtorch_agnostic_2_10.test_device_is_cuda.default(device)
def test_device_is_cpu(device) -> bool:
"""
Tests Device is_cpu() method.
Args:
device: Device - device to check
Returns: bool - True if device is CPU
"""
return torch.ops.libtorch_agnostic_2_10.test_device_is_cpu.default(device)
def test_parallel_for(size, grain_size) -> Tensor:
"""
Tests the parallel_for functionality by using it to fill a tensor with indices.
Args:
size: int - size of the tensor to create
grain_size: int - grain size for parallel_for
Returns: Tensor - a 1D int64 tensor where each element contains its index
(if multiple threads are used the threadid will be encoded in the upper 32 bits)
"""
return torch.ops.libtorch_agnostic_2_10.test_parallel_for.default(size, grain_size)
def test_get_num_threads() -> int:
"""
Tests the get_num_threads functionality by returning the number of threads
for the parallel backend.
Returns: int - the number of threads for the parallel backend
"""
return torch.ops.libtorch_agnostic_2_10.test_get_num_threads.default()
def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor:
"""
Creates an empty tensor with the specified size, dtype, device, and pin_memory.
Args:
size: list[int] - size of the tensor to create
dtype: ScalarType or None - data type of the tensor
device: Device or None - device on which to create the tensor
pin_memory: bool or None - whether to use pinned memory
Returns: Tensor - an uninitialized tensor with the specified properties
"""
return torch.ops.libtorch_agnostic_2_10.my_empty.default(
size, dtype, device, pin_memory
)
def my_reshape(t, shape) -> Tensor:
"""
Returns a tensor with the same data but different shape.
Args:
t: Tensor - tensor to reshape
shape: list[int] - new shape for the tensor
Returns: Tensor - reshaped tensor
"""
return torch.ops.libtorch_agnostic_2_10.my_reshape.default(t, shape)
def my_view(t, size) -> Tensor:
"""
Returns a new tensor with the same data as the input tensor but of a different shape.
Args:
t: Tensor - tensor to view
size: list[int] - new size for the tensor
Returns: Tensor - tensor with new view
"""
return torch.ops.libtorch_agnostic_2_10.my_view.default(t, size)
def get_any_data_ptr(t, mutable) -> int:
"""
Return data pointer value of the tensor.
Args:
t: Input tensor
mutable: whether data pointer qualifier is mutable or const
Returns: int - pointer value
"""
return torch.ops.libtorch_agnostic_2_10.get_any_data_ptr.default(t, mutable)
def get_template_any_data_ptr(t, dtype, mutable) -> int:
"""
Return data pointer value of the tensor iff it has dtype.
Args:
t: Input tensor
dtype: Input dtype
mutable: whether data pointer qualifier is mutable or const
Returns: int - pointer value
Raises RuntimeError when t.dtype() != dtype.
"""
return torch.ops.libtorch_agnostic_2_10.get_template_any_data_ptr.default(
t, dtype, mutable
)

View File

@ -1,308 +0,0 @@
# Owner(s): ["module: cpp"]
"""
Unit tests to verify that each function file requires PyTorch 2.10+.
This test suite compiles each .cpp file in the csrc directory with
TORCH_TARGET_VERSION=2.9.0 and expects compilation to fail.
If compilation succeeds, it means that either
(1) The test function works with 2.9.0 and should not be in this directory.
(2) The test function tests APIs that do not have proper TORCH_FEATURE_VERSION
guards. If this is the case, and you incorrectly move the test function into
libtorch_agnostic_2_9_extension the libtorch_agnostic_targetting CI workflow
will catch this.
Run this script with VERSION_COMPAT_DEBUG=1 to see compilation errors.
"""
import os
import subprocess
import tempfile
from pathlib import Path
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
from torch.utils.cpp_extension import CUDA_HOME, include_paths as torch_include_paths
# TODO: Fix this error in Windows:
# numba.cuda.cudadrv.driver:driver.py:384 Call to cuInit results in CUDA_ERROR_NO_DEVICE
if not IS_WINDOWS:
class FunctionVersionCompatibilityTest(TestCase):
"""Test that all function files require PyTorch 2.10+."""
@classmethod
def setUpClass(cls):
"""Set up test environment once for all tests."""
cls.csrc_dir = Path(__file__).parent / "libtorch_agnostic_2_10" / "csrc"
cls.build_dir = Path(tempfile.mkdtemp(prefix="version_check_"))
cls.pytorch_includes = [
f"-I{path}" for path in torch_include_paths(device_type="cpu")
]
cls.cuda_includes = []
if CUDA_HOME:
cuda_include_path = os.path.join(CUDA_HOME, "include")
if os.path.exists(cuda_include_path):
cls.cuda_includes = [f"-I{cuda_include_path}"]
cls.cuda_available = cls._check_cuda_available()
@classmethod
def tearDownClass(cls):
"""Clean up build directory."""
import shutil
if cls.build_dir.exists():
shutil.rmtree(cls.build_dir)
@staticmethod
def _check_cuda_available() -> bool:
"""Check if CUDA is available."""
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def _compile_cpp_file(
self, source_file: Path, output_file: Path
) -> tuple[bool, str]:
"""
Compile a C++ file with TORCH_TARGET_VERSION=2.9.0.
Returns (success, error_message).
"""
torch_version_2_9 = "0x0209000000000000"
cmd = [
"g++",
"-c",
"-std=c++17",
f"-DTORCH_TARGET_VERSION={torch_version_2_9}",
f"-I{source_file.parent}", # For includes in same directory
*self.pytorch_includes,
]
# Add CUDA flags if available
if self.cuda_available:
cmd.extend(self.cuda_includes)
cmd.extend([str(source_file), "-o", str(output_file)])
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode == 0:
return True, ""
else:
return False, result.stderr
def _compile_cu_file(
self, source_file: Path, output_file: Path
) -> tuple[bool, str]:
"""
Compile a CUDA file with TORCH_TARGET_VERSION=2.9.0.
Returns (success, error_message).
"""
if not CUDA_HOME:
return False, "CUDA_HOME not set"
torch_version_2_9 = "0x0209000000000000"
cmd = [
os.path.join(CUDA_HOME, "bin", "nvcc"),
"-c",
"-std=c++17",
f"-DTORCH_TARGET_VERSION={torch_version_2_9}",
f"-I{source_file.parent}", # For includes in same directory
*self.pytorch_includes,
*self.cuda_includes,
]
cmd.extend([str(source_file), "-o", str(output_file)])
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode == 0:
return True, ""
else:
return False, result.stderr
def _test_function_file(self, source_file: Path):
"""Test that a function file fails to compile with TORCH_TARGET_VERSION=2.9.0."""
func_name = source_file.stem
obj_file = self.build_dir / f"{func_name}.o"
# Choose the appropriate compiler based on file extension
if source_file.suffix == ".cu":
if not self.cuda_available:
self.skipTest(f"CUDA not available, skipping {source_file.name}")
success, error_msg = self._compile_cu_file(source_file, obj_file)
else:
success, error_msg = self._compile_cpp_file(source_file, obj_file)
obj_file.unlink(missing_ok=True)
# Print error details for debugging
if not success:
relevant_errors = self._extract_relevant_errors(error_msg)
if relevant_errors:
print(f"\n Compilation errors for {func_name} (requires 2.10+):")
for err in relevant_errors:
print(f" {err}")
self.assertFalse(
success,
f"Function {func_name} compiled successfully with TORCH_TARGET_VERSION=2.9.0. "
f"This could mean two things.\n\t1. It should run with 2.9.0 and should be "
"moved to libtorch_agnostic_2_9_extension\n\t2. The function(s) it tests do not use the "
"proper TORCH_FEATURE_VERSION guards\n\nThe libtorch_agnostic_targetting CI workflow will "
"verify if you incorrectly move this to the 2_9 extension instead of adding "
"the appropriate version guards.",
)
def test_mv_tensor_accessor_cpu_works_with_2_9(self):
"""Test that mv_tensor_accessor_cpu.cpp compiles successfully with 2.9.0.
This is a negative test - it ensures that a file we expect to work with 2.9.0
actually does compile. This validates that our test infrastructure correctly
distinguishes between files that require 2.10+ and those that don't.
"""
cpp_file = self.csrc_dir / "mv_tensor_accessor_cpu.cpp"
if not cpp_file.exists():
self.skipTest(f"{cpp_file} not found - this is a test file only")
obj_file = self.build_dir / "mv_tensor_accessor_cpu.o"
success, error_msg = self._compile_cpp_file(cpp_file, obj_file)
# Clean up
obj_file.unlink(missing_ok=True)
if not success:
relevant_errors = self._extract_relevant_errors(error_msg)
if relevant_errors:
print(
"\n Unexpected compilation errors for mv_tensor_accessor_cpu:"
)
for err in relevant_errors:
print(f"{err}")
self.assertTrue(
success,
f"mv_tensor_accessor_cpu.cpp failed to compile with TORCH_TARGET_VERSION=2.9.0. "
f"This file is expected to work with 2.9.0 since it doesn't use 2.10+ features. "
f"Error: {error_msg}",
)
def test_mv_tensor_accessor_cuda_works_with_2_9(self):
"""Test that mv_tensor_accessor_cuda.cu compiles successfully with 2.9.0.
This is a negative test - it ensures that a .cu file we expect to work with 2.9.0
actually does compile. This validates that our test infrastructure correctly
compiles CUDA files and distinguishes between files that require 2.10+ and those
that don't.
"""
if not self.cuda_available:
self.skipTest(
"CUDA not available, skipping mv_tensor_accessor_cuda.cu test"
)
cu_file = self.csrc_dir / "mv_tensor_accessor_cuda.cu"
if not cu_file.exists():
self.skipTest(f"{cu_file} not found - this is a test file only")
obj_file = self.build_dir / "cuda_kernel.o"
success, error_msg = self._compile_cu_file(cu_file, obj_file)
# Clean up
obj_file.unlink(missing_ok=True)
if not success:
relevant_errors = self._extract_relevant_errors(error_msg)
if relevant_errors:
print(
"\n Unexpected compilation errors for mv_tensor_accessor_cuda.cu:"
)
for err in relevant_errors:
print(f"{err}")
self.assertTrue(
success,
f"mv_tensor_accessor_cuda.cu failed to compile with TORCH_TARGET_VERSION=2.9.0. "
f"This file is expected to work with 2.9.0 since it doesn't use 2.10+ features. "
f"Error: {error_msg}",
)
@staticmethod
def _extract_relevant_errors(error_msg: str) -> list[str]:
"""Extract the most relevant error messages."""
error_lines = error_msg.strip().split("\n")
relevant_errors = []
for line in error_lines:
line_lower = line.lower()
if (
"error:" in line_lower
or "undefined" in line_lower
or "undeclared" in line_lower
or "no member named" in line_lower
):
relevant_errors.append(line.strip())
return relevant_errors
# Dynamically create test methods for each .cpp and .cu file
def _create_test_method_for_file(source_file: Path):
"""Create a test method for a specific source file."""
def test_method_impl(self):
self._test_function_file(source_file)
# Set a descriptive name and docstring
func_name = source_file.stem
file_ext = source_file.suffix
test_method_impl.__name__ = f"test_{func_name}_requires_2_10"
test_method_impl.__doc__ = (
f"Test that {func_name}{file_ext} requires PyTorch 2.10+"
)
return test_method_impl
# Test discovery: generate a test for each .cpp and .cu file
_csrc_dir = Path(__file__).parent / "libtorch_agnostic_2_10" / "csrc"
if _csrc_dir.exists():
# Collect both .cpp and .cu files, excluding those used for negative test
# already defined above
_source_files = sorted(
[
f
for f in _csrc_dir.rglob("*.cpp")
if f.name not in ("mv_tensor_accessor_cpu.cpp",)
]
+ [
f
for f in _csrc_dir.rglob("*.cu")
if f.name not in ("mv_tensor_accessor_cuda.cu",)
]
)
for _source_file in _source_files:
_test_method = _create_test_method_for_file(_source_file)
setattr(
FunctionVersionCompatibilityTest, _test_method.__name__, _test_method
)
del (
_create_test_method_for_file,
_csrc_dir,
_source_files,
_source_file,
_test_method,
)
if __name__ == "__main__":
run_tests()

View File

@ -1,21 +0,0 @@
import ctypes
from pathlib import Path
import torch
so_files = list(Path(__file__).parent.glob("_C*.so"))
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
# use ctypes.CDLL instead of load_library to be able to test the unload logic
# below code is reduced from the load_library code
with torch._ops.dl_open_guard():
loaded_lib = ctypes.CDLL(so_files[0])
from . import ops
__all__ = [
"loaded_lib",
"ops",
]

Some files were not shown because too many files have changed in this diff Show More