mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 01:54:54 +08:00
Compare commits
1 Commits
dev/joona/
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| e570b2b4f0 |
19
.ci/aarch64_linux/README.md
Normal file
19
.ci/aarch64_linux/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
# Aarch64 (ARM/Graviton) Support Scripts
|
||||
Scripts for building aarch64 PyTorch PIP Wheels. These scripts build the following wheels:
|
||||
* torch
|
||||
* torchvision
|
||||
* torchaudio
|
||||
* torchtext
|
||||
* torchdata
|
||||
## Aarch64_ci_build.sh
|
||||
This script is design to support CD operations within PyPi manylinux aarch64 container, and be executed in the container. It prepares the container and then executes __aarch64_wheel_ci_build.py__ to build the wheels. The script "assumes" the PyTorch repo is located at: ```/pytorch``` and will put the wheels into ```/artifacts```.
|
||||
### Usage
|
||||
```DESIRED_PYTHON=<PythonVersion> aarch64_ci_build.sh```
|
||||
|
||||
__NOTE:__ CI build is currently __EXPERMINTAL__
|
||||
|
||||
## Build_aarch64_wheel.py
|
||||
This app allows a person to build using AWS EC3 resources and requires AWS-CLI and Boto3 with AWS credentials to support building EC2 instances for the wheel builds. Can be used in a codebuild CD or from a local system.
|
||||
|
||||
### Usage
|
||||
```build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch <RCtag>```
|
||||
53
.ci/aarch64_linux/aarch64_ci_build.sh
Normal file
53
.ci/aarch64_linux/aarch64_ci_build.sh
Normal file
@ -0,0 +1,53 @@
|
||||
#!/bin/bash
|
||||
set -eux -o pipefail
|
||||
|
||||
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
|
||||
|
||||
# Set CUDA architecture lists to match x86 build_cuda.sh
|
||||
if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
|
||||
elif [[ "$GPU_ARCH_VERSION" == *"13.0"* ]]; then
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;11.0;12.0+PTX"
|
||||
fi
|
||||
|
||||
# Compress the fatbin with -compress-mode=size for CUDA 13
|
||||
if [[ "$DESIRED_CUDA" == *"13"* ]]; then
|
||||
export TORCH_NVCC_FLAGS="-compress-mode=size"
|
||||
# Bundle ptxas into the cu13 wheel, see https://github.com/pytorch/pytorch/issues/163801
|
||||
export BUILD_BUNDLE_PTXAS=1
|
||||
fi
|
||||
|
||||
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
|
||||
source $SCRIPTPATH/aarch64_ci_setup.sh
|
||||
|
||||
###############################################################################
|
||||
# Run aarch64 builder python
|
||||
###############################################################################
|
||||
cd /
|
||||
# adding safe directory for git as the permissions will be
|
||||
# on the mounted pytorch repo
|
||||
git config --global --add safe.directory /pytorch
|
||||
pip install -r /pytorch/requirements.txt
|
||||
pip install auditwheel==6.2.0 wheel
|
||||
if [ "$DESIRED_CUDA" = "cpu" ]; then
|
||||
echo "BASE_CUDA_VERSION is not set. Building cpu wheel."
|
||||
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
|
||||
else
|
||||
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
|
||||
export USE_SYSTEM_NCCL=1
|
||||
|
||||
# Check if we should use NVIDIA libs from PyPI (similar to x86 build_cuda.sh logic)
|
||||
if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
|
||||
echo "Bundling CUDA libraries with wheel for aarch64."
|
||||
else
|
||||
echo "Using nvidia libs from pypi for aarch64."
|
||||
echo "Updated PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64: $PYTORCH_EXTRA_INSTALL_REQUIREMENTS"
|
||||
export USE_NVIDIA_PYPI_LIBS=1
|
||||
fi
|
||||
|
||||
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
|
||||
fi
|
||||
21
.ci/aarch64_linux/aarch64_ci_setup.sh
Executable file
21
.ci/aarch64_linux/aarch64_ci_setup.sh
Executable file
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
set -eux -o pipefail
|
||||
|
||||
# This script is used to prepare the Docker container for aarch64_ci_wheel_build.py python script
|
||||
# By creating symlinks from desired /opt/python to /usr/local/bin/
|
||||
|
||||
NUMPY_VERSION=2.0.2
|
||||
if [[ "$DESIRED_PYTHON" == "3.13" || "$DESIRED_PYTHON" == "3.13t" ]]; then
|
||||
NUMPY_VERSION=2.1.2
|
||||
fi
|
||||
|
||||
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
|
||||
source $SCRIPTPATH/../manywheel/set_desired_python.sh
|
||||
|
||||
pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2
|
||||
|
||||
for tool in python python3 pip pip3 ninja scons patchelf; do
|
||||
ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin;
|
||||
done
|
||||
|
||||
python --version
|
||||
333
.ci/aarch64_linux/aarch64_wheel_ci_build.py
Executable file
333
.ci/aarch64_linux/aarch64_wheel_ci_build.py
Executable file
@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python3
|
||||
# encoding: UTF-8
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from subprocess import check_call, check_output
|
||||
|
||||
|
||||
def list_dir(path: str) -> list[str]:
|
||||
"""'
|
||||
Helper for getting paths for Python
|
||||
"""
|
||||
return check_output(["ls", "-1", path]).decode().split("\n")
|
||||
|
||||
|
||||
def replace_tag(filename) -> None:
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("Tag:"):
|
||||
lines[i] = line.replace("-linux_", "-manylinux_2_28_")
|
||||
print(f"Updated tag from {line} to {lines[i]}")
|
||||
break
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
|
||||
def patch_library_rpath(
|
||||
folder: str,
|
||||
lib_name: str,
|
||||
use_nvidia_pypi_libs: bool = False,
|
||||
desired_cuda: str = "",
|
||||
) -> None:
|
||||
"""Apply patchelf to set RPATH for a library in torch/lib"""
|
||||
lib_path = f"{folder}/tmp/torch/lib/{lib_name}"
|
||||
|
||||
if use_nvidia_pypi_libs:
|
||||
# For PyPI NVIDIA libraries, construct CUDA RPATH
|
||||
cuda_rpaths = [
|
||||
"$ORIGIN/../../nvidia/cudnn/lib",
|
||||
"$ORIGIN/../../nvidia/nvshmem/lib",
|
||||
"$ORIGIN/../../nvidia/nccl/lib",
|
||||
"$ORIGIN/../../nvidia/cusparselt/lib",
|
||||
]
|
||||
|
||||
if "130" in desired_cuda:
|
||||
cuda_rpaths.append("$ORIGIN/../../nvidia/cu13/lib")
|
||||
else:
|
||||
cuda_rpaths.extend(
|
||||
[
|
||||
"$ORIGIN/../../nvidia/cublas/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_cupti/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_nvrtc/lib",
|
||||
"$ORIGIN/../../nvidia/cuda_runtime/lib",
|
||||
"$ORIGIN/../../nvidia/cufft/lib",
|
||||
"$ORIGIN/../../nvidia/curand/lib",
|
||||
"$ORIGIN/../../nvidia/cusolver/lib",
|
||||
"$ORIGIN/../../nvidia/cusparse/lib",
|
||||
"$ORIGIN/../../nvidia/nvtx/lib",
|
||||
"$ORIGIN/../../nvidia/cufile/lib",
|
||||
]
|
||||
)
|
||||
|
||||
# Add $ORIGIN for local torch libs
|
||||
rpath = ":".join(cuda_rpaths) + ":$ORIGIN"
|
||||
else:
|
||||
# For bundled libraries, just use $ORIGIN
|
||||
rpath = "$ORIGIN"
|
||||
|
||||
if os.path.exists(lib_path):
|
||||
os.system(
|
||||
f"cd {folder}/tmp/torch/lib/; "
|
||||
f"patchelf --set-rpath '{rpath}' --force-rpath {lib_name}"
|
||||
)
|
||||
|
||||
|
||||
def copy_and_patch_library(
|
||||
src_path: str,
|
||||
folder: str,
|
||||
use_nvidia_pypi_libs: bool = False,
|
||||
desired_cuda: str = "",
|
||||
) -> None:
|
||||
"""Copy a library to torch/lib and patch its RPATH"""
|
||||
if os.path.exists(src_path):
|
||||
lib_name = os.path.basename(src_path)
|
||||
shutil.copy2(src_path, f"{folder}/tmp/torch/lib/{lib_name}")
|
||||
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
|
||||
def package_cuda_wheel(wheel_path, desired_cuda) -> None:
|
||||
"""
|
||||
Package the cuda wheel libraries
|
||||
"""
|
||||
folder = os.path.dirname(wheel_path)
|
||||
os.mkdir(f"{folder}/tmp")
|
||||
os.system(f"unzip {wheel_path} -d {folder}/tmp")
|
||||
# Delete original wheel since it will be repackaged
|
||||
os.system(f"rm {wheel_path}")
|
||||
|
||||
# Check if we should use PyPI NVIDIA libraries or bundle system libraries
|
||||
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
|
||||
|
||||
if use_nvidia_pypi_libs:
|
||||
print("Using nvidia libs from pypi - skipping CUDA library bundling")
|
||||
# For PyPI approach, we don't bundle CUDA libraries - they come from PyPI packages
|
||||
# We only need to bundle non-NVIDIA libraries
|
||||
minimal_libs_to_copy = [
|
||||
"/lib64/libgomp.so.1",
|
||||
"/usr/lib64/libgfortran.so.5",
|
||||
"/acl/build/libarm_compute.so",
|
||||
"/acl/build/libarm_compute_graph.so",
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0",
|
||||
]
|
||||
|
||||
# Copy minimal libraries to unzipped_folder/torch/lib
|
||||
for lib_path in minimal_libs_to_copy:
|
||||
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
# Patch torch libraries used for searching libraries
|
||||
torch_libs_to_patch = [
|
||||
"libtorch.so",
|
||||
"libtorch_cpu.so",
|
||||
"libtorch_cuda.so",
|
||||
"libtorch_cuda_linalg.so",
|
||||
"libtorch_global_deps.so",
|
||||
"libtorch_python.so",
|
||||
"libtorch_nvshmem.so",
|
||||
"libc10.so",
|
||||
"libc10_cuda.so",
|
||||
"libcaffe2_nvrtc.so",
|
||||
"libshm.so",
|
||||
]
|
||||
for lib_name in torch_libs_to_patch:
|
||||
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
|
||||
else:
|
||||
print("Bundling CUDA libraries with wheel")
|
||||
# Original logic for bundling system CUDA libraries
|
||||
# Common libraries for all CUDA versions
|
||||
common_libs = [
|
||||
# Non-NVIDIA system libraries
|
||||
"/lib64/libgomp.so.1",
|
||||
"/usr/lib64/libgfortran.so.5",
|
||||
"/acl/build/libarm_compute.so",
|
||||
"/acl/build/libarm_compute_graph.so",
|
||||
# Common CUDA libraries (same for all versions)
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0",
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0",
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so",
|
||||
"/usr/local/cuda/lib64/libcudnn.so.9",
|
||||
"/usr/local/cuda/lib64/libcusparseLt.so.0",
|
||||
"/usr/local/cuda/lib64/libcurand.so.10",
|
||||
"/usr/local/cuda/lib64/libnccl.so.2",
|
||||
"/usr/local/cuda/lib64/libnvshmem_host.so.3",
|
||||
"/usr/local/cuda/lib64/libcudnn_adv.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_cnn.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_graph.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_ops.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9",
|
||||
"/usr/local/cuda/lib64/libcudnn_heuristic.so.9",
|
||||
"/usr/local/cuda/lib64/libcufile.so.0",
|
||||
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
|
||||
"/usr/local/cuda/lib64/libcusparse.so.12",
|
||||
]
|
||||
|
||||
# CUDA version-specific libraries
|
||||
if "13" in desired_cuda:
|
||||
minor_version = desired_cuda[-1]
|
||||
version_specific_libs = [
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13",
|
||||
"/usr/local/cuda/lib64/libcublas.so.13",
|
||||
"/usr/local/cuda/lib64/libcublasLt.so.13",
|
||||
"/usr/local/cuda/lib64/libcudart.so.13",
|
||||
"/usr/local/cuda/lib64/libcufft.so.12",
|
||||
"/usr/local/cuda/lib64/libcusolver.so.12",
|
||||
"/usr/local/cuda/lib64/libnvJitLink.so.13",
|
||||
"/usr/local/cuda/lib64/libnvrtc.so.13",
|
||||
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.{minor_version}",
|
||||
]
|
||||
elif "12" in desired_cuda:
|
||||
# Get the last character for libnvrtc-builtins version (e.g., "129" -> "9")
|
||||
minor_version = desired_cuda[-1]
|
||||
version_specific_libs = [
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12",
|
||||
"/usr/local/cuda/lib64/libcublas.so.12",
|
||||
"/usr/local/cuda/lib64/libcublasLt.so.12",
|
||||
"/usr/local/cuda/lib64/libcudart.so.12",
|
||||
"/usr/local/cuda/lib64/libcufft.so.11",
|
||||
"/usr/local/cuda/lib64/libcusolver.so.11",
|
||||
"/usr/local/cuda/lib64/libnvJitLink.so.12",
|
||||
"/usr/local/cuda/lib64/libnvrtc.so.12",
|
||||
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.{minor_version}",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unsupported CUDA version: {desired_cuda}.")
|
||||
|
||||
# Combine all libraries
|
||||
libs_to_copy = common_libs + version_specific_libs
|
||||
|
||||
# Copy libraries to unzipped_folder/torch/lib
|
||||
for lib_path in libs_to_copy:
|
||||
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
|
||||
|
||||
# Make sure the wheel is tagged with manylinux_2_28
|
||||
for f in os.scandir(f"{folder}/tmp/"):
|
||||
if f.is_dir() and f.name.endswith(".dist-info"):
|
||||
replace_tag(f"{f.path}/WHEEL")
|
||||
break
|
||||
|
||||
os.system(f"wheel pack {folder}/tmp/ -d {folder}")
|
||||
os.system(f"rm -rf {folder}/tmp/")
|
||||
|
||||
|
||||
def complete_wheel(folder: str) -> str:
|
||||
"""
|
||||
Complete wheel build and put in artifact location
|
||||
"""
|
||||
wheel_name = list_dir(f"/{folder}/dist")[0]
|
||||
|
||||
# Please note for cuda we don't run auditwheel since we use custom script to package
|
||||
# the cuda dependencies to the wheel file using update_wheel() method.
|
||||
# However we need to make sure filename reflects the correct Manylinux platform.
|
||||
if "pytorch" in folder and not enable_cuda:
|
||||
print("Repairing Wheel with AuditWheel")
|
||||
check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder)
|
||||
repaired_wheel_name = list_dir(f"/{folder}/wheelhouse")[0]
|
||||
|
||||
print(f"Moving {repaired_wheel_name} wheel to /{folder}/dist")
|
||||
os.rename(
|
||||
f"/{folder}/wheelhouse/{repaired_wheel_name}",
|
||||
f"/{folder}/dist/{repaired_wheel_name}",
|
||||
)
|
||||
else:
|
||||
repaired_wheel_name = list_dir(f"/{folder}/dist")[0]
|
||||
|
||||
print(f"Copying {repaired_wheel_name} to artifacts")
|
||||
shutil.copy2(
|
||||
f"/{folder}/dist/{repaired_wheel_name}", f"/artifacts/{repaired_wheel_name}"
|
||||
)
|
||||
|
||||
return repaired_wheel_name
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""
|
||||
Parse inline arguments
|
||||
"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("AARCH64 wheels python CD")
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--build-only", action="store_true")
|
||||
parser.add_argument("--test-only", type=str)
|
||||
parser.add_argument("--enable-mkldnn", action="store_true")
|
||||
parser.add_argument("--enable-cuda", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Entry Point
|
||||
"""
|
||||
args = parse_arguments()
|
||||
enable_mkldnn = args.enable_mkldnn
|
||||
enable_cuda = args.enable_cuda
|
||||
branch = check_output(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd="/pytorch"
|
||||
).decode()
|
||||
|
||||
print("Building PyTorch wheel")
|
||||
build_vars = ""
|
||||
# MAX_JOB=5 is not required for CPU backend (see commit 465d98b)
|
||||
if enable_cuda:
|
||||
build_vars += "MAX_JOBS=5 "
|
||||
|
||||
# Handle PyPI NVIDIA libraries vs bundled libraries
|
||||
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
|
||||
if use_nvidia_pypi_libs:
|
||||
print("Configuring build for PyPI NVIDIA libraries")
|
||||
# Configure for dynamic linking (matching x86 logic)
|
||||
build_vars += "ATEN_STATIC_CUDA=0 USE_CUDA_STATIC_LINK=0 USE_CUPTI_SO=1 "
|
||||
else:
|
||||
print("Configuring build for bundled NVIDIA libraries")
|
||||
# Keep existing static linking approach - already configured above
|
||||
|
||||
override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION")
|
||||
desired_cuda = os.getenv("DESIRED_CUDA")
|
||||
if override_package_version is not None:
|
||||
version = override_package_version
|
||||
build_vars += (
|
||||
f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version} PYTORCH_BUILD_NUMBER=1 "
|
||||
)
|
||||
elif branch in ["nightly", "main"]:
|
||||
build_date = (
|
||||
check_output(["git", "log", "--pretty=format:%cs", "-1"], cwd="/pytorch")
|
||||
.decode()
|
||||
.replace("-", "")
|
||||
)
|
||||
version = (
|
||||
check_output(["cat", "version.txt"], cwd="/pytorch").decode().strip()[:-2]
|
||||
)
|
||||
if enable_cuda:
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date}+{desired_cuda} PYTORCH_BUILD_NUMBER=1 "
|
||||
else:
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 "
|
||||
elif branch.startswith(("v1.", "v2.")):
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
|
||||
|
||||
if enable_mkldnn:
|
||||
print("build pytorch with mkldnn+acl backend")
|
||||
build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
|
||||
build_vars += "ACL_ROOT_DIR=/acl "
|
||||
if enable_cuda:
|
||||
build_vars += "BLAS=NVPL "
|
||||
else:
|
||||
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS "
|
||||
else:
|
||||
print("build pytorch without mkldnn backend")
|
||||
|
||||
os.system(f"cd /pytorch; {build_vars} python3 -m build --wheel --no-isolation")
|
||||
if enable_cuda:
|
||||
print("Updating Cuda Dependency")
|
||||
filename = os.listdir("/pytorch/dist/")
|
||||
wheel_path = f"/pytorch/dist/{filename[0]}"
|
||||
package_cuda_wheel(wheel_path, desired_cuda)
|
||||
pytorch_wheel_name = complete_wheel("/pytorch/")
|
||||
print(f"Build Complete. Created {pytorch_wheel_name}..")
|
||||
999
.ci/aarch64_linux/build_aarch64_wheel.py
Executable file
999
.ci/aarch64_linux/build_aarch64_wheel.py
Executable file
@ -0,0 +1,999 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# This script is for building AARCH64 wheels using AWS EC2 instances.
|
||||
# To generate binaries for the release follow these steps:
|
||||
# 1. Update mappings for each of the Domain Libraries by adding new row to a table like this:
|
||||
# "v1.11.0": ("0.11.0", "rc1"),
|
||||
# 2. Run script with following arguments for each of the supported python versions and required tag, for example:
|
||||
# build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch v1.11.0-rc3
|
||||
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
|
||||
# AMI images for us-east-1, change the following based on your ~/.aws/config
|
||||
os_amis = {
|
||||
"ubuntu20_04": "ami-052eac90edaa9d08f", # login_name: ubuntu
|
||||
"ubuntu22_04": "ami-0c6c29c5125214c77", # login_name: ubuntu
|
||||
"redhat8": "ami-0698b90665a2ddcf1", # login_name: ec2-user
|
||||
}
|
||||
|
||||
ubuntu20_04_ami = os_amis["ubuntu20_04"]
|
||||
|
||||
|
||||
def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]:
|
||||
if key_name is None:
|
||||
key_name = os.getenv("AWS_KEY_NAME")
|
||||
if key_name is None:
|
||||
return os.getenv("SSH_KEY_PATH", ""), ""
|
||||
|
||||
homedir_path = os.path.expanduser("~")
|
||||
default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem")
|
||||
return os.getenv("SSH_KEY_PATH", default_path), key_name
|
||||
|
||||
|
||||
ec2 = boto3.resource("ec2")
|
||||
|
||||
|
||||
def ec2_get_instances(filter_name, filter_value):
|
||||
return ec2.instances.filter(
|
||||
Filters=[{"Name": filter_name, "Values": [filter_value]}]
|
||||
)
|
||||
|
||||
|
||||
def ec2_instances_of_type(instance_type="t4g.2xlarge"):
|
||||
return ec2_get_instances("instance-type", instance_type)
|
||||
|
||||
|
||||
def ec2_instances_by_id(instance_id):
|
||||
rc = list(ec2_get_instances("instance-id", instance_id))
|
||||
return rc[0] if len(rc) > 0 else None
|
||||
|
||||
|
||||
def start_instance(
|
||||
key_name, ami=ubuntu20_04_ami, instance_type="t4g.2xlarge", ebs_size: int = 50
|
||||
):
|
||||
inst = ec2.create_instances(
|
||||
ImageId=ami,
|
||||
InstanceType=instance_type,
|
||||
SecurityGroups=["ssh-allworld"],
|
||||
KeyName=key_name,
|
||||
MinCount=1,
|
||||
MaxCount=1,
|
||||
BlockDeviceMappings=[
|
||||
{
|
||||
"DeviceName": "/dev/sda1",
|
||||
"Ebs": {
|
||||
"DeleteOnTermination": True,
|
||||
"VolumeSize": ebs_size,
|
||||
"VolumeType": "standard",
|
||||
},
|
||||
}
|
||||
],
|
||||
)[0]
|
||||
print(f"Create instance {inst.id}")
|
||||
inst.wait_until_running()
|
||||
running_inst = ec2_instances_by_id(inst.id)
|
||||
print(f"Instance started at {running_inst.public_dns_name}")
|
||||
return running_inst
|
||||
|
||||
|
||||
class RemoteHost:
|
||||
addr: str
|
||||
keyfile_path: str
|
||||
login_name: str
|
||||
container_id: Optional[str] = None
|
||||
ami: Optional[str] = None
|
||||
|
||||
def __init__(self, addr: str, keyfile_path: str, login_name: str = "ubuntu"):
|
||||
self.addr = addr
|
||||
self.keyfile_path = keyfile_path
|
||||
self.login_name = login_name
|
||||
|
||||
def _gen_ssh_prefix(self) -> list[str]:
|
||||
return [
|
||||
"ssh",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
f"{self.login_name}@{self.addr}",
|
||||
"--",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _split_cmd(args: Union[str, list[str]]) -> list[str]:
|
||||
return args.split() if isinstance(args, str) else args
|
||||
|
||||
def run_ssh_cmd(self, args: Union[str, list[str]]) -> None:
|
||||
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
|
||||
|
||||
def check_ssh_output(self, args: Union[str, list[str]]) -> str:
|
||||
return subprocess.check_output(
|
||||
self._gen_ssh_prefix() + self._split_cmd(args)
|
||||
).decode("utf-8")
|
||||
|
||||
def scp_upload_file(self, local_file: str, remote_file: str) -> None:
|
||||
subprocess.check_call(
|
||||
[
|
||||
"scp",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
local_file,
|
||||
f"{self.login_name}@{self.addr}:{remote_file}",
|
||||
]
|
||||
)
|
||||
|
||||
def scp_download_file(
|
||||
self, remote_file: str, local_file: Optional[str] = None
|
||||
) -> None:
|
||||
if local_file is None:
|
||||
local_file = "."
|
||||
subprocess.check_call(
|
||||
[
|
||||
"scp",
|
||||
"-i",
|
||||
self.keyfile_path,
|
||||
f"{self.login_name}@{self.addr}:{remote_file}",
|
||||
local_file,
|
||||
]
|
||||
)
|
||||
|
||||
def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None:
|
||||
self.run_ssh_cmd("sudo apt-get install -y docker.io")
|
||||
self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}")
|
||||
self.run_ssh_cmd("sudo service docker start")
|
||||
self.run_ssh_cmd(f"docker pull {image}")
|
||||
self.container_id = self.check_ssh_output(
|
||||
f"docker run -t -d -w /root {image}"
|
||||
).strip()
|
||||
|
||||
def using_docker(self) -> bool:
|
||||
return self.container_id is not None
|
||||
|
||||
def run_cmd(self, args: Union[str, list[str]]) -> None:
|
||||
if not self.using_docker():
|
||||
return self.run_ssh_cmd(args)
|
||||
assert self.container_id is not None
|
||||
docker_cmd = self._gen_ssh_prefix() + [
|
||||
"docker",
|
||||
"exec",
|
||||
"-i",
|
||||
self.container_id,
|
||||
"bash",
|
||||
]
|
||||
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE)
|
||||
p.communicate(
|
||||
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
rc = p.wait()
|
||||
if rc != 0:
|
||||
raise subprocess.CalledProcessError(rc, docker_cmd)
|
||||
|
||||
def check_output(self, args: Union[str, list[str]]) -> str:
|
||||
if not self.using_docker():
|
||||
return self.check_ssh_output(args)
|
||||
assert self.container_id is not None
|
||||
docker_cmd = self._gen_ssh_prefix() + [
|
||||
"docker",
|
||||
"exec",
|
||||
"-i",
|
||||
self.container_id,
|
||||
"bash",
|
||||
]
|
||||
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||
(out, err) = p.communicate(
|
||||
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
rc = p.wait()
|
||||
if rc != 0:
|
||||
raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err)
|
||||
return out.decode("utf-8")
|
||||
|
||||
def upload_file(self, local_file: str, remote_file: str) -> None:
|
||||
if not self.using_docker():
|
||||
return self.scp_upload_file(local_file, remote_file)
|
||||
tmp_file = os.path.join("/tmp", os.path.basename(local_file))
|
||||
self.scp_upload_file(local_file, tmp_file)
|
||||
self.run_ssh_cmd(
|
||||
["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"]
|
||||
)
|
||||
self.run_ssh_cmd(["rm", tmp_file])
|
||||
|
||||
def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None:
|
||||
if not self.using_docker():
|
||||
return self.scp_download_file(remote_file, local_file)
|
||||
tmp_file = os.path.join("/tmp", os.path.basename(remote_file))
|
||||
self.run_ssh_cmd(
|
||||
["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file]
|
||||
)
|
||||
self.scp_download_file(tmp_file, local_file)
|
||||
self.run_ssh_cmd(["rm", tmp_file])
|
||||
|
||||
def download_wheel(
|
||||
self, remote_file: str, local_file: Optional[str] = None
|
||||
) -> None:
|
||||
if self.using_docker() and local_file is None:
|
||||
basename = os.path.basename(remote_file)
|
||||
local_file = basename.replace(
|
||||
"-linux_aarch64.whl", "-manylinux2014_aarch64.whl"
|
||||
)
|
||||
self.download_file(remote_file, local_file)
|
||||
|
||||
def list_dir(self, path: str) -> list[str]:
|
||||
return self.check_output(["ls", "-1", path]).split("\n")
|
||||
|
||||
|
||||
def wait_for_connection(addr, port, timeout=15, attempt_cnt=5):
|
||||
import socket
|
||||
|
||||
for i in range(attempt_cnt):
|
||||
try:
|
||||
with socket.create_connection((addr, port), timeout=timeout):
|
||||
return
|
||||
except (ConnectionRefusedError, TimeoutError): # noqa: PERF203
|
||||
if i == attempt_cnt - 1:
|
||||
raise
|
||||
time.sleep(timeout)
|
||||
|
||||
|
||||
def update_apt_repo(host: RemoteHost) -> None:
|
||||
time.sleep(5)
|
||||
host.run_cmd("sudo systemctl stop apt-daily.service || true")
|
||||
host.run_cmd("sudo systemctl stop unattended-upgrades.service || true")
|
||||
host.run_cmd(
|
||||
"while systemctl is-active --quiet apt-daily.service; do sleep 1; done"
|
||||
)
|
||||
host.run_cmd(
|
||||
"while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done"
|
||||
)
|
||||
host.run_cmd("sudo apt-get update")
|
||||
time.sleep(3)
|
||||
host.run_cmd("sudo apt-get update")
|
||||
|
||||
|
||||
def install_condaforge(
|
||||
host: RemoteHost, suffix: str = "latest/download/Miniforge3-Linux-aarch64.sh"
|
||||
) -> None:
|
||||
print("Install conda-forge")
|
||||
host.run_cmd(f"curl -OL https://github.com/conda-forge/miniforge/releases/{suffix}")
|
||||
host.run_cmd(f"sh -f {os.path.basename(suffix)} -b")
|
||||
host.run_cmd(f"rm -f {os.path.basename(suffix)}")
|
||||
if host.using_docker():
|
||||
host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc")
|
||||
else:
|
||||
host.run_cmd(
|
||||
[
|
||||
"sed",
|
||||
"-i",
|
||||
"'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH'",
|
||||
".bashrc",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None:
|
||||
if python_version == "3.6":
|
||||
# Python-3.6 EOLed and not compatible with conda-4.11
|
||||
install_condaforge(
|
||||
host, suffix="download/4.10.3-10/Miniforge3-4.10.3-10-Linux-aarch64.sh"
|
||||
)
|
||||
host.run_cmd(f"conda install -y python={python_version} numpy pyyaml")
|
||||
else:
|
||||
install_condaforge(
|
||||
host, suffix="download/4.11.0-4/Miniforge3-4.11.0-4-Linux-aarch64.sh"
|
||||
)
|
||||
# Pytorch-1.10 or older are not compatible with setuptools=59.6 or newer
|
||||
host.run_cmd(
|
||||
f"conda install -y python={python_version} numpy pyyaml setuptools>=59.5.0"
|
||||
)
|
||||
|
||||
|
||||
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
|
||||
host.run_cmd("pip3 install auditwheel")
|
||||
host.run_cmd(
|
||||
"conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf"
|
||||
)
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
with NamedTemporaryFile() as tmp:
|
||||
tmp.write(embed_library_script.encode("utf-8"))
|
||||
tmp.flush()
|
||||
host.upload_file(tmp.name, "embed_library.py")
|
||||
|
||||
print("Embedding libgomp into wheel")
|
||||
if host.using_docker():
|
||||
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
|
||||
else:
|
||||
host.run_cmd(f"python3 embed_library.py {wheel_name}")
|
||||
|
||||
|
||||
def checkout_repo(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
url: str,
|
||||
git_clone_flags: str,
|
||||
mapping: dict[str, tuple[str, str]],
|
||||
) -> Optional[str]:
|
||||
for prefix in mapping:
|
||||
if not branch.startswith(prefix):
|
||||
continue
|
||||
tag = f"v{mapping[prefix][0]}-{mapping[prefix][1]}"
|
||||
host.run_cmd(f"git clone {url} -b {tag} {git_clone_flags}")
|
||||
return mapping[prefix][0]
|
||||
|
||||
host.run_cmd(f"git clone {url} -b {branch} {git_clone_flags}")
|
||||
return None
|
||||
|
||||
|
||||
def build_torchvision(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str,
|
||||
run_smoke_tests: bool = True,
|
||||
) -> str:
|
||||
print("Checking out TorchVision repo")
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/vision",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.7.1": ("0.8.2", "rc2"),
|
||||
"v1.8.0": ("0.9.0", "rc3"),
|
||||
"v1.8.1": ("0.9.1", "rc1"),
|
||||
"v1.9.0": ("0.10.0", "rc1"),
|
||||
"v1.10.0": ("0.11.1", "rc1"),
|
||||
"v1.10.1": ("0.11.2", "rc1"),
|
||||
"v1.10.2": ("0.11.3", "rc1"),
|
||||
"v1.11.0": ("0.12.0", "rc1"),
|
||||
"v1.12.0": ("0.13.0", "rc4"),
|
||||
"v1.12.1": ("0.13.1", "rc6"),
|
||||
"v1.13.0": ("0.14.0", "rc4"),
|
||||
"v1.13.1": ("0.14.1", "rc2"),
|
||||
"v2.0.0": ("0.15.1", "rc2"),
|
||||
"v2.0.1": ("0.15.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchVision wheel")
|
||||
|
||||
# Please note libnpg and jpeg are required to build image.so extension
|
||||
if use_conda:
|
||||
host.run_cmd("conda install -y libpng jpeg")
|
||||
# Remove .so files to force static linking
|
||||
host.run_cmd(
|
||||
"rm miniforge3/lib/libpng.so miniforge3/lib/libpng16.so miniforge3/lib/libjpeg.so"
|
||||
)
|
||||
# And patch setup.py to include libz dependency for libpng
|
||||
host.run_cmd(
|
||||
[
|
||||
'sed -i -e \'s/image_link_flags\\.append("png")/image_link_flags += ["png", "z"]/\' vision/setup.py'
|
||||
]
|
||||
)
|
||||
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"]
|
||||
).strip()
|
||||
if len(version) == 0:
|
||||
# In older revisions, version was embedded in setup.py
|
||||
version = (
|
||||
host.check_output(["grep", '"version = \'"', "vision/setup.py"])
|
||||
.strip()
|
||||
.split("'")[1][:-2]
|
||||
)
|
||||
build_date = (
|
||||
host.check_output("cd vision && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd vision && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
vision_wheel_name = host.list_dir("vision/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("vision", "dist", vision_wheel_name))
|
||||
|
||||
print("Copying TorchVision wheel")
|
||||
host.download_wheel(os.path.join("vision", "dist", vision_wheel_name))
|
||||
if run_smoke_tests:
|
||||
host.run_cmd(
|
||||
f"pip3 install {os.path.join('vision', 'dist', vision_wheel_name)}"
|
||||
)
|
||||
host.run_cmd("python3 vision/test/smoke_test.py")
|
||||
print("Delete vision checkout")
|
||||
host.run_cmd("rm -rf vision")
|
||||
|
||||
return vision_wheel_name
|
||||
|
||||
|
||||
def build_torchdata(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchData repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/data",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.13.1": ("0.5.1", ""),
|
||||
"v2.0.0": ("0.6.0", "rc5"),
|
||||
"v2.0.1": ("0.6.1", "rc1"),
|
||||
},
|
||||
)
|
||||
print("Building TorchData wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f data/version.txt ]; then cat data/version.txt; fi"]
|
||||
).strip()
|
||||
build_date = (
|
||||
host.check_output("cd data && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd data && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
wheel_name = host.list_dir("data/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("data", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchData wheel")
|
||||
host.download_wheel(os.path.join("data", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def build_torchtext(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchText repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/text",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.9.0": ("0.10.0", "rc1"),
|
||||
"v1.10.0": ("0.11.0", "rc2"),
|
||||
"v1.10.1": ("0.11.1", "rc1"),
|
||||
"v1.10.2": ("0.11.2", "rc1"),
|
||||
"v1.11.0": ("0.12.0", "rc1"),
|
||||
"v1.12.0": ("0.13.0", "rc2"),
|
||||
"v1.12.1": ("0.13.1", "rc5"),
|
||||
"v1.13.0": ("0.14.0", "rc3"),
|
||||
"v1.13.1": ("0.14.1", "rc1"),
|
||||
"v2.0.0": ("0.15.1", "rc2"),
|
||||
"v2.0.1": ("0.15.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchText wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = host.check_output(
|
||||
["if [ -f text/version.txt ]; then cat text/version.txt; fi"]
|
||||
).strip()
|
||||
build_date = (
|
||||
host.check_output("cd text && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd text && {build_vars} python3 -m build --wheel --no-isolation")
|
||||
wheel_name = host.list_dir("text/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("text", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchText wheel")
|
||||
host.download_wheel(os.path.join("text", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def build_torchaudio(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> str:
|
||||
print("Checking out TorchAudio repo")
|
||||
git_clone_flags += " --recurse-submodules"
|
||||
build_version = checkout_repo(
|
||||
host,
|
||||
branch=branch,
|
||||
url="https://github.com/pytorch/audio",
|
||||
git_clone_flags=git_clone_flags,
|
||||
mapping={
|
||||
"v1.9.0": ("0.9.0", "rc2"),
|
||||
"v1.10.0": ("0.10.0", "rc5"),
|
||||
"v1.10.1": ("0.10.1", "rc1"),
|
||||
"v1.10.2": ("0.10.2", "rc1"),
|
||||
"v1.11.0": ("0.11.0", "rc1"),
|
||||
"v1.12.0": ("0.12.0", "rc3"),
|
||||
"v1.12.1": ("0.12.1", "rc5"),
|
||||
"v1.13.0": ("0.13.0", "rc4"),
|
||||
"v1.13.1": ("0.13.1", "rc2"),
|
||||
"v2.0.0": ("2.0.1", "rc3"),
|
||||
"v2.0.1": ("2.0.2", "rc2"),
|
||||
},
|
||||
)
|
||||
print("Building TorchAudio wheel")
|
||||
build_vars = ""
|
||||
if branch == "nightly":
|
||||
version = (
|
||||
host.check_output(["grep", '"version = \'"', "audio/setup.py"])
|
||||
.strip()
|
||||
.split("'")[1][:-2]
|
||||
)
|
||||
build_date = (
|
||||
host.check_output("cd audio && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
|
||||
elif build_version is not None:
|
||||
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(
|
||||
f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \
|
||||
&& ./packaging/ffmpeg/build.sh \
|
||||
&& {build_vars} python3 -m build --wheel --no-isolation"
|
||||
)
|
||||
|
||||
wheel_name = host.list_dir("audio/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("audio", "dist", wheel_name))
|
||||
|
||||
print("Copying TorchAudio wheel")
|
||||
host.download_wheel(os.path.join("audio", "dist", wheel_name))
|
||||
|
||||
return wheel_name
|
||||
|
||||
|
||||
def configure_system(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
compiler: str = "gcc-8",
|
||||
use_conda: bool = True,
|
||||
python_version: str = "3.8",
|
||||
) -> None:
|
||||
if use_conda:
|
||||
install_condaforge_python(host, python_version)
|
||||
|
||||
print("Configuring the system")
|
||||
if not host.using_docker():
|
||||
update_apt_repo(host)
|
||||
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip")
|
||||
else:
|
||||
host.run_cmd("yum install -y sudo")
|
||||
host.run_cmd("conda install -y ninja scons")
|
||||
|
||||
if not use_conda:
|
||||
host.run_cmd(
|
||||
"sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip"
|
||||
)
|
||||
host.run_cmd("pip3 install dataclasses typing-extensions")
|
||||
if not use_conda:
|
||||
print("Installing Cython + numpy from PyPy")
|
||||
host.run_cmd("sudo pip3 install Cython")
|
||||
host.run_cmd("sudo pip3 install numpy")
|
||||
|
||||
|
||||
def build_domains(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> tuple[str, str, str, str]:
|
||||
vision_wheel_name = build_torchvision(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
audio_wheel_name = build_torchaudio(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
data_wheel_name = build_torchdata(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
text_wheel_name = build_torchtext(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
return (vision_wheel_name, audio_wheel_name, data_wheel_name, text_wheel_name)
|
||||
|
||||
|
||||
def start_build(
|
||||
host: RemoteHost,
|
||||
*,
|
||||
branch: str = "main",
|
||||
compiler: str = "gcc-8",
|
||||
use_conda: bool = True,
|
||||
python_version: str = "3.8",
|
||||
pytorch_only: bool = False,
|
||||
pytorch_build_number: Optional[str] = None,
|
||||
shallow_clone: bool = True,
|
||||
enable_mkldnn: bool = False,
|
||||
) -> tuple[str, str, str, str, str]:
|
||||
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
|
||||
if host.using_docker() and not use_conda:
|
||||
print("Auto-selecting conda option for docker images")
|
||||
use_conda = True
|
||||
if not host.using_docker():
|
||||
print("Disable mkldnn for host builds")
|
||||
enable_mkldnn = False
|
||||
|
||||
configure_system(
|
||||
host, compiler=compiler, use_conda=use_conda, python_version=python_version
|
||||
)
|
||||
|
||||
if host.using_docker():
|
||||
print("Move libgfortant.a into a standard location")
|
||||
# HACK: pypa gforntran.a is compiled without PIC, which leads to the following error
|
||||
# libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17' # noqa: E501, B950
|
||||
# Workaround by copying gfortran library from the host
|
||||
host.run_ssh_cmd("sudo apt-get install -y gfortran-8")
|
||||
host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8")
|
||||
host.run_ssh_cmd(
|
||||
[
|
||||
"docker",
|
||||
"cp",
|
||||
"/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a",
|
||||
f"{host.container_id}:/opt/rh/devtoolset-10/root/usr/lib/gcc/aarch64-redhat-linux/10/",
|
||||
]
|
||||
)
|
||||
|
||||
print("Checking out PyTorch repo")
|
||||
host.run_cmd(
|
||||
f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}"
|
||||
)
|
||||
|
||||
host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh")
|
||||
|
||||
print("Building PyTorch wheel")
|
||||
build_opts = ""
|
||||
if pytorch_build_number is not None:
|
||||
build_opts += f" -C--build-option=--build-number={pytorch_build_number}"
|
||||
# Breakpad build fails on aarch64
|
||||
build_vars = "USE_BREAKPAD=0 "
|
||||
if branch == "nightly":
|
||||
build_date = (
|
||||
host.check_output("cd pytorch && git log --pretty=format:%s -1")
|
||||
.strip()
|
||||
.split()[0]
|
||||
.replace("-", "")
|
||||
)
|
||||
version = host.check_output("cat pytorch/version.txt").strip()[:-2]
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1"
|
||||
if branch.startswith(("v1.", "v2.")):
|
||||
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
|
||||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
if enable_mkldnn:
|
||||
host.run_cmd("pytorch/.ci/docker/common/install_acl.sh")
|
||||
print("build pytorch with mkldnn+acl backend")
|
||||
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
|
||||
build_vars += " BLAS=OpenBLAS"
|
||||
build_vars += " OpenBLAS_HOME=/opt/OpenBLAS"
|
||||
build_vars += " ACL_ROOT_DIR=/acl"
|
||||
host.run_cmd(
|
||||
f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
|
||||
)
|
||||
print("Repair the wheel")
|
||||
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
|
||||
ld_library_path = "/acl/build:$HOME/pytorch/build/lib"
|
||||
host.run_cmd(
|
||||
f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}"
|
||||
)
|
||||
print("replace the original wheel with the repaired one")
|
||||
pytorch_repaired_wheel_name = host.list_dir("wheelhouse")[0]
|
||||
host.run_cmd(
|
||||
f"cp $HOME/wheelhouse/{pytorch_repaired_wheel_name} $HOME/pytorch/dist/{pytorch_wheel_name}"
|
||||
)
|
||||
else:
|
||||
print("build pytorch without mkldnn backend")
|
||||
host.run_cmd(
|
||||
f"cd pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
|
||||
)
|
||||
|
||||
print("Deleting build folder")
|
||||
host.run_cmd("cd pytorch && rm -rf build")
|
||||
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("pytorch", "dist", pytorch_wheel_name))
|
||||
print("Copying the wheel")
|
||||
host.download_wheel(os.path.join("pytorch", "dist", pytorch_wheel_name))
|
||||
|
||||
print("Installing PyTorch wheel")
|
||||
host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}")
|
||||
|
||||
if pytorch_only:
|
||||
return (pytorch_wheel_name, None, None, None, None)
|
||||
domain_wheels = build_domains(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
|
||||
return (pytorch_wheel_name, *domain_wheels)
|
||||
|
||||
|
||||
embed_library_script = """
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from auditwheel.patcher import Patchelf
|
||||
from auditwheel.wheeltools import InWheelCtx
|
||||
from auditwheel.elfutils import elf_file_filter
|
||||
from auditwheel.repair import copylib
|
||||
from auditwheel.lddtree import lddtree
|
||||
from subprocess import check_call
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
|
||||
def replace_tag(filename):
|
||||
with open(filename, 'r') as f:
|
||||
lines = f.read().split("\\n")
|
||||
for i,line in enumerate(lines):
|
||||
if not line.startswith("Tag: "):
|
||||
continue
|
||||
lines[i] = line.replace("-linux_", "-manylinux2014_")
|
||||
print(f'Updated tag from {line} to {lines[i]}')
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
f.write("\\n".join(lines))
|
||||
|
||||
|
||||
class AlignedPatchelf(Patchelf):
|
||||
def set_soname(self, file_name: str, new_soname: str) -> None:
|
||||
check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name])
|
||||
|
||||
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
|
||||
check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name])
|
||||
|
||||
|
||||
def embed_library(whl_path, lib_soname, update_tag=False):
|
||||
patcher = AlignedPatchelf()
|
||||
out_dir = TemporaryDirectory()
|
||||
whl_name = os.path.basename(whl_path)
|
||||
tmp_whl_name = os.path.join(out_dir.name, whl_name)
|
||||
with InWheelCtx(whl_path) as ctx:
|
||||
torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib')
|
||||
ctx.out_wheel=tmp_whl_name
|
||||
new_lib_path, new_lib_soname = None, None
|
||||
for filename, elf in elf_file_filter(ctx.iter_files()):
|
||||
if not filename.startswith('torch/lib'):
|
||||
continue
|
||||
libtree = lddtree(filename)
|
||||
if lib_soname not in libtree['needed']:
|
||||
continue
|
||||
lib_path = libtree['libs'][lib_soname]['path']
|
||||
if lib_path is None:
|
||||
print(f"Can't embed {lib_soname} as it could not be found")
|
||||
break
|
||||
if lib_path.startswith(torchlib_path):
|
||||
continue
|
||||
|
||||
if new_lib_path is None:
|
||||
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
|
||||
patcher.replace_needed(filename, lib_soname, new_lib_soname)
|
||||
print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}')
|
||||
if update_tag:
|
||||
# Add manylinux2014 tag
|
||||
for filename in ctx.iter_files():
|
||||
if os.path.basename(filename) != 'WHEEL':
|
||||
continue
|
||||
replace_tag(filename)
|
||||
shutil.move(tmp_whl_name, whl_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
|
||||
"""
|
||||
|
||||
|
||||
def run_tests(host: RemoteHost, whl: str, branch="main") -> None:
|
||||
print("Configuring the system")
|
||||
update_apt_repo(host)
|
||||
host.run_cmd("sudo apt-get install -y python3-pip git")
|
||||
host.run_cmd("sudo pip3 install Cython")
|
||||
host.run_cmd("sudo pip3 install numpy")
|
||||
host.upload_file(whl, ".")
|
||||
host.run_cmd(f"sudo pip3 install {whl}")
|
||||
host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'")
|
||||
host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch")
|
||||
host.run_cmd("cd pytorch/test; python3 test_torch.py -v")
|
||||
|
||||
|
||||
def get_instance_name(instance) -> Optional[str]:
|
||||
if instance.tags is None:
|
||||
return None
|
||||
for tag in instance.tags:
|
||||
if tag["Key"] == "Name":
|
||||
return tag["Value"]
|
||||
return None
|
||||
|
||||
|
||||
def list_instances(instance_type: str) -> None:
|
||||
print(f"All instances of type {instance_type}")
|
||||
for instance in ec2_instances_of_type(instance_type):
|
||||
ifaces = instance.network_interfaces
|
||||
az = ifaces[0].subnet.availability_zone if len(ifaces) > 0 else None
|
||||
print(
|
||||
f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']} {az}"
|
||||
)
|
||||
|
||||
|
||||
def terminate_instances(instance_type: str) -> None:
|
||||
print(f"Terminating all instances of type {instance_type}")
|
||||
instances = list(ec2_instances_of_type(instance_type))
|
||||
for instance in instances:
|
||||
print(f"Terminating {instance.id}")
|
||||
instance.terminate()
|
||||
print("Waiting for termination to complete")
|
||||
for instance in instances:
|
||||
instance.wait_until_terminated()
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("Build and test AARCH64 wheels using EC2")
|
||||
parser.add_argument("--key-name", type=str)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--build-only", action="store_true")
|
||||
parser.add_argument("--test-only", type=str)
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--os", type=str, choices=list(os_amis.keys()))
|
||||
group.add_argument("--ami", type=str)
|
||||
parser.add_argument(
|
||||
"--python-version",
|
||||
type=str,
|
||||
choices=[f"3.{d}" for d in range(6, 12)],
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--alloc-instance", action="store_true")
|
||||
parser.add_argument("--list-instances", action="store_true")
|
||||
parser.add_argument("--pytorch-only", action="store_true")
|
||||
parser.add_argument("--keep-running", action="store_true")
|
||||
parser.add_argument("--terminate-instances", action="store_true")
|
||||
parser.add_argument("--instance-type", type=str, default="t4g.2xlarge")
|
||||
parser.add_argument("--ebs-size", type=int, default=50)
|
||||
parser.add_argument("--branch", type=str, default="main")
|
||||
parser.add_argument("--use-docker", action="store_true")
|
||||
parser.add_argument(
|
||||
"--compiler",
|
||||
type=str,
|
||||
choices=["gcc-7", "gcc-8", "gcc-9", "clang"],
|
||||
default="gcc-8",
|
||||
)
|
||||
parser.add_argument("--use-torch-from-pypi", action="store_true")
|
||||
parser.add_argument("--pytorch-build-number", type=str, default=None)
|
||||
parser.add_argument("--disable-mkldnn", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
ami = (
|
||||
args.ami
|
||||
if args.ami is not None
|
||||
else os_amis[args.os]
|
||||
if args.os is not None
|
||||
else ubuntu20_04_ami
|
||||
)
|
||||
keyfile_path, key_name = compute_keyfile_path(args.key_name)
|
||||
|
||||
if args.list_instances:
|
||||
list_instances(args.instance_type)
|
||||
sys.exit(0)
|
||||
|
||||
if args.terminate_instances:
|
||||
terminate_instances(args.instance_type)
|
||||
sys.exit(0)
|
||||
|
||||
if len(key_name) == 0:
|
||||
raise RuntimeError("""
|
||||
Cannot start build without key_name, please specify
|
||||
--key-name argument or AWS_KEY_NAME environment variable.""")
|
||||
if len(keyfile_path) == 0 or not os.path.exists(keyfile_path):
|
||||
raise RuntimeError(f"""
|
||||
Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please
|
||||
check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""")
|
||||
|
||||
# Starting the instance
|
||||
inst = start_instance(
|
||||
key_name, ami=ami, instance_type=args.instance_type, ebs_size=args.ebs_size
|
||||
)
|
||||
instance_name = f"{args.key_name}-{args.os}"
|
||||
if args.python_version is not None:
|
||||
instance_name += f"-py{args.python_version}"
|
||||
inst.create_tags(
|
||||
DryRun=False,
|
||||
Tags=[
|
||||
{
|
||||
"Key": "Name",
|
||||
"Value": instance_name,
|
||||
}
|
||||
],
|
||||
)
|
||||
addr = inst.public_dns_name
|
||||
wait_for_connection(addr, 22)
|
||||
host = RemoteHost(addr, keyfile_path)
|
||||
host.ami = ami
|
||||
if args.use_docker:
|
||||
update_apt_repo(host)
|
||||
host.start_docker()
|
||||
|
||||
if args.test_only:
|
||||
run_tests(host, args.test_only)
|
||||
sys.exit(0)
|
||||
|
||||
if args.alloc_instance:
|
||||
if args.python_version is None:
|
||||
sys.exit(0)
|
||||
install_condaforge_python(host, args.python_version)
|
||||
sys.exit(0)
|
||||
|
||||
python_version = args.python_version if args.python_version is not None else "3.10"
|
||||
|
||||
if args.use_torch_from_pypi:
|
||||
configure_system(host, compiler=args.compiler, python_version=python_version)
|
||||
print("Installing PyTorch wheel")
|
||||
host.run_cmd("pip3 install torch")
|
||||
build_domains(
|
||||
host, branch=args.branch, git_clone_flags=" --depth 1 --shallow-submodules"
|
||||
)
|
||||
else:
|
||||
start_build(
|
||||
host,
|
||||
branch=args.branch,
|
||||
compiler=args.compiler,
|
||||
python_version=python_version,
|
||||
pytorch_only=args.pytorch_only,
|
||||
pytorch_build_number=args.pytorch_build_number,
|
||||
enable_mkldnn=not args.disable_mkldnn,
|
||||
)
|
||||
if not args.keep_running:
|
||||
print(f"Waiting for instance {inst.id} to terminate")
|
||||
inst.terminate()
|
||||
inst.wait_until_terminated()
|
||||
87
.ci/aarch64_linux/embed_library.py
Normal file
87
.ci/aarch64_linux/embed_library.py
Normal file
@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from subprocess import check_call
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from auditwheel.elfutils import elf_file_filter
|
||||
from auditwheel.lddtree import lddtree
|
||||
from auditwheel.patcher import Patchelf
|
||||
from auditwheel.repair import copylib
|
||||
from auditwheel.wheeltools import InWheelCtx
|
||||
|
||||
|
||||
def replace_tag(filename):
|
||||
with open(filename) as f:
|
||||
lines = f.read().split("\\n")
|
||||
for i, line in enumerate(lines):
|
||||
if not line.startswith("Tag: "):
|
||||
continue
|
||||
lines[i] = line.replace("-linux_", "-manylinux2014_")
|
||||
print(f"Updated tag from {line} to {lines[i]}")
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("\\n".join(lines))
|
||||
|
||||
|
||||
class AlignedPatchelf(Patchelf):
|
||||
def set_soname(self, file_name: str, new_soname: str) -> None:
|
||||
check_call(
|
||||
["patchelf", "--page-size", "65536", "--set-soname", new_soname, file_name]
|
||||
)
|
||||
|
||||
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
|
||||
check_call(
|
||||
[
|
||||
"patchelf",
|
||||
"--page-size",
|
||||
"65536",
|
||||
"--replace-needed",
|
||||
soname,
|
||||
new_soname,
|
||||
file_name,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def embed_library(whl_path, lib_soname, update_tag=False):
|
||||
patcher = AlignedPatchelf()
|
||||
out_dir = TemporaryDirectory()
|
||||
whl_name = os.path.basename(whl_path)
|
||||
tmp_whl_name = os.path.join(out_dir.name, whl_name)
|
||||
with InWheelCtx(whl_path) as ctx:
|
||||
torchlib_path = os.path.join(ctx._tmpdir.name, "torch", "lib")
|
||||
ctx.out_wheel = tmp_whl_name
|
||||
new_lib_path, new_lib_soname = None, None
|
||||
for filename, _ in elf_file_filter(ctx.iter_files()):
|
||||
if not filename.startswith("torch/lib"):
|
||||
continue
|
||||
libtree = lddtree(filename)
|
||||
if lib_soname not in libtree["needed"]:
|
||||
continue
|
||||
lib_path = libtree["libs"][lib_soname]["path"]
|
||||
if lib_path is None:
|
||||
print(f"Can't embed {lib_soname} as it could not be found")
|
||||
break
|
||||
if lib_path.startswith(torchlib_path):
|
||||
continue
|
||||
|
||||
if new_lib_path is None:
|
||||
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
|
||||
patcher.replace_needed(filename, lib_soname, new_lib_soname)
|
||||
print(f"Replacing {lib_soname} with {new_lib_soname} for {filename}")
|
||||
if update_tag:
|
||||
# Add manylinux2014 tag
|
||||
for filename in ctx.iter_files():
|
||||
if os.path.basename(filename) != "WHEEL":
|
||||
continue
|
||||
replace_tag(filename)
|
||||
shutil.move(tmp_whl_name, whl_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
embed_library(
|
||||
sys.argv[1], "libgomp.so.1", len(sys.argv) > 2 and sys.argv[2] == "--update-tag"
|
||||
)
|
||||
@ -4,17 +4,14 @@ set -ex
|
||||
|
||||
SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||
|
||||
# Source the common build script for architecture-specific configurations (MKLDNN, ACL, etc.)
|
||||
source "${SCRIPTPATH}/../pytorch/build.sh" || true
|
||||
|
||||
case "${GPU_ARCH_TYPE:-BLANK}" in
|
||||
cuda | cuda-aarch64)
|
||||
cuda)
|
||||
bash "${SCRIPTPATH}/build_cuda.sh"
|
||||
;;
|
||||
rocm)
|
||||
bash "${SCRIPTPATH}/build_rocm.sh"
|
||||
;;
|
||||
cpu | cpu-cxx11-abi | cpu-aarch64 | cpu-s390x)
|
||||
cpu | cpu-cxx11-abi | cpu-s390x)
|
||||
bash "${SCRIPTPATH}/build_cpu.sh"
|
||||
;;
|
||||
xpu)
|
||||
|
||||
@ -18,31 +18,12 @@ retry () {
|
||||
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
|
||||
}
|
||||
|
||||
# Detect architecture first
|
||||
ARCH=$(uname -m)
|
||||
echo "Detected architecture: $ARCH"
|
||||
|
||||
PLATFORM=""
|
||||
# TODO move this into the Docker images
|
||||
OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release)
|
||||
if [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
|
||||
retry yum install -q -y zip openssl
|
||||
# Set platform based on architecture
|
||||
case $ARCH in
|
||||
x86_64)
|
||||
PLATFORM="manylinux_2_28_x86_64"
|
||||
;;
|
||||
aarch64)
|
||||
PLATFORM="manylinux_2_28_aarch64"
|
||||
;;
|
||||
s390x)
|
||||
PLATFORM="manylinux_2_28_s390x"
|
||||
;;
|
||||
*)
|
||||
echo "Unsupported architecture: $ARCH"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
PLATFORM="manylinux_2_28_x86_64"
|
||||
elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
|
||||
retry dnf install -q -y zip openssl
|
||||
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
|
||||
@ -57,8 +38,6 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Platform set to: $PLATFORM"
|
||||
|
||||
# We use the package name to test the package by passing this to 'pip install'
|
||||
# This is the env variable that setup.py uses to name the package. Note that
|
||||
# pip 'normalizes' the name first by changing all - to _
|
||||
@ -320,8 +299,8 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
|
||||
# ROCm workaround for roctracer dlopens
|
||||
if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then
|
||||
patchedpath=$(fname_without_so_number $destpath)
|
||||
# Keep the so number for XPU dependencies, libgomp.so.1, ACL libraries, and NVPL libraries to avoid twice load
|
||||
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" || "$filename" == libarm_compute* || "$filename" == libnvpl* || "$filename" == "libgfortran.so.5" ]]; then
|
||||
# Keep the so number for XPU dependencies and libgomp.so.1 to avoid twice load
|
||||
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" ]]; then
|
||||
patchedpath=$destpath
|
||||
else
|
||||
patchedpath=$(fname_with_sha256 $destpath)
|
||||
@ -367,22 +346,9 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
|
||||
done
|
||||
|
||||
# create Manylinux 2_28 tag this needs to happen before regenerate the RECORD
|
||||
# Support all architectures (x86_64, aarch64, s390x)
|
||||
if [[ "$IS_MANYLINUX2_28" == "1" && $GPU_ARCH_TYPE != "xpu" ]]; then
|
||||
if [[ $PLATFORM == "manylinux_2_28_x86_64" && $GPU_ARCH_TYPE != "cpu-s390x" && $GPU_ARCH_TYPE != "xpu" ]]; then
|
||||
wheel_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/WHEEL/g')
|
||||
echo "Updating wheel tag for $ARCH architecture"
|
||||
# Replace linux_* with manylinux_2_28_* based on architecture
|
||||
case $ARCH in
|
||||
x86_64)
|
||||
sed -i -e 's#linux_x86_64#manylinux_2_28_x86_64#g' $wheel_file
|
||||
;;
|
||||
aarch64)
|
||||
sed -i -e 's#linux_aarch64#manylinux_2_28_aarch64#g' $wheel_file
|
||||
;;
|
||||
s390x)
|
||||
sed -i -e 's#linux_s390x#manylinux_2_28_s390x#g' $wheel_file
|
||||
;;
|
||||
esac
|
||||
sed -i -e s#linux_x86_64#"${PLATFORM}"# $wheel_file;
|
||||
fi
|
||||
|
||||
# regenerate the RECORD file with new hashes
|
||||
|
||||
@ -15,10 +15,6 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS=()
|
||||
fi
|
||||
|
||||
# Detect architecture
|
||||
ARCH=$(uname -m)
|
||||
echo "Building CPU wheel for architecture: $ARCH"
|
||||
|
||||
WHEELHOUSE_DIR="wheelhousecpu"
|
||||
LIBTORCH_HOUSE_DIR="libtorch_housecpu"
|
||||
if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then
|
||||
@ -38,10 +34,8 @@ elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
|
||||
elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
|
||||
LIBGOMP_PATH="/usr/lib64/libgomp.so.1"
|
||||
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
|
||||
if [[ "$ARCH" == "s390x" ]]; then
|
||||
if [[ "$(uname -m)" == "s390x" ]]; then
|
||||
LIBGOMP_PATH="/usr/lib/s390x-linux-gnu/libgomp.so.1"
|
||||
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||
LIBGOMP_PATH="/usr/lib/aarch64-linux-gnu/libgomp.so.1"
|
||||
else
|
||||
LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1"
|
||||
fi
|
||||
@ -55,34 +49,6 @@ DEPS_SONAME=(
|
||||
"libgomp.so.1"
|
||||
)
|
||||
|
||||
# Add ARM-specific library dependencies for CPU builds
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Adding ARM-specific CPU library dependencies"
|
||||
|
||||
# ARM Compute Library (if available)
|
||||
if [[ -d "/acl/build" ]]; then
|
||||
echo "Adding ARM Compute Library for CPU"
|
||||
DEPS_LIST+=(
|
||||
"/acl/build/libarm_compute.so"
|
||||
"/acl/build/libarm_compute_graph.so"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libarm_compute.so"
|
||||
"libarm_compute_graph.so"
|
||||
)
|
||||
fi
|
||||
|
||||
# ARM system libraries
|
||||
DEPS_LIST+=(
|
||||
"/usr/lib64/libgfortran.so.5"
|
||||
"/opt/OpenBLAS/lib/libopenblas.so.0"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libgfortran.so.5"
|
||||
"libopenblas.so.0"
|
||||
)
|
||||
fi
|
||||
|
||||
rm -rf /usr/local/cuda*
|
||||
|
||||
SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )"
|
||||
|
||||
@ -29,10 +29,6 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS=()
|
||||
fi
|
||||
|
||||
# Detect architecture
|
||||
ARCH=$(uname -m)
|
||||
echo "Building for architecture: $ARCH"
|
||||
|
||||
# Determine CUDA version and architectures to build for
|
||||
#
|
||||
# NOTE: We should first check `DESIRED_CUDA` when determining `CUDA_VERSION`,
|
||||
@ -57,60 +53,34 @@ fi
|
||||
cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
|
||||
|
||||
# Function to remove architectures from a list
|
||||
remove_archs() {
|
||||
local result="$1"
|
||||
shift
|
||||
for arch in "$@"; do
|
||||
result="${result//${arch};/}"
|
||||
done
|
||||
echo "$result"
|
||||
}
|
||||
|
||||
# Function to filter CUDA architectures for aarch64
|
||||
# aarch64 ARM GPUs only support certain compute capabilities
|
||||
# Keep: 8.0 (A100), 9.0+ (Hopper, Grace Hopper, newer)
|
||||
# Remove: < 8.0 (no ARM GPUs), 8.6 (x86_64 RTX 3090/A6000 only)
|
||||
filter_aarch64_archs() {
|
||||
local arch_list="$1"
|
||||
# Explicitly remove architectures not needed on aarch64
|
||||
arch_list=$(remove_archs "$arch_list" "5.0" "6.0" "7.0" "7.5" "8.6")
|
||||
echo "$arch_list"
|
||||
}
|
||||
|
||||
# Base: Common architectures across all modern CUDA versions
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0"
|
||||
|
||||
case ${CUDA_VERSION} in
|
||||
12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;${TORCH_CUDA_ARCH_LIST}" ;; # Only 12.6 includes Legacy Maxwell/Pascal that will be removed in future releases
|
||||
12.8) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0" ;; # +Hopper/Blackwell support
|
||||
12.9) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0+PTX" # +Hopper/Blackwell support + PTX for forward compatibility
|
||||
#removing sm_50-sm_60 as these architectures are deprecated in CUDA 12.8/9 and will be removed in future releases
|
||||
#however we would like to keep sm_70 architecture see: https://github.com/pytorch/pytorch/issues/157517
|
||||
12.8)
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0"
|
||||
;;
|
||||
12.9)
|
||||
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0+PTX"
|
||||
# WAR to resolve the ld error in libtorch build with CUDA 12.9
|
||||
if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//7.0;/}" # Remove 7.0 to resolve the ld error
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//8.6;/}" # Remove 8.6 for libtorch
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX"
|
||||
fi
|
||||
;;
|
||||
13.0)
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;$([[ "$ARCH" == "aarch64" ]] && echo "11.0;" || echo "")12.0+PTX"
|
||||
export TORCH_NVCC_FLAGS="-compress-mode=size"
|
||||
export BUILD_BUNDLE_PTXAS=1
|
||||
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX"
|
||||
;;
|
||||
12.6)
|
||||
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0"
|
||||
;;
|
||||
*)
|
||||
echo "unknown cuda version $CUDA_VERSION"
|
||||
exit 1
|
||||
;;
|
||||
*) echo "unknown cuda version $CUDA_VERSION"; exit 1 ;;
|
||||
esac
|
||||
|
||||
# Filter for aarch64: Remove < 8.0 and 8.6
|
||||
[[ "$ARCH" == "aarch64" ]] && TORCH_CUDA_ARCH_LIST=$(filter_aarch64_archs "$TORCH_CUDA_ARCH_LIST")
|
||||
|
||||
echo "TORCH_CUDA_ARCH_LIST set to: $TORCH_CUDA_ARCH_LIST"
|
||||
export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}
|
||||
echo "${TORCH_CUDA_ARCH_LIST}"
|
||||
|
||||
# Disable MAGMA for aarch64 as pre-built libraries are x86-64 only
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Disabling MAGMA for aarch64 architecture"
|
||||
export USE_MAGMA=0
|
||||
fi
|
||||
|
||||
# Package directories
|
||||
WHEELHOUSE_DIR="wheelhouse$cuda_version_nodot"
|
||||
LIBTORCH_HOUSE_DIR="libtorch_house$cuda_version_nodot"
|
||||
@ -274,51 +244,6 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Add ARM-specific library dependencies
|
||||
if [[ "$ARCH" == "aarch64" ]]; then
|
||||
echo "Adding ARM-specific library dependencies"
|
||||
|
||||
# ARM Compute Library (if available)
|
||||
if [[ -d "/acl/build" ]]; then
|
||||
echo "Adding ARM Compute Library"
|
||||
DEPS_LIST+=(
|
||||
"/acl/build/libarm_compute.so"
|
||||
"/acl/build/libarm_compute_graph.so"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libarm_compute.so"
|
||||
"libarm_compute_graph.so"
|
||||
)
|
||||
fi
|
||||
|
||||
# ARM system libraries
|
||||
DEPS_LIST+=(
|
||||
"/lib64/libgomp.so.1"
|
||||
"/usr/lib64/libgfortran.so.5"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libgomp.so.1"
|
||||
"libgfortran.so.5"
|
||||
)
|
||||
|
||||
# NVPL libraries (ARM optimized BLAS/LAPACK)
|
||||
if [[ -d "/usr/local/lib" && -f "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ]]; then
|
||||
echo "Adding NVPL libraries for ARM"
|
||||
DEPS_LIST+=(
|
||||
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0"
|
||||
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0"
|
||||
"/usr/local/lib/libnvpl_lapack_core.so.0"
|
||||
"/usr/local/lib/libnvpl_blas_core.so.0"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libnvpl_lapack_lp64_gomp.so.0"
|
||||
"libnvpl_blas_lp64_gomp.so.0"
|
||||
"libnvpl_lapack_core.so.0"
|
||||
"libnvpl_blas_core.so.0"
|
||||
)
|
||||
fi
|
||||
fi
|
||||
|
||||
# run_tests.sh requires DESIRED_CUDA to know what tests to exclude
|
||||
export DESIRED_CUDA="$cuda_version_nodot"
|
||||
|
||||
@ -326,11 +251,9 @@ export DESIRED_CUDA="$cuda_version_nodot"
|
||||
rm -rf /usr/local/cuda || true
|
||||
ln -s "/usr/local/cuda-${CUDA_VERSION}" /usr/local/cuda
|
||||
|
||||
# Switch `/usr/local/magma` to the desired CUDA version (skip for aarch64)
|
||||
if [[ "$ARCH" != "aarch64" ]]; then
|
||||
rm -rf /usr/local/magma || true
|
||||
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
|
||||
fi
|
||||
# Switch `/usr/local/magma` to the desired CUDA version
|
||||
rm -rf /usr/local/magma || true
|
||||
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
|
||||
|
||||
export CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) # 10.0.130
|
||||
export CUDA_VERSION_SHORT=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev | cut -f1,2 -d".") # 10.0
|
||||
|
||||
@ -86,20 +86,10 @@ else
|
||||
fi
|
||||
fi
|
||||
|
||||
# Enable MKLDNN with ARM Compute Library for ARM builds
|
||||
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
|
||||
export USE_MKLDNN=1
|
||||
|
||||
# ACL is required for aarch64 builds
|
||||
if [[ ! -d "/acl" ]]; then
|
||||
echo "ERROR: ARM Compute Library not found at /acl"
|
||||
echo "ACL is required for aarch64 builds. Check Docker image setup."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export USE_MKLDNN_ACL=1
|
||||
export ACL_ROOT_DIR=/acl
|
||||
echo "ARM Compute Library enabled for MKLDNN: ACL_ROOT_DIR=/acl"
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then
|
||||
|
||||
@ -96,6 +96,7 @@ function pip_build_and_install() {
|
||||
python3 -m pip wheel \
|
||||
--no-build-isolation \
|
||||
--no-deps \
|
||||
--no-use-pep517 \
|
||||
-w "${wheel_dir}" \
|
||||
"${build_target}"
|
||||
fi
|
||||
@ -307,28 +308,6 @@ function install_torchao() {
|
||||
pip_build_and_install "git+https://github.com/pytorch/ao.git@${commit}" dist/ao
|
||||
}
|
||||
|
||||
function install_flash_attn_cute() {
|
||||
echo "Installing FlashAttention CuTe from GitHub..."
|
||||
# Grab latest main til we have a pinned commit
|
||||
local flash_attn_commit
|
||||
flash_attn_commit=$(git ls-remote https://github.com/Dao-AILab/flash-attention.git HEAD | cut -f1)
|
||||
|
||||
# Clone the repo to a temporary directory
|
||||
rm -rf flash-attention-build
|
||||
git clone --depth 1 --recursive https://github.com/Dao-AILab/flash-attention.git flash-attention-build
|
||||
|
||||
pushd flash-attention-build
|
||||
git checkout "${flash_attn_commit}"
|
||||
|
||||
# Install only the 'cute' sub-directory
|
||||
pip_install -e flash_attn/cute/
|
||||
popd
|
||||
|
||||
# remove the local repo
|
||||
rm -rf flash-attention-build
|
||||
echo "FlashAttention CuTe installation complete."
|
||||
}
|
||||
|
||||
function print_sccache_stats() {
|
||||
echo 'PyTorch Build Statistics'
|
||||
sccache --show-stats
|
||||
|
||||
@ -353,17 +353,6 @@ def test_linalg(device="cpu") -> None:
|
||||
torch.linalg.svd(A)
|
||||
|
||||
|
||||
def test_sdpa(device="cpu", dtype=torch.float16) -> None:
|
||||
"""Regression test for https://github.com/pytorch/pytorch/issues/167602
|
||||
Without nvrtc_builtins on CuDNN-9.13 on CUDA-13 fails with ` No valid execution plans built.`
|
||||
"""
|
||||
print(f"Testing SDPA on {device} using type {dtype}")
|
||||
k, q, v = torch.rand(3, 1, 16, 77, 64, dtype=dtype, device=device).unbind(0)
|
||||
attn = torch.rand(1, 1, 77, 77, dtype=dtype, device=device)
|
||||
rc = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn)
|
||||
assert rc.isnan().any().item() is False
|
||||
|
||||
|
||||
def smoke_test_compile(device: str = "cpu") -> None:
|
||||
supported_dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
|
||||
@ -500,12 +489,10 @@ def main() -> None:
|
||||
smoke_test_conv2d()
|
||||
test_linalg()
|
||||
test_numpy()
|
||||
test_sdpa()
|
||||
|
||||
if is_cuda_system:
|
||||
test_linalg("cuda")
|
||||
test_cuda_gds_errors_captured()
|
||||
test_sdpa("cuda")
|
||||
|
||||
if options.package == "all":
|
||||
smoke_test_modules()
|
||||
|
||||
@ -344,18 +344,8 @@ test_python_smoke() {
|
||||
}
|
||||
|
||||
test_python_smoke_b200() {
|
||||
# Targeted smoke tests for B200 including FlashAttention CuTe coverage
|
||||
install_flash_attn_cute
|
||||
time python test/run_test.py \
|
||||
--include \
|
||||
test_matmul_cuda \
|
||||
test_scaled_matmul_cuda \
|
||||
inductor/test_fp8 \
|
||||
nn/attention/test_fa4 \
|
||||
nn/attention/test_open_registry \
|
||||
inductor/test_flex_flash \
|
||||
$PYTHON_TEST_EXTRA_OPTION \
|
||||
--upload-artifacts-while-running
|
||||
# Targeted smoke tests for B200 - staged approach to avoid too many failures
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
@ -389,13 +379,6 @@ test_lazy_tensor_meta_reference_disabled() {
|
||||
export -n TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE
|
||||
}
|
||||
|
||||
test_dynamo_core() {
|
||||
time python test/run_test.py \
|
||||
--include-dynamo-core-tests \
|
||||
--verbose \
|
||||
--upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_dynamo_wrapped_shard() {
|
||||
if [[ -z "$NUM_TEST_SHARDS" ]]; then
|
||||
@ -1687,22 +1670,6 @@ test_operator_microbenchmark() {
|
||||
done
|
||||
}
|
||||
|
||||
test_attention_microbenchmark() {
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
TEST_DIR=$(pwd)
|
||||
|
||||
# Install attention-gym dependency
|
||||
echo "Installing attention-gym..."
|
||||
python -m pip install git+https://github.com/meta-pytorch/attention-gym.git@main
|
||||
pip show triton
|
||||
|
||||
cd "${TEST_DIR}"/benchmarks/transformer
|
||||
|
||||
$TASKSET python score_mod.py --config configs/config_basic.yaml \
|
||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/attention_microbenchmark.json"
|
||||
}
|
||||
|
||||
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
||||
(cd test && python -c "import torch; print(torch.__config__.show())")
|
||||
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
|
||||
@ -1760,8 +1727,6 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then
|
||||
test_operator_microbenchmark
|
||||
elif [[ "${TEST_CONFIG}" == *attention_microbenchmark* ]]; then
|
||||
test_attention_microbenchmark
|
||||
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
@ -1821,8 +1786,6 @@ elif [[ "${TEST_CONFIG}" == *inductor* ]]; then
|
||||
test_inductor_shard "${SHARD_NUMBER}"
|
||||
elif [[ "${TEST_CONFIG}" == *einops* ]]; then
|
||||
test_einops
|
||||
elif [[ "${TEST_CONFIG}" == *dynamo_core* ]]; then
|
||||
test_dynamo_core
|
||||
elif [[ "${TEST_CONFIG}" == *dynamo_wrapped* ]]; then
|
||||
install_torchvision
|
||||
test_dynamo_wrapped_shard "${SHARD_NUMBER}"
|
||||
|
||||
2
.github/actionlint.yaml
vendored
2
.github/actionlint.yaml
vendored
@ -63,7 +63,7 @@ self-hosted-runner:
|
||||
- linux.rocm.gpu.gfx942.1
|
||||
- linux.rocm.gpu.gfx942.2
|
||||
- linux.rocm.gpu.gfx942.4
|
||||
- linux.rocm.gfx942.docker-cache
|
||||
- rocm-docker
|
||||
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
|
||||
- macos-m1-stable
|
||||
- macos-m1-14
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
ee1a1350eb37804b94334768f328144f058f14e9
|
||||
ad5816f0eee1c873df1b7d371c69f1f811a89387
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
2d82dc5caa336d179d9b46ac4a0fb8c43d84c5cc
|
||||
ccb801b88af136454798b945175c4c87e636ac33
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
94631807d22c09723dd006f7be5beb649d5f88d0
|
||||
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
|
||||
|
||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -7,7 +7,6 @@ ciflow_push_tags:
|
||||
- ciflow/binaries
|
||||
- ciflow/binaries_libtorch
|
||||
- ciflow/binaries_wheel
|
||||
- ciflow/dynamo
|
||||
- ciflow/h100
|
||||
- ciflow/h100-cutlass-backend
|
||||
- ciflow/h100-distributed
|
||||
|
||||
2
.github/scripts/generate_pytorch_version.py
vendored
2
.github/scripts/generate_pytorch_version.py
vendored
@ -50,7 +50,7 @@ def get_tag() -> str:
|
||||
|
||||
def get_base_version() -> str:
|
||||
root = get_pytorch_root()
|
||||
dirty_version = Path(root / "version.txt").read_text().strip()
|
||||
dirty_version = open(root / "version.txt").read().strip()
|
||||
# Strips trailing a0 from version.txt, not too sure why it's there in the
|
||||
# first place
|
||||
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
|
||||
|
||||
7
.github/workflows/_binary-build-linux.yml
vendored
7
.github/workflows/_binary-build-linux.yml
vendored
@ -260,8 +260,11 @@ jobs:
|
||||
"${DOCKER_IMAGE}"
|
||||
)
|
||||
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
|
||||
# Unified build script for all architectures (x86_64, aarch64, s390x)
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
|
||||
if [[ ${BUILD_ENVIRONMENT} == *"aarch64"* ]]; then
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/aarch64_linux/aarch64_ci_build.sh"
|
||||
else
|
||||
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
|
||||
fi
|
||||
|
||||
- name: Chown artifacts
|
||||
if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }}
|
||||
|
||||
2
.github/workflows/_linux-test.yml
vendored
2
.github/workflows/_linux-test.yml
vendored
@ -326,7 +326,7 @@ jobs:
|
||||
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
|
||||
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
|
||||
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
|
||||
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
DOCKER_IMAGE: ${{ inputs.docker-image }}
|
||||
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
|
||||
XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla
|
||||
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
|
||||
|
||||
@ -1,73 +0,0 @@
|
||||
name: attention_op_microbenchmark
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/op-benchmark/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
# Run at 06:00 UTC everyday
|
||||
- cron: 0 7 * * *
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
attn-microbenchmark-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '8.0 9.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
|
||||
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
attn-microbenchmark-test:
|
||||
name: attn-microbenchmark-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: attn-microbenchmark-build
|
||||
with:
|
||||
timeout-minutes: 500
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
# B200 runner
|
||||
opmicrobenchmark-build-b200:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: opmicrobenchmark-build-b200
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
opmicrobenchmark-test-b200:
|
||||
name: opmicrobenchmark-test-b200
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: opmicrobenchmark-build-b200
|
||||
with:
|
||||
timeout-minutes: 500
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
||||
1
.github/workflows/b200-distributed.yml
vendored
1
.github/workflows/b200-distributed.yml
vendored
@ -37,7 +37,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
|
||||
1
.github/workflows/b200-symm-mem.yml
vendored
1
.github/workflows/b200-symm-mem.yml
vendored
@ -37,7 +37,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100-symm
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
|
||||
16
.github/workflows/docker-builds.yml
vendored
16
.github/workflows/docker-builds.yml
vendored
@ -119,22 +119,6 @@ jobs:
|
||||
with:
|
||||
docker-image: ${{ steps.build-docker-image.outputs.docker-image }}
|
||||
|
||||
- name: Generate output
|
||||
if: contains(matrix.docker-image-name, 'rocm')
|
||||
id: generate_output
|
||||
run: |
|
||||
docker_image_name="${{ matrix.docker-image-name }}"
|
||||
docker_image_tag="${{ steps.build-docker-image.outputs.docker-image }}"
|
||||
echo "${docker_image_name}=${docker_image_tag}" >> docker-builds-output-${docker_image_name}.txt
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4.4.0
|
||||
if: contains(matrix.docker-image-name, 'rocm')
|
||||
with:
|
||||
name: docker-builds-artifacts-${{ matrix.docker-image-name }}
|
||||
retention-days: 14
|
||||
path: ./docker-builds-output-${{ matrix.docker-image-name }}.txt
|
||||
|
||||
- uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
|
||||
name: Push to https://ghcr.io/
|
||||
id: push-to-ghcr-io
|
||||
|
||||
55
.github/workflows/docker-cache-mi300.yml
vendored
Normal file
55
.github/workflows/docker-cache-mi300.yml
vendored
Normal file
@ -0,0 +1,55 @@
|
||||
name: docker-cache-mi300
|
||||
|
||||
on:
|
||||
# run every 6 hours
|
||||
schedule:
|
||||
- cron: 0 0,6,12,18 * * *
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
runs-on: rocm-docker
|
||||
steps:
|
||||
- name: Checkout PyTorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
no-sudo: true
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
push: false
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
|
||||
- name: Tar and upload to S3 bucket
|
||||
run: |
|
||||
sudo docker save -o ~/docker-data/pytorch/pytorch_docker_image.tar ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
sudo rclone copy -P --s3-upload-concurrency 64 --s3-chunk-size 200M --s3-upload-cutoff 300M ~/docker-data/pytorch/pytorch_docker_image.tar oci:pytorchbucket0002/pytorch_docker_image --progress
|
||||
105
.github/workflows/docker-cache-rocm.yml
vendored
105
.github/workflows/docker-cache-rocm.yml
vendored
@ -1,105 +0,0 @@
|
||||
name: docker-cache-rocm
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [docker-builds]
|
||||
branches: [main, release]
|
||||
types:
|
||||
- completed
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
actions: read
|
||||
|
||||
jobs:
|
||||
download-docker-builds-artifacts:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: download-docker-builds-artifacts
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
pytorch-linux-jammy-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}
|
||||
pytorch-linux-noble-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}
|
||||
pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}
|
||||
steps:
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4.1.7
|
||||
with:
|
||||
run-id: ${{ github.event.workflow_run.id }}
|
||||
path: ./docker-builds-artifacts
|
||||
merge-multiple: true
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Process artifacts
|
||||
id: process-artifacts
|
||||
run: |
|
||||
ls -R ./docker-builds-artifacts
|
||||
cat ./docker-builds-artifacts/*txt >> "${GITHUB_OUTPUT}"
|
||||
cat "${GITHUB_OUTPUT}"
|
||||
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
needs: download-docker-builds-artifacts
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
runner: [linux.rocm.gfx942.docker-cache]
|
||||
docker-image: [
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
|
||||
]
|
||||
runs-on: "${{ matrix.runner }}"
|
||||
steps:
|
||||
- name: debug
|
||||
run: |
|
||||
JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}"
|
||||
echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}"
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Generate ghrc.io tag
|
||||
id: ghcr-io-tag
|
||||
run: |
|
||||
ecr_image="${{ matrix.docker-image }}"
|
||||
ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}"
|
||||
echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }}
|
||||
|
||||
- name: Save as tarball
|
||||
run: |
|
||||
docker_image_tag=${{ matrix.docker-image }}
|
||||
docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":"
|
||||
docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-"
|
||||
ref_name=${{ github.event.workflow_run.head_branch }}
|
||||
if [[ $ref_name =~ "release/" ]]; then
|
||||
ref_suffix="release"
|
||||
elif [[ $ref_name == "main" ]]; then
|
||||
ref_suffix="main"
|
||||
else
|
||||
echo "Unexpected branch in ref_name: ${ref_name}" && exit 1
|
||||
fi
|
||||
docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }}
|
||||
# mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention
|
||||
docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }}
|
||||
mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar
|
||||
70
.github/workflows/dynamo-unittest.yml
vendored
70
.github/workflows/dynamo-unittest.yml
vendored
@ -1,70 +0,0 @@
|
||||
# Workflow: Dynamo Unit Test
|
||||
# runs unit tests for dynamo.
|
||||
name: dynamo-unittest
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/dynamo/*
|
||||
workflow_call:
|
||||
schedule:
|
||||
- cron: 29 8 * * * # about 1:29am PDT
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
dynamo-build:
|
||||
name: dynamo-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.11', '3.12']
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
dynamo-test:
|
||||
name: dynamo-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: [get-label-type, dynamo-build]
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.11', '3.12']
|
||||
with:
|
||||
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
7
.github/workflows/test-b200.yml
vendored
7
.github/workflows/test-b200.yml
vendored
@ -5,9 +5,7 @@
|
||||
# Flow:
|
||||
# 1. Builds PyTorch with CUDA 12.8+ and sm100 architecture for B200
|
||||
# 2. Runs smoke tests on linux.dgx.b200 runner
|
||||
# 3. Tests executed are defined in .ci/pytorch/test.sh -> test_python_smoke_b200() function
|
||||
# - Includes matmul, scaled_matmul, FP8, and FlashAttention CuTe tests
|
||||
# - FlashAttention CuTe DSL is installed as part of test execution
|
||||
# 3. Tests executed are defined in .ci/pytorch/test.sh -> test_python_smoke() function
|
||||
#
|
||||
# Triggered by:
|
||||
# - Pull requests modifying this workflow file
|
||||
@ -54,7 +52,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
@ -75,4 +72,4 @@ jobs:
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
||||
secrets: inherit
|
||||
|
||||
83
.github/workflows/trunk-rocm-mi300.yml
vendored
83
.github/workflows/trunk-rocm-mi300.yml
vendored
@ -1,83 +0,0 @@
|
||||
name: trunk-rocm-mi300
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- release/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: 29 8 * * * # about 1:29am PDT
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
llm-td:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: before-test
|
||||
uses: ./.github/workflows/llm_td_retrieval.yml
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
target-determination:
|
||||
name: before-test
|
||||
uses: ./.github/workflows/target_determination.yml
|
||||
needs: llm-td
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
|
||||
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
|
||||
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
1
.github/workflows/upload-test-stats.yml
vendored
1
.github/workflows/upload-test-stats.yml
vendored
@ -5,7 +5,6 @@ on:
|
||||
workflows:
|
||||
- pull
|
||||
- trunk
|
||||
- trunk-rocm-mi300
|
||||
- periodic
|
||||
- periodic-rocm-mi200
|
||||
- periodic-rocm-mi300
|
||||
|
||||
330
.spin/cmds.py
330
.spin/cmds.py
@ -1,330 +0,0 @@
|
||||
import hashlib
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import spin
|
||||
|
||||
|
||||
def file_digest(file, algorithm: str):
|
||||
try:
|
||||
return hashlib.file_digest(file, algorithm)
|
||||
except AttributeError:
|
||||
pass # Fallback to manual implementation below
|
||||
hash = hashlib.new(algorithm)
|
||||
while chunk := file.read(8192):
|
||||
hash.update(chunk)
|
||||
return hash
|
||||
|
||||
|
||||
def _hash_file(file):
|
||||
with open(file, "rb") as f:
|
||||
hash = file_digest(f, "sha256")
|
||||
return hash.hexdigest()
|
||||
|
||||
|
||||
def _hash_files(files):
|
||||
hashes = {file: _hash_file(file) for file in files}
|
||||
return hashes
|
||||
|
||||
|
||||
def _read_hashes(hash_file: Path):
|
||||
if not hash_file.exists():
|
||||
return {}
|
||||
with hash_file.open("r") as f:
|
||||
lines = f.readlines()
|
||||
hashes = {}
|
||||
for line in lines:
|
||||
hash = line[:64]
|
||||
file = line[66:].strip()
|
||||
hashes[file] = hash
|
||||
return hashes
|
||||
|
||||
|
||||
def _updated_hashes(hash_file, files_to_hash):
|
||||
old_hashes = _read_hashes(hash_file)
|
||||
new_hashes = _hash_files(files_to_hash)
|
||||
if new_hashes != old_hashes:
|
||||
return new_hashes
|
||||
return None
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_version():
|
||||
"""Regenerate version.py."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.generate_torch_version",
|
||||
"--is-debug=false",
|
||||
]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
TYPE_STUBS = [
|
||||
(
|
||||
"Pytorch type stubs",
|
||||
Path(".lintbin/.pytorch-type-stubs.sha256"),
|
||||
[
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"tools/autograd/deprecated.yaml",
|
||||
],
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.pyi.gen_pyi",
|
||||
"--native-functions-path",
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"--tags-path",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"--deprecated-functions-path",
|
||||
"tools/autograd/deprecated.yaml",
|
||||
],
|
||||
),
|
||||
(
|
||||
"Datapipes type stubs",
|
||||
None,
|
||||
[],
|
||||
[
|
||||
sys.executable,
|
||||
"torch/utils/data/datapipes/gen_pyi.py",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_type_stubs():
|
||||
"""Regenerate type stubs."""
|
||||
for name, hash_file, files_to_hash, cmd in TYPE_STUBS:
|
||||
if hash_file:
|
||||
if hashes := _updated_hashes(hash_file, files_to_hash):
|
||||
click.echo(
|
||||
f"Changes detected in type stub files for {name}. Regenerating..."
|
||||
)
|
||||
spin.util.run(cmd)
|
||||
hash_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with hash_file.open("w") as f:
|
||||
for file, hash in hashes.items():
|
||||
f.write(f"{hash} {file}\n")
|
||||
click.echo("Type stubs and hashes updated.")
|
||||
else:
|
||||
click.echo(f"No changes detected in type stub files for {name}.")
|
||||
else:
|
||||
click.echo(f"No hash file for {name}. Regenerating...")
|
||||
spin.util.run(cmd)
|
||||
click.echo("Type stubs regenerated.")
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_clangtidy_files():
|
||||
"""Regenerate clang-tidy files."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.linter.clang_tidy.generate_build_files",
|
||||
]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
#: These linters are expected to need less than 3s cpu time total
|
||||
VERY_FAST_LINTERS = {
|
||||
"ATEN_CPU_GPU_AGNOSTIC",
|
||||
"BAZEL_LINTER",
|
||||
"C10_NODISCARD",
|
||||
"C10_UNUSED",
|
||||
"CALL_ONCE",
|
||||
"CMAKE_MINIMUM_REQUIRED",
|
||||
"CONTEXT_DECORATOR",
|
||||
"COPYRIGHT",
|
||||
"CUBINCLUDE",
|
||||
"DEPLOY_DETECTION",
|
||||
"ERROR_PRONE_ISINSTANCE",
|
||||
"EXEC",
|
||||
"HEADER_ONLY_LINTER",
|
||||
"IMPORT_LINTER",
|
||||
"INCLUDE",
|
||||
"LINTRUNNER_VERSION",
|
||||
"MERGE_CONFLICTLESS_CSV",
|
||||
"META_NO_CREATE_UNBACKED",
|
||||
"NEWLINE",
|
||||
"NOQA",
|
||||
"NO_WORKFLOWS_ON_FORK",
|
||||
"ONCE_FLAG",
|
||||
"PYBIND11_INCLUDE",
|
||||
"PYBIND11_SPECIALIZATION",
|
||||
"PYPIDEP",
|
||||
"PYPROJECT",
|
||||
"RAWCUDA",
|
||||
"RAWCUDADEVICE",
|
||||
"ROOT_LOGGING",
|
||||
"TABS",
|
||||
"TESTOWNERS",
|
||||
"TYPEIGNORE",
|
||||
"TYPENOSKIP",
|
||||
"WORKFLOWSYNC",
|
||||
}
|
||||
|
||||
|
||||
#: These linters are expected to take a few seconds, but less than 10s cpu time total
|
||||
FAST_LINTERS = {
|
||||
"CMAKE",
|
||||
"DOCSTRING_LINTER",
|
||||
"GHA",
|
||||
"NATIVEFUNCTIONS",
|
||||
"RUFF",
|
||||
"SET_LINTER",
|
||||
"SHELLCHECK",
|
||||
"SPACES",
|
||||
}
|
||||
|
||||
|
||||
#: These linters are expected to take more than 10s cpu time total;
|
||||
#: some need more than 1 hour.
|
||||
SLOW_LINTERS = {
|
||||
"ACTIONLINT",
|
||||
"CLANGFORMAT",
|
||||
"CLANGTIDY",
|
||||
"CODESPELL",
|
||||
"FLAKE8",
|
||||
"GB_REGISTRY",
|
||||
"PYFMT",
|
||||
"PYREFLY",
|
||||
"TEST_DEVICE_BIAS",
|
||||
"TEST_HAS_MAIN",
|
||||
}
|
||||
|
||||
|
||||
ALL_LINTERS = VERY_FAST_LINTERS | FAST_LINTERS | SLOW_LINTERS
|
||||
|
||||
|
||||
LINTRUNNER_CACHE_INFO = (
|
||||
Path(".lintbin/.lintrunner.sha256"),
|
||||
[
|
||||
"requirements.txt",
|
||||
"pyproject.toml",
|
||||
".lintrunner.toml",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
LINTRUNNER_BASE_CMD = [
|
||||
"uvx",
|
||||
"--python",
|
||||
"3.10",
|
||||
"lintrunner@0.12.7",
|
||||
]
|
||||
|
||||
|
||||
@click.command()
|
||||
def setup_lint():
|
||||
"""Set up lintrunner with current CI version."""
|
||||
cmd = LINTRUNNER_BASE_CMD + ["init"]
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
|
||||
|
||||
def _check_linters():
|
||||
cmd = LINTRUNNER_BASE_CMD + ["list"]
|
||||
ret = spin.util.run(cmd, output=False, stderr=subprocess.PIPE)
|
||||
linters = {l.strip() for l in ret.stdout.decode().strip().split("\n")[1:]}
|
||||
unknown_linters = linters - ALL_LINTERS
|
||||
missing_linters = ALL_LINTERS - linters
|
||||
if unknown_linters:
|
||||
click.secho(
|
||||
f"Unknown linters found; please add them to the correct category "
|
||||
f"in .spin/cmds.py: {', '.join(unknown_linters)}",
|
||||
fg="yellow",
|
||||
)
|
||||
if missing_linters:
|
||||
click.secho(
|
||||
f"Missing linters found; please update the corresponding category "
|
||||
f"in .spin/cmds.py: {', '.join(missing_linters)}",
|
||||
fg="yellow",
|
||||
)
|
||||
return unknown_linters, missing_linters
|
||||
|
||||
|
||||
@spin.util.extend_command(
|
||||
setup_lint,
|
||||
doc=f"""
|
||||
If configuration has changed, update lintrunner.
|
||||
|
||||
Compares the stored old hashes of configuration files with new ones and
|
||||
performs setup via setup-lint if the hashes have changed.
|
||||
Hashes are stored in {LINTRUNNER_CACHE_INFO[0]}; the following files are
|
||||
considered: {", ".join(LINTRUNNER_CACHE_INFO[1])}.
|
||||
""",
|
||||
)
|
||||
@click.pass_context
|
||||
def lazy_setup_lint(ctx, parent_callback, **kwargs):
|
||||
if hashes := _updated_hashes(*LINTRUNNER_CACHE_INFO):
|
||||
click.echo(
|
||||
"Changes detected in lint configuration files. Setting up linting tools..."
|
||||
)
|
||||
parent_callback(**kwargs)
|
||||
hash_file = LINTRUNNER_CACHE_INFO[0]
|
||||
hash_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with hash_file.open("w") as f:
|
||||
for file, hash in hashes.items():
|
||||
f.write(f"{hash} {file}\n")
|
||||
click.echo("Linting tools set up and hashes updated.")
|
||||
else:
|
||||
click.echo("No changes detected in lint configuration files. Skipping setup.")
|
||||
click.echo("Regenerating version...")
|
||||
ctx.invoke(regenerate_version)
|
||||
click.echo("Regenerating type stubs...")
|
||||
ctx.invoke(regenerate_type_stubs)
|
||||
click.echo("Done.")
|
||||
_check_linters()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-a", "--apply-patches", is_flag=True)
|
||||
@click.pass_context
|
||||
def lint(ctx, apply_patches, **kwargs):
|
||||
"""Lint all files."""
|
||||
ctx.invoke(lazy_setup_lint)
|
||||
all_files_linters = VERY_FAST_LINTERS | FAST_LINTERS
|
||||
changed_files_linters = SLOW_LINTERS
|
||||
cmd = LINTRUNNER_BASE_CMD
|
||||
if apply_patches:
|
||||
cmd += ["--apply-patches"]
|
||||
all_files_cmd = cmd + [
|
||||
"--take",
|
||||
",".join(all_files_linters),
|
||||
"--all-files",
|
||||
]
|
||||
spin.util.run(all_files_cmd)
|
||||
changed_files_cmd = cmd + [
|
||||
"--take",
|
||||
",".join(changed_files_linters),
|
||||
]
|
||||
spin.util.run(changed_files_cmd)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.pass_context
|
||||
def fixlint(ctx, **kwargs):
|
||||
"""Autofix all files."""
|
||||
ctx.invoke(lint, apply_patches=True)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-a", "--apply-patches", is_flag=True)
|
||||
@click.pass_context
|
||||
def quicklint(ctx, apply_patches, **kwargs):
|
||||
"""Lint changed files."""
|
||||
ctx.invoke(lazy_setup_lint)
|
||||
cmd = LINTRUNNER_BASE_CMD
|
||||
if apply_patches:
|
||||
cmd += ["--apply-patches"]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.pass_context
|
||||
def quickfix(ctx, **kwargs):
|
||||
"""Autofix changed files."""
|
||||
ctx.invoke(quicklint, apply_patches=True)
|
||||
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/core/TensorAccessor.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Deprecated.h>
|
||||
@ -12,37 +11,252 @@
|
||||
|
||||
namespace at {
|
||||
|
||||
using torch::headeronly::DefaultPtrTraits;
|
||||
// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
|
||||
// is used to enable the __restrict__ keyword/modifier for the data
|
||||
// passed to cuda.
|
||||
template <typename T>
|
||||
struct DefaultPtrTraits {
|
||||
typedef T* PtrType;
|
||||
};
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
using torch::headeronly::RestrictPtrTraits;
|
||||
template <typename T>
|
||||
struct RestrictPtrTraits {
|
||||
typedef T* __restrict__ PtrType;
|
||||
};
|
||||
#endif
|
||||
|
||||
// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.
|
||||
// For CUDA tensors it is used in device code (only). This means that we restrict ourselves
|
||||
// to functions and types available there (e.g. IntArrayRef isn't).
|
||||
|
||||
// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
using TensorAccessorBase = torch::headeronly::detail::TensorAccessorBase<c10::IntArrayRef, T, N, PtrTraits, index_t>;
|
||||
class TensorAccessorBase {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
|
||||
C10_HOST_DEVICE TensorAccessorBase(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: data_(data_), sizes_(sizes_), strides_(strides_) {}
|
||||
C10_HOST IntArrayRef sizes() const {
|
||||
return IntArrayRef(sizes_,N);
|
||||
}
|
||||
C10_HOST IntArrayRef strides() const {
|
||||
return IntArrayRef(strides_,N);
|
||||
}
|
||||
C10_HOST_DEVICE index_t stride(index_t i) const {
|
||||
return strides_[i];
|
||||
}
|
||||
C10_HOST_DEVICE index_t size(index_t i) const {
|
||||
return sizes_[i];
|
||||
}
|
||||
C10_HOST_DEVICE PtrType data() {
|
||||
return data_;
|
||||
}
|
||||
C10_HOST_DEVICE const PtrType data() const {
|
||||
return data_;
|
||||
}
|
||||
protected:
|
||||
PtrType data_;
|
||||
const index_t* sizes_;
|
||||
const index_t* strides_;
|
||||
};
|
||||
|
||||
// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
|
||||
// `Tensor.accessor<T, N>()`.
|
||||
// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only
|
||||
// indexing on the device uses `TensorAccessor`s.
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
using TensorAccessor = torch::headeronly::detail::TensorAccessor<c10::IntArrayRef, T, N, PtrTraits, index_t>;
|
||||
class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
|
||||
namespace detail {
|
||||
C10_HOST_DEVICE TensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
|
||||
|
||||
template <size_t N, typename index_t>
|
||||
struct IndexBoundsCheck {
|
||||
IndexBoundsCheck(index_t i) {
|
||||
TORCH_CHECK_INDEX(
|
||||
C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, template <typename U> class PtrTraits, typename index_t>
|
||||
class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
|
||||
C10_HOST_DEVICE TensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
|
||||
C10_HOST_DEVICE T & operator[](index_t i) {
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
|
||||
return this->data_[this->strides_[0]*i];
|
||||
}
|
||||
C10_HOST_DEVICE const T & operator[](index_t i) const {
|
||||
return this->data_[this->strides_[0]*i];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host
|
||||
// and as
|
||||
// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)
|
||||
// in order to transfer them on the device when calling kernels.
|
||||
// On the device, indexing of multidimensional tensors gives to `TensorAccessor`s.
|
||||
// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
|
||||
// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
|
||||
// on the device, so those functions are host only.
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
class GenericPackedTensorAccessorBase {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
C10_HOST GenericPackedTensorAccessorBase(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: data_(data_) {
|
||||
std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
|
||||
std::copy(strides_, strides_ + N, std::begin(this->strides_));
|
||||
}
|
||||
|
||||
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
||||
C10_HOST GenericPackedTensorAccessorBase(
|
||||
PtrType data_,
|
||||
const source_index_t* sizes_,
|
||||
const source_index_t* strides_)
|
||||
: data_(data_) {
|
||||
for (const auto i : c10::irange(N)) {
|
||||
this->sizes_[i] = sizes_[i];
|
||||
this->strides_[i] = strides_[i];
|
||||
}
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE index_t stride(index_t i) const {
|
||||
return strides_[i];
|
||||
}
|
||||
C10_HOST_DEVICE index_t size(index_t i) const {
|
||||
return sizes_[i];
|
||||
}
|
||||
C10_HOST_DEVICE PtrType data() {
|
||||
return data_;
|
||||
}
|
||||
C10_HOST_DEVICE const PtrType data() const {
|
||||
return data_;
|
||||
}
|
||||
protected:
|
||||
PtrType data_;
|
||||
// NOLINTNEXTLINE(*c-arrays*)
|
||||
index_t sizes_[N];
|
||||
// NOLINTNEXTLINE(*c-arrays*)
|
||||
index_t strides_[N];
|
||||
C10_HOST void bounds_check_(index_t i) const {
|
||||
TORCH_CHECK_INDEX(
|
||||
0 <= i && i < index_t{N},
|
||||
"Index ",
|
||||
i,
|
||||
" is not within bounds of a tensor of dimension ",
|
||||
N);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
using GenericPackedTensorAccessorBase = torch::headeronly::detail::GenericPackedTensorAccessorBase<detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
|
||||
class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const source_index_t* sizes_,
|
||||
const source_index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
|
||||
index_t* new_sizes = this->sizes_ + 1;
|
||||
index_t* new_strides = this->strides_ + 1;
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
||||
}
|
||||
|
||||
C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {
|
||||
const index_t* new_sizes = this->sizes_ + 1;
|
||||
const index_t* new_strides = this->strides_ + 1;
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
||||
}
|
||||
|
||||
/// Returns a PackedTensorAccessor of the same dimension after transposing the
|
||||
/// two dimensions given. Does not actually move elements; transposition is
|
||||
/// made by permuting the size/stride arrays. If the dimensions are not valid,
|
||||
/// asserts.
|
||||
C10_HOST GenericPackedTensorAccessor<T, N, PtrTraits, index_t> transpose(
|
||||
index_t dim1,
|
||||
index_t dim2) const {
|
||||
this->bounds_check_(dim1);
|
||||
this->bounds_check_(dim2);
|
||||
GenericPackedTensorAccessor<T, N, PtrTraits, index_t> result(
|
||||
this->data_, this->sizes_, this->strides_);
|
||||
std::swap(result.strides_[dim1], result.strides_[dim2]);
|
||||
std::swap(result.sizes_[dim1], result.sizes_[dim2]);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, template <typename U> class PtrTraits, typename index_t>
|
||||
class GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const source_index_t* sizes_,
|
||||
const source_index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
C10_DEVICE T & operator[](index_t i) {
|
||||
return this->data_[this->strides_[0] * i];
|
||||
}
|
||||
C10_DEVICE const T& operator[](index_t i) const {
|
||||
return this->data_[this->strides_[0]*i];
|
||||
}
|
||||
|
||||
// Same as in the general N-dimensional case, but note that in the
|
||||
// 1-dimensional case the returned PackedTensorAccessor will always be an
|
||||
// identical copy of the original
|
||||
C10_HOST GenericPackedTensorAccessor<T, 1, PtrTraits, index_t> transpose(
|
||||
index_t dim1,
|
||||
index_t dim2) const {
|
||||
this->bounds_check_(dim1);
|
||||
this->bounds_check_(dim2);
|
||||
return GenericPackedTensorAccessor<T, 1, PtrTraits, index_t>(
|
||||
this->data_, this->sizes_, this->strides_);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
using GenericPackedTensorAccessor = torch::headeronly::detail::GenericPackedTensorAccessor<TensorAccessor<T, N-1, PtrTraits, index_t>, detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
|
||||
|
||||
// Can't put this directly into the macro function args because of commas
|
||||
#define AT_X GenericPackedTensorAccessor<T, N, PtrTraits, index_t>
|
||||
|
||||
@ -245,9 +245,6 @@ class TORCH_API TensorBase {
|
||||
size_t weak_use_count() const noexcept {
|
||||
return impl_.weak_use_count();
|
||||
}
|
||||
bool is_uniquely_owned() const noexcept {
|
||||
return impl_.is_uniquely_owned();
|
||||
}
|
||||
|
||||
std::string toString() const;
|
||||
|
||||
|
||||
@ -18,8 +18,6 @@
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace torch {
|
||||
class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {};
|
||||
namespace jit {
|
||||
@ -1632,6 +1630,4 @@ struct TORCH_API WeakOrStrongTypePtr {
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
#include <ATen/core/ivalue_inl.h> // IWYU pragma: keep
|
||||
|
||||
@ -29,8 +29,6 @@
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
struct Function;
|
||||
@ -2569,5 +2567,3 @@ TypePtr IValue::type() const {
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -223,62 +223,6 @@ CONVERT_FROM_BF16_TEMPLATE(double)
|
||||
CONVERT_FROM_BF16_TEMPLATE(float16_t)
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
|
||||
// clang-[17, 20] crashes when autovectorizing static cast to bf16
|
||||
// Below is a workaround to have some vectorization
|
||||
// Works decently well for smaller int types
|
||||
template <typename from_type>
|
||||
inline void convertToBf16Impl(
|
||||
const from_type* __restrict src,
|
||||
c10::BFloat16* __restrict dst,
|
||||
uint64_t n) {
|
||||
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
|
||||
uint64_t loopBound = n - (n % 16);
|
||||
uint64_t i = 0;
|
||||
for (; i < loopBound; i += 16) {
|
||||
float32x4_t a, b, c, d;
|
||||
a[0] = static_cast<float>(src[i]);
|
||||
a[1] = static_cast<float>(src[i + 1]);
|
||||
a[2] = static_cast<float>(src[i + 2]);
|
||||
a[3] = static_cast<float>(src[i + 3]);
|
||||
b[0] = static_cast<float>(src[i + 4]);
|
||||
b[1] = static_cast<float>(src[i + 5]);
|
||||
b[2] = static_cast<float>(src[i + 6]);
|
||||
b[3] = static_cast<float>(src[i + 7]);
|
||||
c[0] = static_cast<float>(src[i + 8]);
|
||||
c[1] = static_cast<float>(src[i + 9]);
|
||||
c[2] = static_cast<float>(src[i + 10]);
|
||||
c[3] = static_cast<float>(src[i + 11]);
|
||||
d[0] = static_cast<float>(src[i + 12]);
|
||||
d[1] = static_cast<float>(src[i + 13]);
|
||||
d[2] = static_cast<float>(src[i + 14]);
|
||||
d[3] = static_cast<float>(src[i + 15]);
|
||||
|
||||
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
|
||||
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
|
||||
}
|
||||
|
||||
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
|
||||
for (; i < n; i++) {
|
||||
float a = static_cast<float>(src[i]);
|
||||
dstPtr[i] = vcvth_bf16_f32(a);
|
||||
}
|
||||
}
|
||||
|
||||
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
|
||||
template <> \
|
||||
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
|
||||
return convertToBf16Impl<from_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_TO_BF16_TEMPLATE(uint8_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int8_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int16_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int32_t)
|
||||
|
||||
#endif
|
||||
|
||||
inline void convertBoolToBfloat16Impl(
|
||||
const bool* __restrict src,
|
||||
c10::BFloat16* __restrict dst,
|
||||
|
||||
@ -11,8 +11,6 @@
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
// Sleef offers vectorized versions of some transcedentals
|
||||
// such as sin, cos, tan etc..
|
||||
// However for now opting for STL, since we are not building
|
||||
@ -652,5 +650,3 @@ inline Vectorized<float> Vectorized<float>::erf() const {
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <shared_mutex>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusparse.h>
|
||||
@ -89,13 +88,8 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
|
||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
struct WorkspaceMapWithMutex {
|
||||
std::map<std::tuple<void*, void*>, at::DataPtr> map;
|
||||
std::shared_mutex mutex;
|
||||
};
|
||||
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#include <ATen/cuda/CUDAGraph.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/MemPool.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
|
||||
@ -14,7 +13,7 @@ static bool _cuda_graphs_debug = false;
|
||||
MempoolId_t graph_pool_handle() {
|
||||
// Sets just the second value, to distinguish it from MempoolId_ts created from
|
||||
// cudaStreamGetCaptureInfo id_s in capture_begin.
|
||||
return at::cuda::MemPool::graph_pool_handle();
|
||||
return c10::cuda::MemPool::graph_pool_handle();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -91,7 +90,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
||||
} else {
|
||||
// User did not ask us to share a mempool. Create graph pool handle using is_user_created=false.
|
||||
// Sets just the first value, to distinguish it from MempoolId_ts created by graph_pool_handle().
|
||||
mempool_id_ = at::cuda::MemPool::graph_pool_handle(false);
|
||||
mempool_id_ = c10::cuda::MemPool::graph_pool_handle(false);
|
||||
TORCH_INTERNAL_ASSERT(mempool_id_.first > 0);
|
||||
}
|
||||
|
||||
@ -175,24 +174,17 @@ void CUDAGraph::instantiate() {
|
||||
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
|
||||
// who prefer not to report error message through these arguments moving forward
|
||||
// (they prefer return value, or errors on api calls internal to the capture)
|
||||
// ROCM appears to fail with HIP error: invalid argument
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && !defined(USE_ROCM)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, cudaGraphInstantiateFlagUseNodePriority));
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, 0));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
|
||||
#endif
|
||||
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
|
||||
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
|
||||
} else {
|
||||
#if !defined(USE_ROCM)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch | cudaGraphInstantiateFlagUseNodePriority));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
||||
#endif
|
||||
}
|
||||
has_graph_exec_ = true;
|
||||
}
|
||||
|
||||
@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) {
|
||||
// - Comments of @soumith copied from cuDNN handle pool implementation
|
||||
#ifdef NO_CUDNN_DESTROY_HANDLE
|
||||
#else
|
||||
cublasDestroy(handle);
|
||||
cublasDestroy(handle);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -107,27 +107,19 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
|
||||
|
||||
} // namespace
|
||||
|
||||
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
return instance;
|
||||
}
|
||||
|
||||
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void clearCublasWorkspaces() {
|
||||
{
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
{
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
cublas_handle_stream_to_workspace().clear();
|
||||
cublaslt_handle_stream_to_workspace().clear();
|
||||
}
|
||||
|
||||
size_t parseChosenWorkspaceSize() {
|
||||
@ -241,38 +233,6 @@ at::DataPtr getNewCUDABlasLtWorkspace() {
|
||||
return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize());
|
||||
}
|
||||
|
||||
void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) {
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
|
||||
size_t workspace_size = getChosenWorkspaceSize();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
|
||||
handle, workspace_it->second.get(), workspace_size));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.try_emplace(key, std::move(new_workspace)).first;
|
||||
TORCH_CUDABLAS_CHECK(
|
||||
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
|
||||
}
|
||||
}
|
||||
|
||||
void* getCUDABlasLtWorkspace() {
|
||||
#ifndef USE_ROCM
|
||||
static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true;
|
||||
@ -281,10 +241,8 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
|
||||
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
#endif
|
||||
@ -292,29 +250,11 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewCUDABlasLtWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it =
|
||||
workspace.map.try_emplace(key, std::move(new_workspace)).first;
|
||||
return workspace_it->second.mutable_get();
|
||||
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
|
||||
}
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
|
||||
cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
@ -358,8 +298,13 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
// will allocate memory dynamically (even if they're cheap) outside
|
||||
// PyTorch's CUDA caching allocator. It's possible that CCA used up
|
||||
// all the memory and cublas's cudaMallocAsync will return OOM
|
||||
setWorkspaceForHandle(handle, stream);
|
||||
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
|
||||
}
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
|
||||
#if !defined(USE_ROCM)
|
||||
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
|
||||
// FP32 data type calculations based on the value of the allow_tf32 flag.
|
||||
|
||||
@ -1,69 +0,0 @@
|
||||
#include <ATen/core/CachingHostAllocator.h>
|
||||
#include <ATen/cuda/MemPool.h>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
// uid_ is incremented when a user creates a MemPool,
|
||||
// for example: using graph_pool_handle() or c10::cuda::MemPool().
|
||||
//
|
||||
// uuid_ is incremented when CUDAGraph creates a MemPool
|
||||
// as a result of a user not providing a pool.
|
||||
//
|
||||
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
|
||||
// passed to a function, either by user or CUDAGraphs. For example,
|
||||
// default value of MempoolId_t for capture_begin function is {0, 0}.
|
||||
// That's why uid_ and uuid_ start at 1.
|
||||
std::atomic<CaptureId_t> MemPool::uid_{1};
|
||||
std::atomic<CaptureId_t> MemPool::uuid_{1};
|
||||
|
||||
MemPool::MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created,
|
||||
bool use_on_oom)
|
||||
: allocator_(allocator), is_user_created_(is_user_created) {
|
||||
if (is_user_created_) {
|
||||
id_ = {0, uid_++};
|
||||
} else {
|
||||
id_ = {uuid_++, 0};
|
||||
}
|
||||
device_ = c10::cuda::current_device();
|
||||
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
|
||||
if (use_on_oom) {
|
||||
CUDACachingAllocator::setUseOnOOM(device_, id_);
|
||||
}
|
||||
}
|
||||
|
||||
MemPool::~MemPool() {
|
||||
// TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
// We used to assert that TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
// However, this assertion is not true if a memory pool is shared
|
||||
// with a cuda graph. That CUDAGraph will increase the use count
|
||||
// until it is reset.
|
||||
CUDACachingAllocator::releasePool(device_, id_);
|
||||
c10::cuda::CUDACachingAllocator::emptyCache(id_);
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::id() {
|
||||
return id_;
|
||||
}
|
||||
|
||||
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
int MemPool::use_count() {
|
||||
return CUDACachingAllocator::getPoolUseCount(device_, id_);
|
||||
}
|
||||
|
||||
c10::DeviceIndex MemPool::device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
|
||||
if (is_user_created) {
|
||||
return {0, uid_++};
|
||||
}
|
||||
return {uuid_++, 0};
|
||||
}
|
||||
|
||||
} // namespace at::cuda
|
||||
@ -1,44 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
// Keep BC only
|
||||
using c10::CaptureId_t;
|
||||
using c10::MempoolId_t;
|
||||
|
||||
// MemPool represents a pool of memory in a caching allocator. Currently,
|
||||
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
||||
//
|
||||
// An allocator pointer can be passed to the MemPool to define how the
|
||||
// allocations should be done in the pool. For example: using a different
|
||||
// system allocator such as ncclMemAlloc.
|
||||
struct TORCH_CUDA_CPP_API MemPool {
|
||||
MemPool(
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
|
||||
bool is_user_created = true,
|
||||
bool use_on_oom = false);
|
||||
MemPool(const MemPool&) = delete;
|
||||
MemPool(MemPool&&) = default;
|
||||
MemPool& operator=(const MemPool&) = delete;
|
||||
MemPool& operator=(MemPool&&) = default;
|
||||
~MemPool();
|
||||
|
||||
MempoolId_t id();
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator();
|
||||
int use_count();
|
||||
c10::DeviceIndex device();
|
||||
static MempoolId_t graph_pool_handle(bool is_user_created = true);
|
||||
|
||||
private:
|
||||
static std::atomic<CaptureId_t> uid_;
|
||||
static std::atomic<CaptureId_t> uuid_;
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator_;
|
||||
bool is_user_created_;
|
||||
MempoolId_t id_;
|
||||
c10::DeviceIndex device_;
|
||||
};
|
||||
|
||||
} // namespace at::cuda
|
||||
@ -22,7 +22,6 @@ enum class MacOSVersion : uint32_t {
|
||||
MACOS_VER_15_0_PLUS,
|
||||
MACOS_VER_15_1_PLUS,
|
||||
MACOS_VER_15_2_PLUS,
|
||||
MACOS_VER_26_0_PLUS,
|
||||
};
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
|
||||
@ -65,7 +65,6 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
|
||||
static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
|
||||
static bool _macos_15_1_plus = is_os_version_at_least(15, 1);
|
||||
static bool _macos_15_2_plus = is_os_version_at_least(15, 2);
|
||||
static bool _macos_26_0_plus = is_os_version_at_least(26, 0);
|
||||
|
||||
switch (version) {
|
||||
case MacOSVersion::MACOS_VER_14_4_PLUS:
|
||||
@ -76,8 +75,6 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
|
||||
return _macos_15_1_plus;
|
||||
case MacOSVersion::MACOS_VER_15_2_PLUS:
|
||||
return _macos_15_2_plus;
|
||||
case MacOSVersion::MACOS_VER_26_0_PLUS:
|
||||
return _macos_26_0_plus;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -1087,8 +1087,7 @@ TORCH_IMPL_FUNC(index_copy_out)
|
||||
result.copy_(self);
|
||||
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if ((result.is_cuda() || result.is_xpu()) &&
|
||||
globalContext().deterministicAlgorithms()) {
|
||||
if (result.is_cuda() && globalContext().deterministicAlgorithms()) {
|
||||
torch::List<std::optional<Tensor>> indices;
|
||||
indices.resize(dim + 1);
|
||||
indices.set(dim, index);
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
#pragma once
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace at::native {
|
||||
|
||||
// Used as an interface between the different BLAS-like libraries
|
||||
@ -23,5 +21,3 @@ static inline char to_blas(TransposeType trans) {
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -904,11 +904,19 @@ Tensor mvlgamma(const Tensor& self, int64_t p) {
|
||||
return args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER);
|
||||
}
|
||||
|
||||
// since mvlgamma_ has different signature from its
|
||||
// out and functional variant, we explicitly
|
||||
// define it (instead of using structured kernel).
|
||||
Tensor& mvlgamma_(Tensor& self, int64_t p) {
|
||||
return at::mvlgamma_out(self, self, p);
|
||||
mvlgamma_check(self, p);
|
||||
Tensor args = native::arange(
|
||||
-p *HALF + HALF,
|
||||
HALF,
|
||||
HALF,
|
||||
optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().layout_opt(),
|
||||
self.options().device_opt(),
|
||||
self.options().pinned_memory_opt());
|
||||
args = args.add(self.unsqueeze(-1));
|
||||
const auto p2_sub_p = static_cast<double>(p * (p - 1));
|
||||
return self.copy_(args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER));
|
||||
}
|
||||
|
||||
Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {
|
||||
|
||||
@ -296,7 +296,7 @@ template <typename scalar_t, typename res_scalar_t = scalar_t>
|
||||
bool launchGemmAndBiasCublasLt(
|
||||
// args contains result which is modified
|
||||
cublasCommonArgs& args,
|
||||
const std::optional<Tensor>& self,
|
||||
const Tensor& self,
|
||||
const Scalar& alpha,
|
||||
Activation activation = Activation::None
|
||||
) {
|
||||
@ -304,8 +304,12 @@ bool launchGemmAndBiasCublasLt(
|
||||
// or when it can be squeezed to 1D.
|
||||
// self_ptr == nullptr implies ignore bias epilogue
|
||||
// and use standard gemm-like API.
|
||||
const auto* self_ptr = self.has_value() ? self.value().const_data_ptr<scalar_t>() : static_cast<const scalar_t*>(nullptr);
|
||||
|
||||
const auto* self_ptr = [&]() -> auto {
|
||||
if (self.dim() == 1 || self.squeeze().dim() == 1) {
|
||||
return self.const_data_ptr<scalar_t>();
|
||||
}
|
||||
return static_cast<const scalar_t*>(nullptr);
|
||||
}();
|
||||
|
||||
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
@ -388,30 +392,35 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
|
||||
#ifdef USE_ROCM
|
||||
// Conditioned on the device index, which is not persistent
|
||||
disable_addmm_cuda_lt = disable_addmm_cuda_lt || isGloballyDisabledAddmmCudaLt(self.device());
|
||||
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
|
||||
#endif
|
||||
// Condition on the input
|
||||
disable_addmm_cuda_lt = disable_addmm_cuda_lt || !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation);
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation) || disable_addmm_cuda_lt;
|
||||
// }
|
||||
|
||||
at::ScalarType scalar_type = mat1.scalar_type();
|
||||
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
|
||||
|
||||
#ifdef USE_ROCM
|
||||
disable_addmm_cuda_lt = disable_addmm_cuda_lt || is_float_output_with_half_input;
|
||||
#endif
|
||||
|
||||
bool use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
|
||||
// for float output with half input cublasLT with bias produces wrong results
|
||||
use_bias_ptr_lt &= !is_float_output_with_half_input;
|
||||
|
||||
// Handle result/self shapes
|
||||
if (!result.is_same(self)) {
|
||||
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
|
||||
|
||||
// We do not copy bias only when we need the bias ptr
|
||||
// We use bias ptr in the Lt path only when bias is 1D
|
||||
const auto use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
|
||||
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
|
||||
if (!use_bias_ptr_lt) {
|
||||
// We do expand self even before
|
||||
// check for beta != 0.0 to make sure that
|
||||
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
|
||||
// runs green.
|
||||
return expand_size(self, result.sizes(), "addmm");
|
||||
}
|
||||
return c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
}();
|
||||
// We do not copy bias only when we need the bias ptr
|
||||
if (beta.toComplexDouble() != 0.0 && !use_bias_ptr_lt) {
|
||||
// NOTE: self should broadcast over result
|
||||
at::native::copy_(result, *expand_size(self, result.sizes(), "addmm"));
|
||||
at::native::copy_(result, *self_maybe_expanded);
|
||||
}
|
||||
}
|
||||
|
||||
@ -459,7 +468,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation);
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
|
||||
}
|
||||
);
|
||||
#endif
|
||||
@ -471,7 +480,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
scalar_type,
|
||||
"addmm_cuda_lt",
|
||||
[&] {
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation);
|
||||
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
|
||||
}
|
||||
);
|
||||
} // end is_float_output_with_half_input
|
||||
@ -927,7 +936,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
|
||||
return _int_mm_out_cuda(self, mat2, result);
|
||||
}
|
||||
|
||||
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
|
||||
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
|
||||
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
|
||||
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
|
||||
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
|
||||
@ -951,7 +960,7 @@ static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& bat
|
||||
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
|
||||
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
|
||||
|
||||
if (self_baddbmm.has_value()) {
|
||||
if (!is_bmm && self_baddbmm.has_value()) {
|
||||
const auto& self = self_baddbmm.value();
|
||||
TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
|
||||
TORCH_CHECK(self.sizes() == output_size, "self must have the same shape as the output");
|
||||
@ -959,12 +968,15 @@ static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& bat
|
||||
}
|
||||
|
||||
Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
|
||||
Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype));
|
||||
IntArrayRef batch1_sizes = batch1.sizes();
|
||||
IntArrayRef batch2_sizes = batch2.sizes();
|
||||
|
||||
Tensor out = at::empty({batch1_sizes[0], batch1_sizes[1], batch2_sizes[2]}, batch1.options().dtype(out_dtype));
|
||||
return _bmm_out_dtype_cuda(batch1, batch2, out_dtype, out);
|
||||
}
|
||||
|
||||
Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, Tensor &out) {
|
||||
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype);
|
||||
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype, true);
|
||||
Scalar beta(0.0);
|
||||
Scalar alpha(1.0);
|
||||
{
|
||||
@ -976,16 +988,14 @@ Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at
|
||||
}
|
||||
|
||||
Tensor _baddbmm_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) {
|
||||
TORCH_CHECK(self.scalar_type() == out_dtype || self.scalar_type() == batch1.dtype(),
|
||||
"self dtype must match either out_dtype or batch1 dtype");
|
||||
Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype));
|
||||
return _baddbmm_out_dtype_cuda(self, batch1, batch2, out_dtype, beta, alpha, out);
|
||||
// We need to copy the tensor
|
||||
Tensor out = self.clone().to(self.options().dtype(out_dtype));
|
||||
|
||||
return _baddbmm_out_dtype_cuda(out, batch1, batch2, out_dtype, beta, alpha, out);
|
||||
}
|
||||
|
||||
Tensor& _baddbmm_out_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
|
||||
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, out);
|
||||
// We need to copy the tensor
|
||||
out.copy_(self);
|
||||
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, false, self);
|
||||
{
|
||||
NoNamesGuard guard;
|
||||
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha);
|
||||
@ -1020,27 +1030,24 @@ Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::Sca
|
||||
}
|
||||
|
||||
Tensor _addmm_dtype_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) {
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
|
||||
Tensor result = at::empty({mat1.size(0), mat2.size(1)}, self.options().dtype(out_dtype));
|
||||
Tensor result = at::empty(self.sizes(), self.options().dtype(out_dtype));
|
||||
return _addmm_dtype_out_cuda(self, mat1, mat2, out_dtype, beta, alpha, result);
|
||||
}
|
||||
|
||||
Tensor& _addmm_dtype_out_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
|
||||
// repeat dimensionality checks for direct calls to `out` overload
|
||||
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type());
|
||||
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
|
||||
TORCH_CHECK(out_dtype == mat1.scalar_type() ||
|
||||
(out_dtype == at::ScalarType::Float && (mat1.scalar_type() == at::ScalarType::Half || mat1.scalar_type() == at::ScalarType::BFloat16)),
|
||||
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
|
||||
|
||||
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
||||
TORCH_CHECK(out_dtype == self.scalar_type() || self.scalar_type() == mat1.scalar_type(),
|
||||
"self dtype must match either out_dtype or mat1 dtype");
|
||||
TORCH_CHECK(out_dtype == self.scalar_type() ||
|
||||
(out_dtype == at::ScalarType::Float && (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16)),
|
||||
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
|
||||
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
||||
|
||||
addmm_out_cuda_impl(out, self, mat1, mat2, beta, alpha);
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/CompositeRandomAccessorCommon.h>
|
||||
#include <thrust/swap.h>
|
||||
#include <thrust/tuple.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
@ -75,52 +75,30 @@ static inline bool can_use_int32_nhwc(
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline bool can_use_int32_nchw(
|
||||
int64_t nbatch, int64_t channels,
|
||||
int64_t height, int64_t width,
|
||||
int64_t pooled_height, int64_t pooled_width) {
|
||||
int64_t hw = height * width;
|
||||
return can_use_int32_nhwc(
|
||||
nbatch, channels, height, width,
|
||||
pooled_height, pooled_width,
|
||||
channels * hw, // in_stride_n
|
||||
hw, // in_stride_c
|
||||
width, // in_stride_h
|
||||
1 // in_stride_w
|
||||
);
|
||||
}
|
||||
|
||||
// kernels borrowed from Caffe
|
||||
template <typename scalar_t, typename index_t>
|
||||
__global__ void max_pool_forward_nchw(
|
||||
const index_t nthreads,
|
||||
const scalar_t* bottom_data,
|
||||
const int64_t channels,
|
||||
const int64_t height,
|
||||
const int64_t width,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
scalar_t* top_data,
|
||||
template <typename scalar_t>
|
||||
__global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom_data,
|
||||
const int64_t channels, const int64_t height,
|
||||
const int64_t width, const int pooled_height, const int pooled_width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w, scalar_t* top_data,
|
||||
int64_t* top_mask) {
|
||||
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
|
||||
index_t pw = index % pooled_width;
|
||||
index_t ph = (index / pooled_width) % pooled_height;
|
||||
index_t c = (index / pooled_width / pooled_height) % channels;
|
||||
index_t n = index / pooled_width / pooled_height / channels;
|
||||
index_t hstart = ph * stride_h - pad_h;
|
||||
index_t wstart = pw * stride_w - pad_w;
|
||||
index_t hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
|
||||
index_t wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
|
||||
CUDA_KERNEL_LOOP(index, nthreads) {
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
int hstart = ph * stride_h - pad_h;
|
||||
int wstart = pw * stride_w - pad_w;
|
||||
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
|
||||
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
|
||||
while(hstart < 0)
|
||||
hstart += dilation_h;
|
||||
while(wstart < 0)
|
||||
wstart += dilation_w;
|
||||
scalar_t maxval = at::numeric_limits<scalar_t>::lower_bound(); // -Infinity
|
||||
index_t maxidx = hstart * width + wstart;
|
||||
int maxidx = hstart * width + wstart;
|
||||
const scalar_t* btm_data = bottom_data + (n * channels + c) * height * width;
|
||||
for (int h = hstart; h < hend; h += dilation_h) {
|
||||
for (int w = wstart; w < wend; w += dilation_w) {
|
||||
@ -273,39 +251,32 @@ __global__ void max_pool_forward_nhwc(
|
||||
|
||||
static constexpr int BLOCK_THREADS = 256;
|
||||
|
||||
template <typename scalar_t, typename accscalar_t, typename index_t>
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
#if defined (USE_ROCM)
|
||||
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4)
|
||||
#else
|
||||
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 8)
|
||||
#endif
|
||||
__global__ void max_pool_backward_nchw(
|
||||
const scalar_t* top_diff,
|
||||
const int64_t* top_mask,
|
||||
const index_t num,
|
||||
const index_t channels,
|
||||
const index_t height,
|
||||
const index_t width,
|
||||
const index_t pooled_height,
|
||||
const index_t pooled_width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w,
|
||||
__global__ void max_pool_backward_nchw(const scalar_t* top_diff,
|
||||
const int64_t* top_mask, const int num, const int64_t channels,
|
||||
const int64_t height, const int64_t width, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
scalar_t* bottom_diff) {
|
||||
CUDA_KERNEL_LOOP_TYPE(index, height*width, index_t) {
|
||||
index_t h = index / width;
|
||||
index_t w = index - h * width;
|
||||
index_t phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
|
||||
index_t phend = p_end(h, pad_h, pooled_height, stride_h);
|
||||
index_t pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
|
||||
index_t pwend = p_end(w, pad_w, pooled_width, stride_w);
|
||||
for (index_t n = blockIdx.y; n < num; n += gridDim.y) {
|
||||
for (index_t c = blockIdx.z; c < channels; c += gridDim.z) {
|
||||
CUDA_KERNEL_LOOP(index, height*width) {
|
||||
int h = index / width;
|
||||
int w = index - h * width;
|
||||
int phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
|
||||
int phend = p_end(h, pad_h, pooled_height, stride_h);
|
||||
int pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
|
||||
int pwend = p_end(w, pad_w, pooled_width, stride_w);
|
||||
for (int n = blockIdx.y; n < num; n += gridDim.y) {
|
||||
for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
|
||||
accscalar_t gradient = accscalar_t(0);
|
||||
index_t offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
for (index_t ph = phstart; ph < phend; ++ph) {
|
||||
for (index_t pw = pwstart; pw < pwend; ++pw) {
|
||||
int offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
for (int ph = phstart; ph < phend; ++ph) {
|
||||
for (int pw = pwstart; pw < pwend; ++pw) {
|
||||
if (top_mask[ph * pooled_width + pw + offset] == h * width + w) {
|
||||
gradient += static_cast<accscalar_t>(top_diff[ph * pooled_width + pw + offset]);
|
||||
}
|
||||
@ -498,6 +469,8 @@ const Tensor& indices) {
|
||||
const int64_t in_stride_h = input.stride(-2);
|
||||
const int64_t in_stride_w = input.stride(-1);
|
||||
|
||||
const int count = safe_downcast<int, int64_t>(output.numel());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"max_pool2d_with_indices_out_cuda_frame",
|
||||
[&] {
|
||||
@ -580,42 +553,14 @@ const Tensor& indices) {
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::Contiguous: {
|
||||
const int threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
|
||||
BLOCK_THREADS);
|
||||
const int64_t nthreads = output.numel();
|
||||
bool use_int32 = can_use_int32_nchw(
|
||||
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
|
||||
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
|
||||
const int blocks = static_cast<int>(std::min<int64_t>(
|
||||
ceil_div(nthreads, static_cast<int64_t>(threads)),
|
||||
static_cast<int64_t>(maxGridX)));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (use_int32) {
|
||||
max_pool_forward_nchw<scalar_t, int32_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
static_cast<int32_t>(nthreads),
|
||||
input_data,
|
||||
static_cast<int32_t>(nInputPlane),
|
||||
static_cast<int32_t>(inputHeight),
|
||||
static_cast<int32_t>(inputWidth),
|
||||
static_cast<int32_t>(outputHeight),
|
||||
static_cast<int32_t>(outputWidth),
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
output_data, indices_data);
|
||||
} else {
|
||||
max_pool_forward_nchw<scalar_t, int64_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
nthreads,
|
||||
input_data,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
output_data, indices_data);
|
||||
}
|
||||
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
|
||||
BLOCK_THREADS);
|
||||
max_pool_forward_nchw<scalar_t>
|
||||
<<<ceil_div(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count, input_data,
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
output_data, indices_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
break;
|
||||
}
|
||||
@ -688,6 +633,8 @@ const Tensor& gradInput) {
|
||||
|
||||
gradInput.zero_();
|
||||
|
||||
int64_t count = input.numel();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"max_pool2d_with_indices_out_cuda_frame",
|
||||
[&] {
|
||||
@ -745,45 +692,25 @@ const Tensor& gradInput) {
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::Contiguous: {
|
||||
const int threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
|
||||
BLOCK_THREADS);
|
||||
const int imgcount = inputWidth * inputHeight;
|
||||
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
|
||||
const int maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
const int maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
|
||||
const int blocks_x = std::min(ceil_div(imgcount, threads), maxGridX);
|
||||
dim3 grid(blocks_x, static_cast<unsigned>(std::min<int64_t>(nbatch, maxGridY)), static_cast<unsigned>(std::min<int64_t>(nInputPlane, maxGridZ)));
|
||||
bool use_int32 = can_use_int32_nchw(
|
||||
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (use_int32) {
|
||||
max_pool_backward_nchw<scalar_t, accscalar_t, int32_t>
|
||||
<<<grid, threads, 0, stream>>>(
|
||||
gradOutput_data,
|
||||
indices_data,
|
||||
static_cast<int32_t>(nbatch),
|
||||
static_cast<int32_t>(nInputPlane),
|
||||
static_cast<int32_t>(inputHeight),
|
||||
static_cast<int32_t>(inputWidth),
|
||||
static_cast<int32_t>(outputHeight),
|
||||
static_cast<int32_t>(outputWidth),
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
gradInput_data);
|
||||
} else {
|
||||
max_pool_backward_nchw<scalar_t, accscalar_t, int64_t>
|
||||
<<<grid, threads, 0, stream>>>(
|
||||
gradOutput_data,
|
||||
indices_data,
|
||||
nbatch,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
gradInput_data);
|
||||
}
|
||||
int imgcount = inputWidth * inputHeight;
|
||||
dim3 grid;
|
||||
const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS;
|
||||
grid.x = blocks;
|
||||
grid.y = nbatch;
|
||||
uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
if (maxGridY < grid.y) grid.y = maxGridY;
|
||||
grid.z = nInputPlane;
|
||||
uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
|
||||
if (maxGridZ < grid.z) grid.z = maxGridZ;
|
||||
|
||||
max_pool_backward_nchw<scalar_t, accscalar_t>
|
||||
<<<grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
gradOutput_data,
|
||||
indices_data,
|
||||
nbatch,
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
gradInput_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
break;
|
||||
}
|
||||
|
||||
@ -78,18 +78,9 @@ __global__ void EmbeddingBag_updateOutputKernel_max(
|
||||
scalar_t weightFeatMax = 0;
|
||||
int64_t bag_size_ = 0;
|
||||
int64_t maxWord = -1;
|
||||
|
||||
// Separate validation loop reduces register pressure in the main loop below.
|
||||
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
|
||||
bool has_invalid_index = false;
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
|
||||
}
|
||||
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
|
||||
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
bool pad = (input[emb] == padding_idx);
|
||||
CUDA_KERNEL_ASSERT(input[emb] < numRows);
|
||||
const int64_t weightRow = input[emb] * weight_stride0;
|
||||
scalar_t weightValue = weightFeat[weightRow];
|
||||
if (bag_size_ == 0 || weightValue > weightFeatMax) {
|
||||
@ -138,19 +129,10 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
|
||||
CUDA_KERNEL_ASSERT(end >= begin);
|
||||
accscalar_t weightFeatSum = 0;
|
||||
int64_t bag_size_ = 0;
|
||||
|
||||
// Separate validation loop reduces register pressure in the main loop below.
|
||||
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
|
||||
bool has_invalid_index = false;
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
|
||||
}
|
||||
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
|
||||
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
bool pad = (input_idx == padding_idx);
|
||||
CUDA_KERNEL_ASSERT(0 <= input_idx && input_idx < numRows);
|
||||
const int64_t weightRow = input_idx * weight_stride0;
|
||||
scalar_t weightValue = weightFeat[weightRow];
|
||||
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
|
||||
|
||||
@ -78,9 +78,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType swizzle_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType swizzle_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
|
||||
@ -5,11 +5,69 @@
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
// ROCm 6.3 is planned to have these functions, but until then here they are.
|
||||
#if defined(USE_ROCM)
|
||||
#include <device_functions.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#define ATOMICADD unsafeAtomicAdd
|
||||
|
||||
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
|
||||
#if (defined(__gfx942__)) && \
|
||||
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
|
||||
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
|
||||
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
|
||||
union {
|
||||
__hip_bfloat162_raw bf162_raw;
|
||||
vec_short2 vs2;
|
||||
} u{static_cast<__hip_bfloat162_raw>(value)};
|
||||
u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2);
|
||||
return static_cast<__hip_bfloat162>(u.bf162_raw);
|
||||
#else
|
||||
static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw));
|
||||
union u_hold {
|
||||
__hip_bfloat162_raw h2r;
|
||||
unsigned int u32;
|
||||
};
|
||||
u_hold old_val, new_val;
|
||||
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
do {
|
||||
new_val.h2r = __hadd2(old_val.h2r, value);
|
||||
} while (!__hip_atomic_compare_exchange_strong(
|
||||
(unsigned int*)address, &old_val.u32, new_val.u32,
|
||||
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
|
||||
return old_val.h2r;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
|
||||
#if (defined(__gfx942__)) && \
|
||||
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
|
||||
// The api expects an ext_vector_type of half
|
||||
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
|
||||
static_assert(sizeof(vec_fp162) == sizeof(__half2_raw));
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
vec_fp162 fp16;
|
||||
} u {static_cast<__half2_raw>(value)};
|
||||
u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16);
|
||||
return static_cast<__half2>(u.h2r);
|
||||
#else
|
||||
static_assert(sizeof(__half2_raw) == sizeof(unsigned int));
|
||||
union u_hold {
|
||||
__half2_raw h2r;
|
||||
unsigned int u32;
|
||||
};
|
||||
u_hold old_val, new_val;
|
||||
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
do {
|
||||
new_val.h2r = __hadd2(old_val.h2r, value);
|
||||
} while (!__hip_atomic_compare_exchange_strong(
|
||||
(unsigned int*)address, &old_val.u32, new_val.u32,
|
||||
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
|
||||
return old_val.h2r;
|
||||
#endif
|
||||
}
|
||||
#define ATOMICADD preview_unsafeAtomicAdd
|
||||
#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f)
|
||||
#else
|
||||
#define ATOMICADD atomicAdd
|
||||
|
||||
@ -2,250 +2,18 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/JitLoops.cuh>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <ATen/native/cuda/ScanUtils.cuh>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/BinaryOps.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <c10/util/MathConstants.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
|
||||
// NOTE: CUDA on Windows requires that the enclosing function
|
||||
// of a __device__ lambda not have internal linkage.
|
||||
|
||||
namespace at::native {
|
||||
|
||||
// custom min and max to be used in logaddexp for complex arguments
|
||||
template <typename scalar_t, bool min>
|
||||
__host__ __device__ c10::complex<scalar_t> _logaddexp_minmax(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
|
||||
scalar_t xr = std::real(x);
|
||||
scalar_t yr = std::real(y);
|
||||
if (::isnan(yr) || (::isnan(std::imag(y)))) {
|
||||
return y;
|
||||
} else if (::isnan(xr) || (::isnan(std::imag(x)))) {
|
||||
return x;
|
||||
} else if (min) { // min
|
||||
return (xr < yr) ? x : y;
|
||||
} else { // max
|
||||
return (xr >= yr) ? x : y;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__host__ __device__ scalar_t _log_add_exp_helper(const scalar_t& x, const scalar_t& y) {
|
||||
// Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
|
||||
// Using the original expression: `at::_isnan(y) ? y : std::min(x, y)` causes an error in ROCM
|
||||
const auto isnan_x = at::_isnan(x);
|
||||
const auto isnan_y = at::_isnan(y);
|
||||
scalar_t min = isnan_y ? y : (isnan_x ? x : std::min(x, y));
|
||||
scalar_t max = isnan_y ? y : (isnan_x ? x : std::max(x, y));
|
||||
if (min != max || ::isfinite(min)) {
|
||||
// nan will be propagated here
|
||||
return ::log1p(std::exp(min - max)) + max;
|
||||
} else {
|
||||
// special case to correctly handle infinite cases
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__host__ __device__ c10::complex<scalar_t> _fast_build_exp(const c10::complex<scalar_t>& x) {
|
||||
// complex exponential function, but implemented manually to get fast compilation time
|
||||
// this function only handles the case where the x is finite (not inf nor nan)
|
||||
const auto xreal = std::real(x);
|
||||
const auto ximag = std::imag(x);
|
||||
const auto exp_x_abs = std::exp(xreal);
|
||||
auto exp_x_real = exp_x_abs * std::cos(ximag);
|
||||
auto exp_x_imag = exp_x_abs * std::sin(ximag);
|
||||
return {exp_x_real, exp_x_imag};
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__host__ __device__ c10::complex<scalar_t> _fast_build_exp_inf(const c10::complex<scalar_t>& x) {
|
||||
// complex exponential function, but implemented manually to get fast compilation time
|
||||
// this function only handles the case where the real part of x is infinite
|
||||
const auto ximag = std::imag(x);
|
||||
constexpr auto exp_x_abs = std::numeric_limits<scalar_t>::infinity();
|
||||
if (!::isfinite(ximag)) { // add this to make consitent with std::exp(x+yi)
|
||||
return {exp_x_abs, std::numeric_limits<scalar_t>::quiet_NaN()};
|
||||
}
|
||||
const auto sin = std::sin(ximag);
|
||||
const auto cos = std::cos(ximag);
|
||||
// special case if the angle is exactly the multiple of pi/2
|
||||
auto exp_x_real = (cos == 0) ? (scalar_t)0.0 : exp_x_abs * cos;
|
||||
auto exp_x_imag = (sin == 0) ? (scalar_t)0.0 : exp_x_abs * sin;
|
||||
return {exp_x_real, exp_x_imag};
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__host__ __device__ c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
|
||||
c10::complex<scalar_t> min = _logaddexp_minmax<scalar_t, /*min=*/true>(x, y);
|
||||
c10::complex<scalar_t> max = _logaddexp_minmax<scalar_t, /*min=*/false>(x, y);
|
||||
scalar_t min_real = std::real(min);
|
||||
scalar_t max_real = std::real(max);
|
||||
|
||||
if (::isnan(min_real) || ::isnan(std::imag(min))) {
|
||||
// handling the "infectious" NaNs
|
||||
return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
|
||||
}
|
||||
else if ((!::isfinite(min_real)) && (min_real == max_real)) {
|
||||
if (min_real < 0) {
|
||||
// handle the -inf case, the imaginary part here does not really matter as the exp(value)
|
||||
// will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
|
||||
// It does not matter if we're taking the exp of this value
|
||||
return min;
|
||||
} else {
|
||||
// handle the +inf case, we don't need the special precision for log1p for small values
|
||||
// and to avoid producing nan in case of real(max) == real(min) == +inf
|
||||
const auto exp_min = _fast_build_exp_inf(min);
|
||||
const auto exp_max = _fast_build_exp_inf(max);
|
||||
return ::log1p(exp_min + exp_max - 1); // log1p(x - 1) builds faster than log
|
||||
}
|
||||
} else {
|
||||
const auto minmax = min - max;
|
||||
c10::complex<scalar_t> exp_minmax;
|
||||
if (!::isfinite(minmax.real())) {
|
||||
exp_minmax = minmax.real() < 0 ? c10::complex<scalar_t>{0.0, 0.0} : _fast_build_exp_inf(minmax);
|
||||
} else {
|
||||
exp_minmax = _fast_build_exp(minmax);
|
||||
}
|
||||
return ::log1p(exp_minmax) + max;
|
||||
}
|
||||
}
|
||||
|
||||
// Complex logaddexp jiterator string
|
||||
const auto logaddexp_complex_string = jiterator_stringify(
|
||||
template<typename T>
|
||||
std::complex<T> log1p(const std::complex<T>& z)
|
||||
{
|
||||
using complex_t = std::complex<T>;
|
||||
T x = z.real();
|
||||
T y = z.imag();
|
||||
T zabs = abs(z);
|
||||
T theta = atan2(y, x + T(1));
|
||||
if (zabs < 0.5) {
|
||||
T r = x * (T(2) + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return complex_t(x, theta);
|
||||
}
|
||||
return complex_t(T(0.5) * std::log1p(r), theta);
|
||||
} else {
|
||||
T z0 = std::hypot(x + 1, y);
|
||||
return complex_t(log(z0), theta);
|
||||
}
|
||||
}
|
||||
|
||||
// separated _logaddexp_minmax into 2 different functions for jiterator_string
|
||||
template <typename T>
|
||||
std::complex<T> logaddexp_min(const std::complex<T>& x, const std::complex<T>& y) {
|
||||
T xr = x.real();
|
||||
T yr = y.real();
|
||||
if (isnan(yr) || isnan(y.imag())) {
|
||||
return y;
|
||||
} else if (isnan(xr) || isnan(x.imag())) {
|
||||
return x;
|
||||
} else {
|
||||
return (xr < yr) ? x : y;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::complex<T> logaddexp_max(const std::complex<T>& x, const std::complex<T>& y) {
|
||||
T xr = x.real();
|
||||
T yr = y.real();
|
||||
if (isnan(yr) || isnan(y.imag())) {
|
||||
return y;
|
||||
} else if (isnan(xr) || isnan(x.imag())) {
|
||||
return x;
|
||||
} else {
|
||||
return (xr >= yr) ? x : y;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::complex<T> fast_build_exp(const std::complex<T>& x) {
|
||||
const auto xreal = x.real();
|
||||
const auto ximag = x.imag();
|
||||
const auto exp_x_abs = exp(xreal);
|
||||
auto exp_x_real = exp_x_abs * cos(ximag);
|
||||
auto exp_x_imag = exp_x_abs * sin(ximag);
|
||||
return std::complex<T>(exp_x_real, exp_x_imag);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::complex<T> fast_build_exp_inf(const std::complex<T>& x) {
|
||||
using complex_t = std::complex<T>;
|
||||
const auto ximag = x.imag();
|
||||
const T exp_x_abs = INFINITY;
|
||||
if (!isfinite(ximag)) {
|
||||
return complex_t(exp_x_abs, NAN);
|
||||
}
|
||||
const auto sin_val = sin(ximag);
|
||||
const auto cos_val = cos(ximag);
|
||||
auto exp_x_real = (cos_val == T(0)) ? T(0) : exp_x_abs * cos_val;
|
||||
auto exp_x_imag = (sin_val == T(0)) ? T(0) : exp_x_abs * sin_val;
|
||||
return complex_t(exp_x_real, exp_x_imag);
|
||||
}
|
||||
|
||||
template <typename complex_t>
|
||||
complex_t logaddexp_complex(complex_t x, complex_t y) {
|
||||
using T = typename complex_t::value_type;
|
||||
complex_t min_val = logaddexp_min(x, y);
|
||||
complex_t max_val = logaddexp_max(x, y);
|
||||
T min_real = min_val.real();
|
||||
T max_real = max_val.real();
|
||||
|
||||
if (isnan(min_real) || isnan(min_val.imag())) {
|
||||
return complex_t(NAN, NAN);
|
||||
}
|
||||
else if ((!isfinite(min_real)) && (min_real == max_real)) {
|
||||
if (min_real < T(0)) {
|
||||
return min_val;
|
||||
} else {
|
||||
const auto exp_min = fast_build_exp_inf<T>(min_val);
|
||||
const auto exp_max = fast_build_exp_inf<T>(max_val);
|
||||
return log1p(exp_min + exp_max - complex_t(1, 0));
|
||||
}
|
||||
} else {
|
||||
const auto minmax = min_val - max_val;
|
||||
complex_t exp_minmax;
|
||||
if (!isfinite(minmax.real())) {
|
||||
exp_minmax = (minmax.real() < T(0)) ? complex_t(0, 0) : fast_build_exp_inf<T>(minmax);
|
||||
} else {
|
||||
exp_minmax = fast_build_exp<T>(minmax);
|
||||
}
|
||||
return log1p(exp_minmax) + max_val;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
constexpr char logaddexp_complex_name[] = "logaddexp_complex";
|
||||
void logaddexp_kernel_cuda(TensorIteratorBase& iter) {
|
||||
if (at::isComplexType(iter.dtype())) {
|
||||
#if AT_USE_JITERATOR()
|
||||
AT_DISPATCH_COMPLEX_TYPES_AND(at::ScalarType::ComplexHalf, iter.dtype(), "logaddexp_cuda", [&]() {
|
||||
jitted_gpu_kernel<
|
||||
/*name=*/logaddexp_complex_name,
|
||||
/*return_dtype=*/scalar_t,
|
||||
/*common_dtype=*/scalar_t,
|
||||
/*arity=*/2>(iter, logaddexp_complex_string);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_COMPLEX_TYPES_AND(at::ScalarType::ComplexHalf, iter.dtype(), "logaddexp_cuda", [&]() {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t {
|
||||
const auto a = static_cast<opmath_t>(a_);
|
||||
const auto b = static_cast<opmath_t>(b_);
|
||||
return static_cast<scalar_t>(_log_add_exp_helper(a, b));
|
||||
});
|
||||
});
|
||||
#endif
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::BFloat16, ScalarType::Half,
|
||||
iter.dtype(), "logaddexp_cuda",
|
||||
[&]() {
|
||||
@ -261,7 +29,6 @@ void logaddexp_kernel_cuda(TensorIteratorBase& iter) {
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void logaddexp2_kernel_cuda(TensorIteratorBase& iter) {
|
||||
|
||||
@ -740,12 +740,7 @@ _scaled_rowwise_rowwise(
|
||||
TORCH_CHECK_VALUE(scale_a.numel() == mat_a.size(0) && scale_a.scalar_type() == kFloat, "scale_a must have ", mat_a.size(0), " Float elements, got ", scale_a.numel())
|
||||
TORCH_CHECK_VALUE(scale_b.numel() == mat_b.size(1) && scale_b.scalar_type() == kFloat, "scale_b must have ", mat_b.size(1), " Float elements, got ", scale_b.numel())
|
||||
|
||||
// if we have a scale of shape [256, 1] (say), then stride can be [1, 0] - handle this case
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(1) == 1 ||
|
||||
scale_a.size(1) == 1,
|
||||
"expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1)
|
||||
);
|
||||
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
|
||||
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
|
||||
|
||||
auto scaling_choice_a = ScalingType::RowWise;
|
||||
@ -1101,19 +1096,6 @@ _scaled_mxfp8_mxfp8(
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
}
|
||||
|
||||
void
|
||||
_check_mxfp4_support() {
|
||||
#ifndef USE_ROCM
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// Only on B200 GPUs
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
// B200 = 10.0, B300 = 10.3
|
||||
dprops->major == 10,
|
||||
"MXFP4 scaling only supported in CUDA for B200/B300"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Tensor&
|
||||
_scaled_mxfp4_mxfp4(
|
||||
@ -1126,7 +1108,6 @@ _scaled_mxfp4_mxfp4(
|
||||
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#else
|
||||
_check_mxfp4_support();
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
|
||||
@ -267,15 +267,15 @@ void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, con
|
||||
* outer dimensions, which contains several "inner rows").
|
||||
* Each thread processes a single inner row at a time.
|
||||
*/
|
||||
template<typename scalar_t, typename index_t, class BinaryOp>
|
||||
template<typename scalar_t, class BinaryOp>
|
||||
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
|
||||
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
|
||||
const scalar_t init, BinaryOp binary_op)
|
||||
{
|
||||
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
|
||||
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
|
||||
const scalar_t *src = src_ + static_cast<index_t>(orow) * row_size * num_irows + irow;
|
||||
scalar_t *tgt = tgt_ + (index_t) orow * row_size * num_irows + irow;
|
||||
const scalar_t *src = src_ + orow * row_size * num_irows + irow;
|
||||
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
|
||||
scalar_t acc = init;
|
||||
|
||||
for (uint32_t col = 0; col < row_size; ++col) {
|
||||
@ -409,15 +409,10 @@ __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
|
||||
check_fits_in_unsigned(num_irows, "num_irows");
|
||||
check_fits_in_unsigned(num_orows, "num_orows");
|
||||
check_fits_in_unsigned(row_size, "row_size");
|
||||
if (static_cast<size_t>(num_irows) * num_orows * row_size <= UINT_MAX) {
|
||||
tensor_kernel_scan_outer_dim<scalar_t, uint32_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
|
||||
tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
|
||||
num_orows, num_irows, row_size, init, binary_op);
|
||||
} else {
|
||||
tensor_kernel_scan_outer_dim<scalar_t, size_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
|
||||
num_orows, num_irows, row_size, init, binary_op);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
|
||||
@ -337,6 +337,10 @@ Tensor _convolution_out(
|
||||
TORCH_CHECK(
|
||||
3 == ndim || 4 == ndim || 5 == ndim,
|
||||
"convolution only supports 3D, 4D, 5D tensor");
|
||||
// get computation format for Conv/TransposedConv
|
||||
bool is_channels_last_suggested =
|
||||
use_channels_last_for_conv(input_r, weight_r);
|
||||
|
||||
Tensor input = input_r, weight = weight_r;
|
||||
// PyTorch does not support ChannelsLast1D case,
|
||||
// thus we need the transformation here
|
||||
@ -344,8 +348,13 @@ Tensor _convolution_out(
|
||||
input = view4d(input_r);
|
||||
weight = view4d(weight_r);
|
||||
}
|
||||
// get computation format for Conv/TransposedConv
|
||||
bool is_channels_last_suggested = use_channels_last_for_conv(input, weight);
|
||||
// ensure the input/weight/bias/output are congituous in desired format
|
||||
at::MemoryFormat mfmt = is_channels_last_suggested
|
||||
? get_cl_tag_by_ndim(input.ndimension())
|
||||
: at::MemoryFormat::Contiguous;
|
||||
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
|
||||
input = input.contiguous(mfmt);
|
||||
weight = weight.contiguous(mfmt);
|
||||
|
||||
auto k = weight.ndimension();
|
||||
if (k == input.ndimension() + 1) {
|
||||
@ -379,14 +388,6 @@ Tensor _convolution_out(
|
||||
expand_param_if_needed(output_padding_, "output_padding", dim);
|
||||
params.groups = groups_;
|
||||
}
|
||||
|
||||
// ensure the input/weight/bias/output are congituous in desired format
|
||||
at::MemoryFormat mfmt = is_channels_last_suggested
|
||||
? get_cl_tag_by_ndim(input.ndimension())
|
||||
: at::MemoryFormat::Contiguous;
|
||||
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
|
||||
input = input.contiguous(mfmt);
|
||||
weight = weight.contiguous(mfmt);
|
||||
check_shape_forward(input, weight, bias, params, true);
|
||||
|
||||
Tensor output;
|
||||
@ -513,9 +514,18 @@ Tensor convolution_overrideable(
|
||||
at::borrow_from_optional_tensor(bias_r_opt);
|
||||
const Tensor& bias_r = *bias_r_maybe_owned;
|
||||
|
||||
auto k = weight_r.ndimension();
|
||||
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
|
||||
if (xpu_conv_use_channels_last(input_r, weight_r)) {
|
||||
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d
|
||||
: at::MemoryFormat::ChannelsLast;
|
||||
}
|
||||
Tensor input_c = input_r.contiguous(backend_memory_format);
|
||||
Tensor weight_c = weight_r.contiguous(backend_memory_format);
|
||||
|
||||
return _convolution(
|
||||
input_r,
|
||||
weight_r,
|
||||
input_c,
|
||||
weight_c,
|
||||
bias_r,
|
||||
stride_,
|
||||
padding_,
|
||||
|
||||
@ -1,342 +0,0 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/xpu/Blas.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace {
|
||||
/*
|
||||
* Scaling Type Determination:
|
||||
* ---------------------------
|
||||
* Conditions and corresponding Scaling Types:
|
||||
*
|
||||
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
|
||||
* - Returns BlockWise (with additional size checks).
|
||||
*
|
||||
* - Else if scale.numel() == 1:
|
||||
* - Returns TensorWise.
|
||||
*
|
||||
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) ==
|
||||
* 1:
|
||||
* - Returns RowWise.
|
||||
*
|
||||
* - Otherwise:
|
||||
* - Returns Error.
|
||||
*/
|
||||
|
||||
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return at::isFloat8Type(t.scalar_type()) &&
|
||||
scale.scalar_type() == at::kFloat && scale.numel() == 1;
|
||||
}
|
||||
|
||||
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return (
|
||||
at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat &&
|
||||
scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 &&
|
||||
scale.is_contiguous());
|
||||
}
|
||||
|
||||
bool is_desired_scaling(
|
||||
const at::Tensor& t,
|
||||
const at::Tensor& scale,
|
||||
ScalingType desired_scaling) {
|
||||
auto result = desired_scaling == ScalingType::TensorWise
|
||||
? is_tensorwise_scaling(t, scale)
|
||||
: is_rowwise_scaling(t, scale);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<ScalingType, ScalingType> get_joint_scaling(
|
||||
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
|
||||
const at::Tensor& a,
|
||||
const at::Tensor& b,
|
||||
const at::Tensor& scale_a,
|
||||
const at::Tensor& scale_b) {
|
||||
for (auto [lhs, rhs] : options) {
|
||||
if (is_desired_scaling(a, scale_a, lhs) &&
|
||||
is_desired_scaling(b.t(), scale_b.t(), rhs)) {
|
||||
return {lhs, rhs};
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Invalid scaling configuration.\n"
|
||||
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
|
||||
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
|
||||
a.size(0),
|
||||
", 1) and scale_b should be (1, ",
|
||||
b.size(1),
|
||||
"), and both should be contiguous.\n"
|
||||
"Got a.dtype()=",
|
||||
a.scalar_type(),
|
||||
", scale_a.dtype()=",
|
||||
scale_a.scalar_type(),
|
||||
", scale_a.size()=",
|
||||
scale_a.sizes(),
|
||||
", scale_a.stride()=",
|
||||
scale_a.strides(),
|
||||
", ",
|
||||
"b.dtype()=",
|
||||
b.scalar_type(),
|
||||
", scale_b.dtype()=",
|
||||
scale_b.scalar_type(),
|
||||
", scale_b.size()=",
|
||||
scale_b.sizes(),
|
||||
" and scale_b.stride()=",
|
||||
scale_b.strides());
|
||||
}
|
||||
|
||||
Tensor& _scaled_gemm(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const ScalingType scaling_choice_a,
|
||||
const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out,
|
||||
const std::optional<Tensor>& alpha = std::nullopt) {
|
||||
// TODO: scale_result and alpha is not defined or used!
|
||||
std::optional<Tensor> scaled_result = std::nullopt;
|
||||
at::native::onednn::scaled_matmul(
|
||||
mat1,
|
||||
mat2,
|
||||
out,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
scaled_result,
|
||||
use_fast_accum);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output
|
||||
// matrices Scales are only applicable when matrices are of Float8 type and
|
||||
// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit
|
||||
// type, scale_result is not applied. Known limitations:
|
||||
// - Only works if mat1 is row-major and mat2 is column-major
|
||||
// - Only works if matrices sizes are divisible by 32
|
||||
// - If 1-dimensional tensors are used then scale_a should be size =
|
||||
// mat1.size(0)
|
||||
// and scale_b should have size = to mat2.size(1)
|
||||
// Arguments:
|
||||
// - `mat1`: the first operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `mat2`: the second operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
|
||||
// - `out_dtype`: the output dtype, can either be a float8 or a higher
|
||||
// precision floating point type
|
||||
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_result`: a scalar tensor with the scale of the output, only
|
||||
// utilized if the output is a float8 type
|
||||
// - `use_fast_accum`: Not applicable for XPU. For now, it should always be
|
||||
// false.
|
||||
// - `out`: a reference to the output tensor
|
||||
|
||||
Tensor& _scaled_mm_out_xpu(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// Note: fast_accum is not supported in XPU for now.
|
||||
TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now.");
|
||||
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0],
|
||||
"mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0],
|
||||
"x",
|
||||
mat1.sizes()[1],
|
||||
" and ",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
")");
|
||||
|
||||
// Check what type of scaling we are doing based on inputs. This list is
|
||||
// sorted by decreasing priority.
|
||||
|
||||
// List of supported datatypes for XPU with oneDNN:
|
||||
// https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types
|
||||
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
|
||||
{
|
||||
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
|
||||
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
|
||||
},
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b);
|
||||
TORCH_CHECK(
|
||||
!scale_result ||
|
||||
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
|
||||
"scale_result must be a float scalar");
|
||||
TORCH_CHECK(
|
||||
!bias || bias->numel() == mat2.sizes()[1],
|
||||
"Bias must be size ",
|
||||
mat2.sizes()[1],
|
||||
" but got ",
|
||||
bias->numel());
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat1.sizes()[0],
|
||||
"x",
|
||||
mat1.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK(
|
||||
mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0,
|
||||
"mat2 shape (",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
") must be divisible by 16");
|
||||
// Check types
|
||||
TORCH_CHECK(
|
||||
!out_dtype || *out_dtype == out.scalar_type(),
|
||||
"out_dtype must match output matrix type");
|
||||
TORCH_CHECK(
|
||||
at::isFloat8Type(mat1.scalar_type()),
|
||||
"Expected mat1 to be Float8 matrix got ",
|
||||
mat1.scalar_type());
|
||||
TORCH_CHECK(
|
||||
at::isFloat8Type(mat2.scalar_type()),
|
||||
"Expected mat2 to be Float8 matrix got ",
|
||||
mat2.scalar_type());
|
||||
// TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not
|
||||
// support 2D scales, only 1D. Needs to add more checks there.
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(
|
||||
bias->scalar_type() == kFloat ||
|
||||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
|
||||
bias->scalar_type() == c10::ScalarType::Half,
|
||||
"Bias must be Float32 or BFloat16 or Half, but got ",
|
||||
bias->scalar_type());
|
||||
}
|
||||
|
||||
{
|
||||
auto bias_ = bias.value_or(Tensor());
|
||||
auto scale_result_ = scale_result.value_or(Tensor());
|
||||
|
||||
// NOLINTNEXTLINE(*c-array*)
|
||||
TensorArg targs[]{
|
||||
{out, "out", 0},
|
||||
{mat1, "mat1", 1},
|
||||
{mat2, "mat2", 2},
|
||||
{bias_, "bias", 3},
|
||||
{scale_a, "scale_a", 4},
|
||||
{scale_b, "scale_b", 5},
|
||||
{scale_result_, "scale_result", 6}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
}
|
||||
|
||||
// Validation checks have passed lets resize the output to actual size
|
||||
IntArrayRef mat1_sizes = mat1.sizes();
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
|
||||
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
|
||||
// kernels do not support this case).
|
||||
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
|
||||
// `out` was created with `at::empty`. In the case where we are multiplying
|
||||
// MxK by KxN and K is the zero dim, we need to initialize here to properly
|
||||
// return a tensor of zeros.
|
||||
if (mat1_sizes[1] == 0) {
|
||||
out.zero_();
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// TODO: Scale_result is not supported by now!!
|
||||
return _scaled_gemm(
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
Tensor _scaled_mm_xpu(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
|
||||
return _scaled_mm_out_xpu(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias,
|
||||
scale_result,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
@ -1,4 +1,3 @@
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
@ -9,6 +8,7 @@
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
|
||||
namespace at::native::onednn {
|
||||
|
||||
at::Tensor broadcast_bias2D(
|
||||
at::Tensor& dst,
|
||||
at::Tensor& bias,
|
||||
@ -328,236 +328,4 @@ void quantized_matmul(
|
||||
result.copy_(dst);
|
||||
}
|
||||
|
||||
// Describes how to configure oneDNN scales for a given role/ScalingType
|
||||
struct ScaleSpec {
|
||||
// specifies the way scale values will be applied to an ARG tensor.
|
||||
int mask;
|
||||
// specifies how scales are grouped along dimensions where
|
||||
// multiple scale factors are used.
|
||||
dnnl::memory::dims groups;
|
||||
// specifies data type for scale factors.
|
||||
dnnl::memory::data_type dtype;
|
||||
|
||||
// Helper to compute expected number of elements for scale tensors
|
||||
// arg_type: "src" for SRC (groups pattern {1, X}),
|
||||
// "wei" for WEIGHTS (groups pattern {X, 1})
|
||||
int64_t expected_numel(
|
||||
int64_t outer_dim,
|
||||
int64_t inner_dim,
|
||||
const std::string& arg_type) const {
|
||||
if (groups == dnnl::memory::dims{1, 1})
|
||||
return 1; // tensorwise scaling
|
||||
|
||||
TORCH_CHECK(
|
||||
arg_type == "src" || arg_type == "wei",
|
||||
"Expected arg_type to be 'src' or 'wei', but got '",
|
||||
arg_type,
|
||||
"'");
|
||||
|
||||
// For rowwise: SRC groups={1, K}, WEI groups={K, 1}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(groups == dnnl::memory::dims{1, inner_dim} ||
|
||||
groups == dnnl::memory::dims{inner_dim, 1}),
|
||||
"The groups must be either {1, inner_dim} or {inner_dim, 1}. But got ",
|
||||
groups,
|
||||
".");
|
||||
return outer_dim;
|
||||
}
|
||||
|
||||
// Normalize an incoming scale tensor to contiguous storage and appropriate
|
||||
// dtype/view
|
||||
at::Tensor normalize(const at::Tensor& scale) const {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dtype == dnnl::memory::data_type::f32,
|
||||
"tensor scale currently must be f32, but got scale dtype: ",
|
||||
scale.scalar_type());
|
||||
return scale.to(at::kFloat).contiguous();
|
||||
}
|
||||
};
|
||||
|
||||
// This function defines how to set scales mask and groups according to:
|
||||
// https://github.com/uxlfoundation/oneDNN/blob/main/tests/benchdnn/doc/knobs_attr.md#--attr-scales
|
||||
// The returned value will be used in
|
||||
// `set_scales(arg, mask, groups, data_type)`.
|
||||
inline ScaleSpec make_scale_spec(
|
||||
at::blas::ScalingType scaling_type,
|
||||
int64_t M,
|
||||
int64_t K,
|
||||
int64_t N,
|
||||
const std::string& arg_type) {
|
||||
TORCH_CHECK(
|
||||
arg_type == "src" || arg_type == "wei",
|
||||
"Expected arg_type to be 'src' or 'wei', but got '",
|
||||
arg_type,
|
||||
"'");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(scaling_type == at::blas::ScalingType::TensorWise ||
|
||||
scaling_type == at::blas::ScalingType::RowWise),
|
||||
"Currently only support scaling_type for TensorWise or RowWise");
|
||||
int64_t dim = K; // Currently only K is used for grouping
|
||||
bool is_src = (arg_type == "src");
|
||||
if (scaling_type == at::blas::ScalingType::TensorWise) {
|
||||
// Scale tensorwise. The same as `--attr-scales=common`.
|
||||
// mask=0 : scale whole tensor
|
||||
// groups={1, 1}: indicates that there is only one group for scaling
|
||||
return {0, {1, 1}, dnnl::memory::data_type::f32};
|
||||
} else {
|
||||
// (scaling_type == at::blas::ScalingType::RowWise)
|
||||
// Scale RowWise. The same as `--attr-scales=per_dim_01`.
|
||||
// mask={(1 << 0) | (1 << 1)}: Scale on both dim0 and dim1
|
||||
// SRC: groups={1, K}, WEIGHTS: groups={K, 1}
|
||||
return {
|
||||
(1 << 0) | (1 << 1),
|
||||
is_src ? dnnl::memory::dims{1, dim} : dnnl::memory::dims{dim, 1},
|
||||
dnnl::memory::data_type::f32};
|
||||
}
|
||||
}
|
||||
|
||||
sycl::event scaled_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
Tensor& result,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
at::blas::ScalingType scaling_choice_a,
|
||||
at::blas::ScalingType scaling_choice_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
bool use_fast_accum) {
|
||||
auto& engine = GpuEngineManager::Instance().get_engine();
|
||||
auto& stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
// This function will do steps with following steps
|
||||
// 1. create memory descriptor
|
||||
// 2. call write_to_dnnl_memory() to actually write memory
|
||||
// 3. execute
|
||||
|
||||
const int64_t M = mat1.size(0);
|
||||
const int64_t K = mat1.size(1);
|
||||
const int64_t N = mat2.size(1);
|
||||
|
||||
// 1.1 Create memory descriptor
|
||||
dnnl::memory::desc src_md = get_onednn_md(mat1);
|
||||
dnnl::memory::desc weights_md = get_onednn_md(mat2);
|
||||
dnnl::memory::desc dst_md = get_onednn_md(result);
|
||||
|
||||
// scale_a and scale_b has already be checked in `is_desired_scaling()` call.
|
||||
// So we could directly get their memory desc and set later.
|
||||
dnnl::memory::desc scale_a_md = get_onednn_md(scale_a);
|
||||
dnnl::memory::desc scale_b_md = get_onednn_md(scale_b);
|
||||
|
||||
dnnl::memory::desc bias_md;
|
||||
bool with_bias = bias.has_value();
|
||||
at::Tensor possible_reshaped_bias = bias.value_or(at::Tensor());
|
||||
if (with_bias) {
|
||||
if (possible_reshaped_bias.dim() == 1) {
|
||||
possible_reshaped_bias =
|
||||
possible_reshaped_bias.reshape({1, possible_reshaped_bias.size(0)});
|
||||
bias_md = get_onednn_md(possible_reshaped_bias);
|
||||
} else {
|
||||
bias_md = get_onednn_md(possible_reshaped_bias);
|
||||
}
|
||||
}
|
||||
|
||||
// 1.2 Create primitive descriptor and set scales mask
|
||||
const ScaleSpec src_spec = make_scale_spec(scaling_choice_a, M, K, N, "src");
|
||||
const ScaleSpec wei_spec = make_scale_spec(scaling_choice_b, M, K, N, "wei");
|
||||
|
||||
dnnl::primitive_attr op_attr = dnnl::primitive_attr();
|
||||
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn())
|
||||
op_attr.set_deterministic(true);
|
||||
#endif
|
||||
|
||||
std::vector<int64_t> default_groups;
|
||||
op_attr.set_scales(
|
||||
DNNL_ARG_SRC, src_spec.mask, src_spec.groups, src_spec.dtype);
|
||||
op_attr.set_scales(
|
||||
DNNL_ARG_WEIGHTS, wei_spec.mask, wei_spec.groups, wei_spec.dtype);
|
||||
// scale_result tensor currently only supports scalar(TensorWise Scaling).
|
||||
bool with_dst_scale = scale_result && scale_result->defined();
|
||||
if (with_dst_scale) {
|
||||
op_attr.set_scales(DNNL_ARG_DST, 0, {1}, dnnl::memory::data_type::f32);
|
||||
}
|
||||
|
||||
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
// 1.3 Create the matmul primitive descriptor
|
||||
dnnl::matmul::primitive_desc matmul_pd = with_bias
|
||||
? dnnl::matmul::primitive_desc(
|
||||
engine, src_md, weights_md, bias_md, dst_md, op_attr)
|
||||
: dnnl::matmul::primitive_desc(
|
||||
engine, src_md, weights_md, dst_md, op_attr);
|
||||
|
||||
// 1.4 (Possible) Additional Checks
|
||||
// TODO: In case there are memory desc does not align with the actual tensor,
|
||||
// we might need to reorder weights similar to CPU's reorder_if_differ_in()
|
||||
// call. For example, weights not the same as matmul_pd.weights_desc(),
|
||||
|
||||
// 2. Prepare memory
|
||||
|
||||
// Create memory
|
||||
auto src_usr_m = make_onednn_memory(src_md, engine, mat1.data_ptr());
|
||||
auto weights_usr_m = make_onednn_memory(weights_md, engine, mat2.data_ptr());
|
||||
auto dst_usr_m = make_onednn_memory(dst_md, engine, result.data_ptr());
|
||||
dnnl::memory b_usr_m;
|
||||
if (with_bias) {
|
||||
b_usr_m =
|
||||
make_onednn_memory(bias_md, engine, possible_reshaped_bias.data_ptr());
|
||||
}
|
||||
|
||||
// Prepare runtime scale memories (flat 1-D views) using the specs
|
||||
auto make_scale_mem_from_spec = [&](const ScaleSpec& spec,
|
||||
int64_t expected_numel,
|
||||
const at::Tensor& scale_tensor) {
|
||||
at::Tensor prepared = spec.normalize(scale_tensor);
|
||||
TORCH_CHECK(
|
||||
prepared.numel() == expected_numel,
|
||||
"Scale buffer length mismatch. Expected ",
|
||||
expected_numel,
|
||||
", got ",
|
||||
prepared.numel());
|
||||
dnnl::memory::desc scale_md(
|
||||
{prepared.numel()}, spec.dtype, dnnl::memory::format_tag::x);
|
||||
return make_onednn_memory(scale_md, engine, prepared.data_ptr());
|
||||
};
|
||||
|
||||
auto scratchpad =
|
||||
make_onednn_memory(matmul_pd.scratchpad_desc(), engine, nullptr);
|
||||
|
||||
// 3. Setup Args for exec
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
args.insert({DNNL_ARG_SRC, src_usr_m});
|
||||
args.insert({DNNL_ARG_WEIGHTS, weights_usr_m});
|
||||
args.insert({DNNL_ARG_DST, dst_usr_m});
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
|
||||
if (with_bias) {
|
||||
args.insert({DNNL_ARG_BIAS, b_usr_m});
|
||||
}
|
||||
|
||||
// Attach runtime scales using specs
|
||||
auto src_sc_mem = make_scale_mem_from_spec(
|
||||
src_spec, src_spec.expected_numel(M, K, "src"), scale_a);
|
||||
auto wei_sc_mem = make_scale_mem_from_spec(
|
||||
wei_spec, wei_spec.expected_numel(N, K, "wei"), scale_b);
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_mem});
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_sc_mem});
|
||||
if (with_dst_scale) {
|
||||
// Bind single f32 scalar as DST scale
|
||||
at::Tensor dst_scale_f32 = scale_result->to(at::kFloat).contiguous();
|
||||
dnnl::memory::desc dst_sc_md(
|
||||
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
|
||||
auto dst_sc_mem =
|
||||
make_onednn_memory(dst_sc_md, engine, dst_scale_f32.data_ptr());
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_mem});
|
||||
}
|
||||
|
||||
dnnl::matmul matmul_p = dnnl::matmul(matmul_pd);
|
||||
sycl::event matmul_fwd_event =
|
||||
dnnl::sycl_interop::execute(matmul_p, stream, args);
|
||||
return matmul_fwd_event;
|
||||
}
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
||||
@ -78,10 +78,6 @@ dnnl::memory::data_type get_onednn_dtype(
|
||||
return dnnl::memory::data_type::f32;
|
||||
case at::ScalarType::BFloat16:
|
||||
return dnnl::memory::data_type::bf16;
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
return dnnl::memory::data_type::f8_e4m3;
|
||||
case at::ScalarType::Float8_e5m2:
|
||||
return dnnl::memory::data_type::f8_e5m2;
|
||||
default:
|
||||
if (!allow_undef) {
|
||||
TORCH_CHECK(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
@ -203,16 +202,4 @@ void sdpa_backward(
|
||||
Tensor& grad_query,
|
||||
Tensor& grad_key,
|
||||
Tensor& grad_value);
|
||||
|
||||
sycl::event scaled_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
Tensor& result,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
at::blas::ScalingType scaling_choice_a,
|
||||
at::blas::ScalingType scaling_choice_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
bool use_fast_accum);
|
||||
} // namespace at::native::onednn
|
||||
|
||||
@ -82,7 +82,6 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
|
||||
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
|
||||
std::string getMPSShapeString(MPSShape* shape);
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
|
||||
std::string to_hex_key(float);
|
||||
std::string getArrayRefString(const IntArrayRef s);
|
||||
// use has_storage() on the returned tensor to determine if src actually is a view
|
||||
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
|
||||
|
||||
@ -301,10 +301,6 @@ std::string getArrayRefString(const IntArrayRef s) {
|
||||
return fmt::to_string(fmt::join(s, ","));
|
||||
}
|
||||
|
||||
std::string to_hex_key(float f) {
|
||||
return fmt::format("{:a}", f);
|
||||
}
|
||||
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
|
||||
fmt::basic_memory_buffer<char, 100> buffer;
|
||||
auto buf_iterator = std::back_inserter(buffer);
|
||||
|
||||
@ -40,7 +40,7 @@ inline c10::metal::opmath_t<T> matmul_inner(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint k = 0; k < TILE_DIM; k++) {
|
||||
sum += c10::metal::mul(A_tile[tid.y][k], B_tile[k][tid.x]);
|
||||
sum += A_tile[tid.y][k] * B_tile[k][tid.x];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -96,9 +96,7 @@ kernel void addmm(
|
||||
auto bias =
|
||||
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
|
||||
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
|
||||
static_cast<T>(
|
||||
c10::metal::mul(alpha_beta[0], sum) +
|
||||
c10::metal::mul(alpha_beta[1], bias));
|
||||
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
|
||||
}
|
||||
}
|
||||
|
||||
@ -834,10 +832,6 @@ INSTANTIATE_MM_OPS(float);
|
||||
INSTANTIATE_MM_OPS(half);
|
||||
INSTANTIATE_MM_OPS(bfloat);
|
||||
|
||||
// Complex MM
|
||||
INSTANTIATE_MM_OPS(float2);
|
||||
INSTANTIATE_MM_OPS(half2);
|
||||
|
||||
// Integral MM
|
||||
INSTANTIATE_MM_OPS(long);
|
||||
INSTANTIATE_MM_OPS(int);
|
||||
|
||||
@ -69,139 +69,75 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
||||
auto out = at::empty({batchSize, num_head, qSize, headSize}, query.options());
|
||||
auto attn = at::empty({batchSize, num_head, qSize, maxSeqLength}, query.options());
|
||||
auto scale_factor = sdp::calculate_scale(query, scale).expect_float();
|
||||
static const bool is_macOS_26_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_26_0_PLUS);
|
||||
@autoreleasepool {
|
||||
auto mkey = __func__ + getTensorsStringKey({query, key, value}) + ":" + std::to_string(is_causal) + ":" +
|
||||
std::to_string(attn_mask.has_value());
|
||||
auto cachedGraph =
|
||||
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
|
||||
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
|
||||
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
|
||||
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
|
||||
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
|
||||
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor
|
||||
shape:getMPSShape({1})
|
||||
dataType:MPSDataTypeFloat32];
|
||||
|
||||
CachedGraph* cachedGraph;
|
||||
//if(is_macOS_26_0_or_newer) {
|
||||
if(true) {
|
||||
cachedGraph =
|
||||
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
|
||||
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
|
||||
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
|
||||
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
|
||||
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
|
||||
|
||||
if (is_causal) {
|
||||
MPSShape* maskShape = @[@(qSize), @(maxSeqLength)];
|
||||
auto x = [mpsGraph coordinateAlongAxis:-1 withShape:@[@(qSize), @1] name:nil];
|
||||
auto y = [mpsGraph coordinateAlongAxis:-2 withShape:@[@1, @(maxSeqLength)] name:nil];
|
||||
auto isLess = [mpsGraph lessThanOrEqualToWithPrimaryTensor:x secondaryTensor:y name:nil];
|
||||
auto causalMask = [mpsGraph selectWithPredicateTensor:isLess
|
||||
truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:qTensor.dataType]
|
||||
falsePredicateTensor:[mpsGraph constantWithScalar:-INFINITY dataType:qTensor.dataType]
|
||||
name:nil];
|
||||
graph->maskTensor = causalMask;
|
||||
} else if (attn_mask) {
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
|
||||
}
|
||||
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
|
||||
// bug in MacOS15, without this trick SDPA leaks memory, adding 0.0f gets ignored(still takes SDPA sequence
|
||||
// path which leaks)
|
||||
auto oneTensor = [mpsGraph constantWithScalar:1e-20f shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:oneTensor name:nil];
|
||||
}
|
||||
|
||||
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
|
||||
// Overwrites expected NANs in sm with zeros.
|
||||
// auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
// auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
|
||||
// auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
|
||||
// auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
|
||||
// auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
//
|
||||
// auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
|
||||
// MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
|
||||
// truePredicateTensor:zeroTensor
|
||||
// falsePredicateTensor:sm
|
||||
// name:nil];
|
||||
//
|
||||
// auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
|
||||
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
|
||||
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
|
||||
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
|
||||
|
||||
MPSGraphTensor* output;
|
||||
if(graph->maskTensor != nil) {
|
||||
output = [mpsGraph scaledDotProductAttentionWithQueryTensor:qTensor
|
||||
keyTensor:kTensor
|
||||
valueTensor:vTensor
|
||||
maskTensor:graph->maskTensor
|
||||
scale:scale_factor
|
||||
name:@"MPSGraph SDPA"];
|
||||
} else {
|
||||
output = [mpsGraph scaledDotProductAttentionWithQueryTensor:qTensor
|
||||
keyTensor:kTensor
|
||||
valueTensor:vTensor
|
||||
scale:scale_factor
|
||||
name:@"MPSGraph SDPA"];
|
||||
}
|
||||
graph->qTensor = qTensor;
|
||||
graph->kTensor = kTensor;
|
||||
graph->vTensor = vTensor;
|
||||
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
|
||||
// graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
|
||||
});
|
||||
} else {
|
||||
cachedGraph =
|
||||
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
|
||||
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
|
||||
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
|
||||
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
|
||||
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
|
||||
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor
|
||||
shape:getMPSShape({1})
|
||||
dataType:MPSDataTypeFloat32];
|
||||
if (is_causal) {
|
||||
auto causalMask = [mpsGraph constantWithScalar:1.0f
|
||||
shape:getMPSShape({qSize, maxSeqLength})
|
||||
dataType:MPSDataTypeBool];
|
||||
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
|
||||
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
|
||||
truePredicateTensor:maskedMM
|
||||
falsePredicateTensor:minusInf
|
||||
name:nil];
|
||||
} else if (attn_mask) {
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
|
||||
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
|
||||
name:nil];
|
||||
}
|
||||
|
||||
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
|
||||
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
|
||||
// Overwrites expected NANs in sm with zeros.
|
||||
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
|
||||
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
|
||||
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
|
||||
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
|
||||
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
|
||||
// bug in MacOS15, without this trick SDPA leaks memory, adding 0.0f gets ignored(still takes SDPA sequence
|
||||
// path which leaks)
|
||||
auto oneTensor = [mpsGraph constantWithScalar:1e-20f shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:oneTensor name:nil];
|
||||
}
|
||||
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
|
||||
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
|
||||
truePredicateTensor:zeroTensor
|
||||
falsePredicateTensor:sm
|
||||
name:nil];
|
||||
|
||||
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
|
||||
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
|
||||
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
|
||||
|
||||
if (is_causal) {
|
||||
auto causalMask = [mpsGraph constantWithScalar:1.0f
|
||||
shape:getMPSShape({qSize, maxSeqLength})
|
||||
dataType:MPSDataTypeBool];
|
||||
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
|
||||
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
|
||||
truePredicateTensor:maskedMM
|
||||
falsePredicateTensor:minusInf
|
||||
name:nil];
|
||||
} else if (attn_mask) {
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
|
||||
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
|
||||
name:nil];
|
||||
}
|
||||
|
||||
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
|
||||
// Overwrites expected NANs in sm with zeros.
|
||||
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
|
||||
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
|
||||
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
|
||||
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
|
||||
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
|
||||
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
|
||||
truePredicateTensor:zeroTensor
|
||||
falsePredicateTensor:sm
|
||||
name:nil];
|
||||
|
||||
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
|
||||
graph->qTensor = qTensor;
|
||||
graph->kTensor = kTensor;
|
||||
graph->vTensor = vTensor;
|
||||
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
|
||||
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
|
||||
});
|
||||
}
|
||||
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
|
||||
graph->qTensor = qTensor;
|
||||
graph->kTensor = kTensor;
|
||||
graph->vTensor = vTensor;
|
||||
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
|
||||
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
|
||||
});
|
||||
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
|
||||
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);
|
||||
auto vPlaceholder = Placeholder(cachedGraph->vTensor, value);
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, out);
|
||||
// auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
|
||||
auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
|
||||
NSDictionary* feeds = nil;
|
||||
if (!attn_mask) {
|
||||
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder);
|
||||
@ -209,8 +145,7 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
||||
auto mPlaceholder = Placeholder(cachedGraph->maskTensor, *attn_mask);
|
||||
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder, mPlaceholder);
|
||||
}
|
||||
// NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
|
||||
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder);
|
||||
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outs);
|
||||
}
|
||||
|
||||
|
||||
@ -121,7 +121,7 @@ Tensor& do_metal_addmm(const Tensor& self,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta,
|
||||
const Tensor& bias) {
|
||||
if (beta.isFloatingPoint() && alpha.isFloatingPoint() && beta.toDouble() == 0 && alpha.toDouble() == 1) {
|
||||
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
|
||||
return do_metal_mm(self, other, output);
|
||||
}
|
||||
auto stream = getCurrentMPSStream();
|
||||
@ -147,15 +147,13 @@ Tensor& do_metal_addmm(const Tensor& self,
|
||||
std::array<int64_t, 2> i64;
|
||||
std::array<int32_t, 2> i32;
|
||||
std::array<float, 2> f32;
|
||||
std::array<c10::complex<float>, 2> c64;
|
||||
} alpha_beta{};
|
||||
} alpha_beta;
|
||||
if (output.scalar_type() == kLong) {
|
||||
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
|
||||
} else if (c10::isIntegralType(output.scalar_type(), true)) {
|
||||
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
|
||||
} else if (c10::isComplexType(output.scalar_type())) {
|
||||
alpha_beta.c64 = {alpha.toComplexFloat(), beta.toComplexFloat()};
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
|
||||
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
|
||||
}
|
||||
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
|
||||
@ -192,16 +190,10 @@ std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* gr
|
||||
bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
static bool always_use_metal = c10::utils::has_env("PYTORCH_MPS_PREFER_METAL");
|
||||
constexpr auto max_stride_size = 32768;
|
||||
constexpr auto max_complex_inner_size = 2048;
|
||||
static bool is_macos_14_4_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
|
||||
if (always_use_metal || c10::isIntegralType(self.scalar_type(), true)) {
|
||||
return true;
|
||||
}
|
||||
// multiplicationWithPrimaryTensor: returns incorrect results if inner size exceeds 2048
|
||||
// See https://github.com/pytorch/pytorch/issues/167727#issuecomment-3529308548
|
||||
if (c10::isComplexType(self.scalar_type()) && self.size(1) > max_complex_inner_size) {
|
||||
return true;
|
||||
}
|
||||
return !is_macos_14_4_or_newer &&
|
||||
(self.stride(0) > max_stride_size || self.stride(1) > max_stride_size || self.size(0) > max_stride_size ||
|
||||
self.size(1) > max_stride_size || other.stride(0) > max_stride_size || other.stride(1) > max_stride_size ||
|
||||
|
||||
@ -91,30 +91,25 @@ static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||
#include <ATen/native/mps/Repeat_metallib.h>
|
||||
#endif
|
||||
|
||||
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
|
||||
TORCH_CHECK(repeat.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
|
||||
template <typename index_t>
|
||||
void computeRepeatIndices(const index_t* repeat_ptr,
|
||||
const int64_t* cumsum_ptr,
|
||||
index_t* result_ptr,
|
||||
int64_t size,
|
||||
int64_t result_size) {
|
||||
id<MTLBuffer> repeatBuffer = reinterpret_cast<id<MTLBuffer>>(repeat_ptr);
|
||||
id<MTLBuffer> cumsumBuffer = reinterpret_cast<id<MTLBuffer>>(cumsum_ptr);
|
||||
id<MTLBuffer> resultBuffer = reinterpret_cast<id<MTLBuffer>>(result_ptr);
|
||||
TORCH_CHECK(repeatBuffer && cumsumBuffer && resultBuffer);
|
||||
|
||||
std::string scalar_type;
|
||||
if (repeat.scalar_type() == kInt) {
|
||||
if constexpr (std::is_same_v<index_t, int32_t>) {
|
||||
scalar_type = "int32_t";
|
||||
} else if (repeat.scalar_type() == kLong) {
|
||||
} else if constexpr (std::is_same_v<index_t, int64_t>) {
|
||||
scalar_type = "int64_t";
|
||||
} else {
|
||||
TORCH_CHECK(false, "repeats has to be Long or Int tensor");
|
||||
TORCH_CHECK(false, "repeat_interleave: unsupported indexing data type");
|
||||
}
|
||||
if (repeat.size(0) == 0) {
|
||||
return at::empty_like(repeat, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
}
|
||||
Tensor repeat_ = repeat.contiguous();
|
||||
Tensor cumsum = repeat.cumsum(0);
|
||||
int64_t total = 0;
|
||||
if (output_size.has_value()) {
|
||||
total = output_size.value();
|
||||
} else {
|
||||
total = cumsum[-1].item<int64_t>();
|
||||
TORCH_CHECK((repeat >= 0).all().item<uint8_t>(), "repeats can not be negative");
|
||||
}
|
||||
|
||||
auto result = at::empty({total}, repeat.options());
|
||||
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
dispatch_sync(mpsStream->queue(), ^() {
|
||||
@ -126,13 +121,20 @@ Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output
|
||||
getMPSProfiler().beginProfileKernel(pipelineState, "repeat_interleave:" + scalar_type, false);
|
||||
|
||||
[computeEncoder setComputePipelineState:pipelineState];
|
||||
mps::mtl_setArgs(computeEncoder, repeat_, cumsum, result, repeat.size(0));
|
||||
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, repeat.size(0));
|
||||
mps::mtl_setArgs(computeEncoder, repeatBuffer, cumsumBuffer, resultBuffer, size);
|
||||
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, size);
|
||||
|
||||
getMPSProfiler().endProfileKernel(pipelineState);
|
||||
}
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
|
||||
Tensor output;
|
||||
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
|
||||
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TensorCompare.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -90,21 +89,13 @@ static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor&
|
||||
auto clamp_shape = clamp_opt->sizes();
|
||||
auto input_shape = input_t.sizes();
|
||||
|
||||
if (num_clamp_dims > num_input_dims) {
|
||||
auto leading_dims = num_clamp_dims - num_input_dims;
|
||||
for (int64_t i = 0; i < leading_dims; ++i) {
|
||||
TORCH_CHECK(clamp_shape[i] == 1,
|
||||
op_name + ": clamp tensor leading shape must be 1 to broadcast with input tensor");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(num_clamp_dims <= num_input_dims,
|
||||
op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
|
||||
|
||||
auto clamp_idx = num_clamp_dims - 1;
|
||||
auto input_idx = num_input_dims - 1;
|
||||
auto common_dims = std::min(num_clamp_dims, num_input_dims);
|
||||
for (int64_t i = 0; i < common_dims; ++i)
|
||||
for (int i = 0; i < num_clamp_dims; i++)
|
||||
// One of the indices is allowed to be 1; will be handled by broadcast
|
||||
TORCH_CHECK(clamp_shape[clamp_idx - i] == input_shape[input_idx - i] || clamp_shape[clamp_idx - i] == 1 ||
|
||||
input_shape[input_idx - i] == 1,
|
||||
TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] ||
|
||||
clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1,
|
||||
op_name + ": clamp tensor trailing shape must match input tensor")
|
||||
}
|
||||
}
|
||||
@ -145,6 +136,9 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
|
||||
auto result_type = output_t.scalar_type();
|
||||
|
||||
IntArrayRef new_min_shape;
|
||||
IntArrayRef new_max_shape;
|
||||
|
||||
auto num_min_dims = min_opt->dim();
|
||||
auto num_max_dims = max_opt->dim();
|
||||
auto num_input_dims = input_t.dim();
|
||||
@ -152,32 +146,24 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
std::vector<int64_t> new_min_arr(num_input_dims);
|
||||
std::vector<int64_t> new_max_arr(num_input_dims);
|
||||
|
||||
if (has_min && num_min_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes());
|
||||
new_min_shape = IntArrayRef(new_min_arr);
|
||||
}
|
||||
|
||||
if (has_max && num_max_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes());
|
||||
new_max_shape = IntArrayRef(new_max_arr);
|
||||
}
|
||||
|
||||
Tensor min_opt_tensor;
|
||||
Tensor max_opt_tensor;
|
||||
|
||||
auto reshape_clamp_tensor = [&](const OptionalTensorRef clamp_tensor_ref,
|
||||
int64_t num_clamp_dims,
|
||||
std::vector<int64_t>& new_shape_storage) -> Tensor {
|
||||
IntArrayRef clamp_shape = clamp_tensor_ref->sizes();
|
||||
bool requires_view = false;
|
||||
|
||||
if (num_clamp_dims > num_input_dims) {
|
||||
clamp_shape = clamp_shape.slice(num_clamp_dims - num_input_dims);
|
||||
requires_view = true;
|
||||
} else if (num_clamp_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_clamp_dims, new_shape_storage.data(), clamp_shape);
|
||||
clamp_shape = IntArrayRef(new_shape_storage);
|
||||
requires_view = true;
|
||||
}
|
||||
|
||||
return requires_view ? (*clamp_tensor_ref).view(clamp_shape) : *clamp_tensor_ref;
|
||||
};
|
||||
|
||||
if (has_min) {
|
||||
min_opt_tensor = reshape_clamp_tensor(min_opt, num_min_dims, new_min_arr);
|
||||
min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt;
|
||||
}
|
||||
if (has_max) {
|
||||
max_opt_tensor = reshape_clamp_tensor(max_opt, num_max_dims, new_max_arr);
|
||||
max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt;
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
@ -258,8 +244,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar
|
||||
|
||||
@ -4225,7 +4225,7 @@
|
||||
MTIA: mm_out_mtia
|
||||
MPS: mm_out_mps
|
||||
XPU: mm_out_xpu
|
||||
SparseCPU, SparseCUDA, SparseMPS: _sparse_mm_out
|
||||
SparseCPU, SparseCUDA: _sparse_mm_out
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm_out
|
||||
|
||||
- func: mm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor
|
||||
@ -7518,7 +7518,7 @@
|
||||
- func: _sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: sparse_mask_projection
|
||||
SparseCPU, SparseCUDA: sparse_mask_projection
|
||||
autogen: _sparse_mask_projection.out
|
||||
|
||||
- func: _to_cpu(Tensor[] tensors) -> Tensor[]
|
||||
|
||||
@ -30,12 +30,10 @@
|
||||
|
||||
#include <thrust/binary_search.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/distance.h>
|
||||
#include <thrust/iterator/constant_iterator.h>
|
||||
#include <thrust/scan.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/system/cuda/execution_policy.h>
|
||||
#include <thrust/iterator/constant_iterator.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusparse.h>
|
||||
@ -49,7 +47,6 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/distance.h>
|
||||
#include <thrust/for_each.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/gather.h>
|
||||
|
||||
@ -445,33 +445,6 @@ static SparseTensor& mul_out_dense_sparse_mps(
|
||||
return out;
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor, int64_t> mps_intersect_binary_search(
|
||||
const Tensor& A_keys,
|
||||
const Tensor& B_keys,
|
||||
int64_t lenA,
|
||||
int64_t lenB,
|
||||
bool boolean_flag) {
|
||||
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto outA_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
|
||||
auto outB_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
|
||||
auto counter = at::zeros({1}, A_keys.options().dtype(at::kInt));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), boolean_flag);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const auto match_count = static_cast<int64_t>(counter.item<int32_t>());
|
||||
return std::make_tuple(std::move(outA_idx), std::move(outB_idx), match_count);
|
||||
}
|
||||
|
||||
|
||||
SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTensor& r_) {
|
||||
TORCH_CHECK(r_.is_mps(), "mul: expected 'out' to be MPS, but got ", r_.device());
|
||||
@ -550,10 +523,22 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
||||
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
|
||||
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
|
||||
|
||||
auto [outA_idx, outB_idx, M_int64] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_lhs);
|
||||
auto outA_idx = at::empty({lenA}, at::device(device).dtype(kLong));
|
||||
auto outB_idx = at::empty({lenA}, at::device(device).dtype(kLong));
|
||||
auto counter = at::zeros({1}, at::device(device).dtype(kInt));
|
||||
|
||||
const auto M = static_cast<uint32_t>(M_int64); // number of structural matches
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), A_is_lhs);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const uint32_t M = counter.item<int32_t>(); // number of structural matches
|
||||
|
||||
r_.resize_as_(lhs);
|
||||
|
||||
@ -777,14 +762,6 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self,
|
||||
|
||||
using OptTensor = std::optional<Tensor>;
|
||||
|
||||
static Tensor create_sparse_output_values(
|
||||
const Tensor& template_values,
|
||||
int64_t output_nnz,
|
||||
ScalarType dtype) {
|
||||
auto out_val_sizes = template_values.sizes().vec();
|
||||
out_val_sizes[0] = output_nnz;
|
||||
return at::zeros(out_val_sizes, template_values.options().dtype(dtype));
|
||||
}
|
||||
|
||||
static void sparse_mask_apply_out_mps_kernel(
|
||||
Tensor& result,
|
||||
@ -806,9 +783,9 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
auto src = src_in.coalesce();
|
||||
auto mask = coalesce_mask ? mask_in.coalesce() : mask_in;
|
||||
|
||||
const auto src_nnz = src._nnz();
|
||||
const auto mask_nnz = mask._nnz();
|
||||
const auto sd = src.sparse_dim();
|
||||
const int64_t src_nnz = src._nnz();
|
||||
const int64_t mask_nnz = mask._nnz();
|
||||
const int64_t sd = src.sparse_dim();
|
||||
result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim());
|
||||
|
||||
auto commonDtype = at::result_type(src, mask);
|
||||
@ -837,27 +814,53 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_indices = mask._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype).contiguous();
|
||||
auto out_values = create_sparse_output_values(src_values, mask_nnz, commonDtype);
|
||||
|
||||
if (src_nnz == 0) {
|
||||
alias_into_sparse(result, mask_indices, out_values);
|
||||
auto out_indices = mask._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype);
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
alias_into_sparse(result, out_indices, out_values);
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_keys = flatten_indices(mask._indices().contiguous(), mask.sizes().slice(0, sd)).contiguous();
|
||||
auto src_keys = flatten_indices(src._indices().contiguous(), src.sizes().slice(0, sd)).contiguous();
|
||||
auto mask_indices = mask._indices().contiguous();
|
||||
auto src_indices = src._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype).contiguous();
|
||||
|
||||
const auto A_is_src = (src_nnz <= mask_nnz);
|
||||
const auto lenA = A_is_src ? src_nnz : mask_nnz;
|
||||
const auto lenB = A_is_src ? mask_nnz : src_nnz;
|
||||
auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous();
|
||||
auto src_keys = flatten_indices(src_indices, src.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
const bool A_is_src = (src_nnz <= mask_nnz);
|
||||
const int64_t lenA = A_is_src ? src_nnz : mask_nnz;
|
||||
const int64_t lenB = A_is_src ? mask_nnz : src_nnz;
|
||||
auto A_keys = A_is_src ? src_keys : mask_keys;
|
||||
auto B_keys = A_is_src ? mask_keys : src_keys;
|
||||
|
||||
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_src);
|
||||
const auto device = result.device();
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto counter = at::zeros({1}, at::device(device).dtype(at::kInt));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), A_is_src);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const int64_t M = static_cast<int64_t>(counter.item<int32_t>());
|
||||
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
|
||||
if (M > 0) {
|
||||
auto src_match = outA_idx.narrow(0, 0, M);
|
||||
@ -875,70 +878,6 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
}
|
||||
|
||||
static void sparse_mask_projection_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& lhs,
|
||||
const Tensor& rhs,
|
||||
const OptTensor& /*x_hash_opt*/,
|
||||
bool accumulate_matches) {
|
||||
|
||||
TORCH_CHECK(lhs.is_sparse() && rhs.is_sparse(), "sparse_mask_projection: expected sparse COO");
|
||||
TORCH_CHECK(lhs.is_mps() && rhs.is_mps(), "sparse_mask_projection: expected MPS tensors");
|
||||
TORCH_CHECK(lhs.sparse_dim() == rhs.sparse_dim(), "sparse_dim mismatch");
|
||||
|
||||
auto lhs_c = lhs.coalesce();
|
||||
auto rhs_c = rhs.coalesce();
|
||||
|
||||
const auto sd = lhs_c.sparse_dim();
|
||||
const auto lhs_nnz = lhs_c._nnz();
|
||||
const auto rhs_nnz = rhs_c._nnz();
|
||||
|
||||
auto commonDtype = at::result_type(lhs_c, rhs_c);
|
||||
TORCH_CHECK(canCast(commonDtype, result.scalar_type()),
|
||||
"Can't convert ", commonDtype, " to output ", result.scalar_type());
|
||||
|
||||
result.sparse_resize_(lhs.sizes(), lhs.sparse_dim(), lhs.dense_dim());
|
||||
|
||||
auto lhs_indices = lhs_c._indices().contiguous();
|
||||
auto rhs_values = rhs_c._values().to(commonDtype).contiguous();
|
||||
auto out_values = create_sparse_output_values(rhs_values, lhs_nnz, commonDtype);
|
||||
|
||||
if (lhs_nnz > 0 && rhs_nnz > 0) {
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs_c.sizes().slice(0, sd)).contiguous();
|
||||
auto rhs_keys = flatten_indices(rhs_c._indices().contiguous(), rhs_c.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
const auto A_is_lhs = (lhs_nnz <= rhs_nnz);
|
||||
const auto lenA = A_is_lhs ? lhs_nnz : rhs_nnz;
|
||||
const auto lenB = A_is_lhs ? rhs_nnz : lhs_nnz;
|
||||
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
|
||||
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
|
||||
|
||||
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_lhs);
|
||||
|
||||
if (M > 0) {
|
||||
auto idx_in_A = outA_idx.narrow(0, 0, M);
|
||||
auto idx_in_B = outB_idx.narrow(0, 0, M);
|
||||
auto idx_in_lhs = A_is_lhs ? idx_in_A : idx_in_B;
|
||||
auto idx_in_rhs = A_is_lhs ? idx_in_B : idx_in_A;
|
||||
|
||||
const auto view_cols = rhs_values.numel() / std::max<int64_t>(rhs_nnz, 1);
|
||||
auto rhs_rows = rhs_values.index_select(0, idx_in_rhs).contiguous();
|
||||
auto rhs_rows_2d = rhs_rows.view({M, view_cols});
|
||||
auto out_2d = out_values.view({lhs_nnz, view_cols});
|
||||
|
||||
if (accumulate_matches) {
|
||||
out_2d.index_add_(0, idx_in_lhs, rhs_rows_2d);
|
||||
} else {
|
||||
out_2d.index_copy_(0, idx_in_lhs, rhs_rows_2d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
alias_into_sparse(result, lhs._indices(), out_values);
|
||||
result._coalesced_(lhs.is_coalesced());
|
||||
}
|
||||
|
||||
static void sparse_mask_intersection_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& lhs,
|
||||
@ -1063,5 +1002,4 @@ Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_mps_kernel);
|
||||
} // namespace at::native
|
||||
@ -61,7 +61,6 @@ list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
|
||||
// to verify that the data race fix is working correctly
|
||||
|
||||
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
|
||||
if (!at::cuda::is_available()) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int num_accessor_threads = 15;
|
||||
constexpr int num_clear_threads = 5;
|
||||
constexpr int iterations_per_thread = 50;
|
||||
|
||||
std::atomic<bool> stop{false};
|
||||
std::atomic<int> error_count{0};
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_accessor_threads + num_clear_threads);
|
||||
|
||||
// Launch accessor threads
|
||||
for (int i = 0; i < num_accessor_threads; ++i) {
|
||||
threads.emplace_back([&stop, &error_count]() {
|
||||
try {
|
||||
at::cuda::CUDAGuard device_guard(0);
|
||||
|
||||
while (!stop.load(std::memory_order_relaxed)) {
|
||||
const auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
|
||||
|
||||
if (handle == nullptr || workspace == nullptr) {
|
||||
error_count++;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Launch threads that clear workspaces
|
||||
for (int i = 0; i < num_clear_threads; ++i) {
|
||||
threads.emplace_back([&error_count]() {
|
||||
try {
|
||||
for (int j = 0; j < iterations_per_thread; ++j) {
|
||||
at::cuda::clearCublasWorkspaces();
|
||||
std::this_thread::yield();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Let them run for a bit
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
stop.store(true, std::memory_order_relaxed);
|
||||
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
EXPECT_EQ(error_count.load(), 0);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
c10::cuda::CUDACachingAllocator::init(1);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@ -1,3 +1,191 @@
|
||||
#pragma once
|
||||
#include <ATen/xpu/XPUContext.h>
|
||||
#include <c10/xpu/XPUEvent.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace at::xpu {
|
||||
|
||||
/*
|
||||
* XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
|
||||
* constructed lazily when first recorded. It has a device, and this device is
|
||||
* acquired from the first recording stream. Later streams that record the event
|
||||
* must match the same device.
|
||||
*
|
||||
* Currently, XPUEvent does NOT support to export an inter-process event from
|
||||
* another process via inter-process communication(IPC). So it means that
|
||||
* inter-process communication for event handles between different processes is
|
||||
* not available. This could impact some applications that rely on cross-process
|
||||
* synchronization and communication.
|
||||
*/
|
||||
struct TORCH_XPU_API XPUEvent {
|
||||
// Constructors
|
||||
XPUEvent(bool enable_timing = false) noexcept
|
||||
: enable_timing_{enable_timing} {}
|
||||
|
||||
~XPUEvent() {
|
||||
if (isCreated()) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_deletion(
|
||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
XPUEvent(const XPUEvent&) = delete;
|
||||
XPUEvent& operator=(const XPUEvent&) = delete;
|
||||
|
||||
XPUEvent(XPUEvent&& other) = default;
|
||||
XPUEvent& operator=(XPUEvent&& other) = default;
|
||||
|
||||
operator sycl::event&() const {
|
||||
return event();
|
||||
}
|
||||
|
||||
std::optional<at::Device> device() const {
|
||||
if (isCreated()) {
|
||||
return at::Device(at::kXPU, device_index_);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool isCreated() const {
|
||||
return (event_.get() != nullptr);
|
||||
}
|
||||
|
||||
DeviceIndex device_index() const {
|
||||
return device_index_;
|
||||
}
|
||||
|
||||
sycl::event& event() const {
|
||||
return *event_;
|
||||
}
|
||||
|
||||
bool query() const {
|
||||
using namespace sycl::info;
|
||||
if (!isCreated()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return event().get_info<event::command_execution_status>() ==
|
||||
event_command_status::complete;
|
||||
}
|
||||
|
||||
void record() {
|
||||
record(getCurrentXPUStream());
|
||||
}
|
||||
|
||||
void recordOnce(const XPUStream& stream) {
|
||||
if (!isCreated()) {
|
||||
record(stream);
|
||||
}
|
||||
}
|
||||
|
||||
void record(const XPUStream& stream) {
|
||||
if (!isCreated()) {
|
||||
device_index_ = stream.device_index();
|
||||
assignEvent(stream.queue());
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_creation(
|
||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
device_index_ == stream.device_index(),
|
||||
"Event device ",
|
||||
device_index_,
|
||||
" does not match recording stream's device ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
reassignEvent(stream.queue());
|
||||
}
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_record(
|
||||
at::kXPU,
|
||||
reinterpret_cast<uintptr_t>(event_.get()),
|
||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||
}
|
||||
}
|
||||
|
||||
void block(const XPUStream& stream) {
|
||||
if (isCreated()) {
|
||||
std::vector<sycl::event> event_list{event()};
|
||||
// Make this stream wait until event_ is completed.
|
||||
stream.queue().ext_oneapi_submit_barrier(event_list);
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_wait(
|
||||
at::kXPU,
|
||||
reinterpret_cast<uintptr_t>(event_.get()),
|
||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double elapsed_time(const XPUEvent& other) const {
|
||||
TORCH_CHECK(
|
||||
isCreated() && other.isCreated(),
|
||||
"Both events must be recorded before calculating elapsed time.");
|
||||
TORCH_CHECK(
|
||||
query() && other.query(),
|
||||
"Both events must be completed before calculating elapsed time.");
|
||||
TORCH_CHECK(
|
||||
enable_timing_ && other.enable_timing_,
|
||||
"Both events must be created with argument 'enable_timing=True'.");
|
||||
|
||||
#if SYCL_COMPILER_VERSION < 20250000
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"elapsed_time of XPUEvent requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
|
||||
#endif
|
||||
|
||||
using namespace sycl::info::event_profiling;
|
||||
// Block until both of the recorded events are completed.
|
||||
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
|
||||
uint64_t start_time_ns = event().get_profiling_info<command_end>();
|
||||
// Return the eplased time in milliseconds.
|
||||
return 1e-6 *
|
||||
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
|
||||
}
|
||||
|
||||
void synchronize() const {
|
||||
if (isCreated()) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_synchronization(
|
||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
event().wait_and_throw();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void assignEvent(sycl::queue& queue) {
|
||||
#if SYCL_COMPILER_VERSION >= 20250000
|
||||
if (enable_timing_) {
|
||||
event_ = std::make_unique<sycl::event>(
|
||||
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
|
||||
} else {
|
||||
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
||||
}
|
||||
#else
|
||||
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
||||
#endif
|
||||
}
|
||||
|
||||
void reassignEvent(sycl::queue& queue) {
|
||||
event_.reset();
|
||||
assignEvent(queue);
|
||||
}
|
||||
|
||||
bool enable_timing_ = false;
|
||||
DeviceIndex device_index_ = -1;
|
||||
// Only need to track the last event, as events in an in-order queue are
|
||||
// executed sequentially.
|
||||
std::unique_ptr<sycl::event> event_;
|
||||
};
|
||||
|
||||
} // namespace at::xpu
|
||||
|
||||
@ -10,13 +10,6 @@
|
||||
...
|
||||
}
|
||||
|
||||
{
|
||||
ignore_empty_generic_uninitialised_conditional_jump
|
||||
Memcheck:Cond
|
||||
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
|
||||
...
|
||||
}
|
||||
|
||||
{
|
||||
Cond_cuda
|
||||
Memcheck:Cond
|
||||
|
||||
@ -50,7 +50,6 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
|
||||
"mobilenet_v2",
|
||||
"pytorch_CycleGAN_and_pix2pix",
|
||||
"pytorch_stargan",
|
||||
"repvgg_a2",
|
||||
"resnet152",
|
||||
"resnet18",
|
||||
"resnet50",
|
||||
|
||||
@ -9,61 +9,28 @@ def check_perf_csv(filename, threshold, threshold_scale):
|
||||
"""
|
||||
Basic performance checking.
|
||||
"""
|
||||
try:
|
||||
df = pd.read_csv(filename)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File {filename} not found")
|
||||
sys.exit(1)
|
||||
|
||||
effective_threshold = threshold * threshold_scale
|
||||
print(f"Checking {filename} (speedup threshold >= {effective_threshold:.2f}x)\n")
|
||||
df = pd.read_csv(filename)
|
||||
|
||||
failed = []
|
||||
for _, row in df.iterrows():
|
||||
model_name = row["name"]
|
||||
speedup = float(row["speedup"])
|
||||
abs_latency = float(row["abs_latency"])
|
||||
compilation_latency = float(row["compilation_latency"])
|
||||
compression_ratio = float(row["compression_ratio"])
|
||||
eager_peak_mem = float(row["eager_peak_mem"])
|
||||
dynamo_peak_mem = float(row["dynamo_peak_mem"])
|
||||
speedup = row["speedup"]
|
||||
if speedup < threshold * threshold_scale:
|
||||
failed.append(model_name)
|
||||
|
||||
perf_summary = f"{model_name:34} speedup={speedup:.3f}x"
|
||||
if pd.notna(abs_latency):
|
||||
perf_summary += f", latency={abs_latency:.1f} ms/iter"
|
||||
if pd.notna(compilation_latency):
|
||||
perf_summary += f", compile={compilation_latency:.3f}s"
|
||||
if pd.notna(compression_ratio):
|
||||
perf_summary += f", mem_ratio={1 / compression_ratio:.2f}x"
|
||||
if pd.notna(eager_peak_mem) and pd.notna(dynamo_peak_mem):
|
||||
perf_summary += (
|
||||
f" (eager={eager_peak_mem:.1f} GB, dynamo={dynamo_peak_mem:.1f} GB)"
|
||||
)
|
||||
|
||||
if speedup < effective_threshold:
|
||||
failed.append((model_name, speedup))
|
||||
|
||||
print(perf_summary)
|
||||
print(f"{model_name:34} {speedup}")
|
||||
|
||||
if failed:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
f"""
|
||||
Error {len(failed)} model(s) performance regressed
|
||||
{" ".join([name for name, _ in failed])}
|
||||
Error {len(failed)} models performance regressed
|
||||
{" ".join(failed)}
|
||||
"""
|
||||
)
|
||||
)
|
||||
for name, sp in sorted(failed, key=lambda x: x[1]):
|
||||
pct_from_target = (sp / effective_threshold - 1.0) * 100.0
|
||||
print(
|
||||
f" - {name}: {sp:.3f}x (< {effective_threshold:.2f}x; {pct_from_target:.1f}% from target)"
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(
|
||||
f"\nAll {len(df)} model(s) passed threshold check (>= {effective_threshold:.2f}x)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -77,7 +44,7 @@ if __name__ == "__main__":
|
||||
"-s",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="multiply threshold by this value to relax the check",
|
||||
help="multiple threshold by this value to relax the check",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
check_perf_csv(args.file, args.threshold, args.threshold_scale)
|
||||
|
||||
@ -10,7 +10,7 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -50,7 +50,7 @@ nfnet_l0,pass,7
|
||||
|
||||
|
||||
|
||||
repvgg_a2,pass,7
|
||||
repvgg_a2,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -2288,9 +2288,11 @@ class BenchmarkRunner:
|
||||
)
|
||||
):
|
||||
is_same = False
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# Sometimes torch.allclose may throw RuntimeError
|
||||
is_same = False
|
||||
exception_string = str(e)
|
||||
accuracy_status = f"fail_exception: {exception_string}"
|
||||
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
||||
|
||||
if not is_same:
|
||||
accuracy_status = "eager_two_runs_differ"
|
||||
@ -2379,9 +2381,7 @@ class BenchmarkRunner:
|
||||
print(
|
||||
f"Load model outputs from {self.args.compare_model_outputs_with} to compare"
|
||||
)
|
||||
saved_result = torch.load(
|
||||
self.args.compare_model_outputs_with, weights_only=False
|
||||
)
|
||||
saved_result = torch.load(self.args.compare_model_outputs_with)
|
||||
is_bitwise_same = bitwise_same(saved_result, new_result)
|
||||
if not is_bitwise_same:
|
||||
print(
|
||||
@ -2409,9 +2409,11 @@ class BenchmarkRunner:
|
||||
force_max_multiplier=force_max_multiplier,
|
||||
):
|
||||
is_same = False
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# Sometimes torch.allclose may throw RuntimeError
|
||||
is_same = False
|
||||
exception_string = str(e)
|
||||
accuracy_status = f"fail_exception: {exception_string}"
|
||||
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
||||
|
||||
if not is_same:
|
||||
if self.args.skip_accuracy_check:
|
||||
|
||||
@ -1,62 +0,0 @@
|
||||
import sys
|
||||
|
||||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DTensor, Replicate
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
|
||||
class BenchmarkDTensorDispatch(BenchmarkBase):
|
||||
def __init__(self, operator, world_size) -> None:
|
||||
super().__init__(
|
||||
category=f"dtensor_dispatch_{operator}",
|
||||
device="cuda",
|
||||
)
|
||||
self.world_size = world_size
|
||||
|
||||
def name(self) -> str:
|
||||
prefix = f"{self.category()}"
|
||||
return prefix
|
||||
|
||||
def description(self) -> str:
|
||||
return f"DTensor dispatch time for {self.category()}"
|
||||
|
||||
def _prepare_once(self) -> None:
|
||||
self.mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
"cuda", (self.world_size,), mesh_dim_names=("dp",)
|
||||
)
|
||||
self.a = DTensor.from_local(
|
||||
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
|
||||
)
|
||||
self.b = DTensor.from_local(
|
||||
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
|
||||
)
|
||||
|
||||
def _prepare(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkDetach(BenchmarkDTensorDispatch):
|
||||
def __init__(self, world_size) -> None:
|
||||
super().__init__(operator="detach", world_size=world_size)
|
||||
|
||||
def _work(self) -> None:
|
||||
self.a.detach()
|
||||
|
||||
|
||||
def main():
|
||||
world_size = 256
|
||||
fake_store = FakeStore()
|
||||
torch.distributed.init_process_group(
|
||||
"fake", store=fake_store, rank=0, world_size=world_size
|
||||
)
|
||||
result_path = sys.argv[1]
|
||||
BenchmarkDetach(world_size).enable_instruction_count().collect_all().append_results(
|
||||
result_path
|
||||
)
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -189,10 +189,6 @@ skip:
|
||||
- hf_Whisper
|
||||
- hf_distil_whisper
|
||||
- timm_vision_transformer_large
|
||||
# https://github.com/pytorch/pytorch/issues/167895
|
||||
- stable_diffusion
|
||||
- stable_diffusion_text_encoder
|
||||
- stable_diffusion_unet
|
||||
|
||||
device:
|
||||
cpu:
|
||||
|
||||
@ -125,17 +125,6 @@ AttentionType = Literal[
|
||||
]
|
||||
DtypeString = Literal["bfloat16", "float16", "float32"]
|
||||
SpeedupType = Literal["fwd", "bwd"]
|
||||
# Operator Name mapping
|
||||
backend_to_operator_name = {
|
||||
"math": "math attention kernel",
|
||||
"efficient": "efficient attention kernel",
|
||||
"cudnn": "cudnn attention kernel",
|
||||
"fav2": "flash attention 2 kernel",
|
||||
"fav3": "flash attention 3 kernel",
|
||||
"fakv": "flash attention kv cache kernel",
|
||||
"og-eager": "eager attention kernel",
|
||||
"flex": "flex attention kernel",
|
||||
}
|
||||
|
||||
|
||||
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
||||
@ -1276,14 +1265,12 @@ def _output_json_for_dashboard(
|
||||
model: ModelInfo
|
||||
metric: MetricInfo
|
||||
|
||||
operator_name = backend_to_operator_name.get(backend, backend)
|
||||
|
||||
# Benchmark extra info
|
||||
benchmark_extra_info = {
|
||||
"input_config": input_config,
|
||||
"device": device,
|
||||
"arch": device_arch,
|
||||
"operator_name": operator_name,
|
||||
"operator_name": backend,
|
||||
"attn_type": config.attn_type,
|
||||
"shape": str(config.shape),
|
||||
"max_autotune": config.max_autotune,
|
||||
@ -1301,7 +1288,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": operator_name,
|
||||
"operator_name": backend,
|
||||
"attn_type": config.attn_type,
|
||||
},
|
||||
),
|
||||
@ -1328,7 +1315,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": operator_name,
|
||||
"operator_name": backend,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
@ -1354,7 +1341,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": operator_name,
|
||||
"operator_name": backend,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
@ -1384,7 +1371,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": operator_name,
|
||||
"operator_name": backend,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# These load paths point to different files in internal and OSS environment
|
||||
|
||||
load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
load("//tools/build_defs:cell_defs.bzl", "get_fbsource_cell")
|
||||
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
|
||||
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
|
||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||
@ -591,9 +590,6 @@ def pt_operator_query_codegen(
|
||||
pt_allow_forced_schema_registration = True,
|
||||
compatible_with = [],
|
||||
apple_sdks = None):
|
||||
if get_fbsource_cell() == "fbcode":
|
||||
return
|
||||
|
||||
oplist_dir_name = name + "_pt_oplist"
|
||||
|
||||
# @lint-ignore BUCKLINT
|
||||
@ -869,9 +865,6 @@ def define_buck_targets(
|
||||
pt_xplat_cxx_library = fb_xplat_cxx_library,
|
||||
c2_fbandroid_xplat_compiler_flags = [],
|
||||
labels = []):
|
||||
if get_fbsource_cell() == "fbcode":
|
||||
return
|
||||
|
||||
# @lint-ignore BUCKLINT
|
||||
fb_native.filegroup(
|
||||
name = "metal_build_srcs",
|
||||
|
||||
@ -19,17 +19,6 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
using CaptureId_t = unsigned long long;
|
||||
// first is set if the instance is created by CUDAGraph::capture_begin.
|
||||
// second is set if the instance is created by at::cuda::graph_pool_handle.
|
||||
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
|
||||
|
||||
struct MempoolIdHash {
|
||||
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
|
||||
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
|
||||
}
|
||||
};
|
||||
|
||||
// A DataPtr is a unique pointer (with an attached deleter and some
|
||||
// context for the deleter) to some memory, which also records what
|
||||
// device is for its data.
|
||||
|
||||
@ -99,10 +99,7 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
|
||||
// Return the free memory size and total memory size in bytes for the
|
||||
// specified device.
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "getMemoryInfo is not implemented for this allocator yet.");
|
||||
}
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
|
||||
@ -44,7 +44,7 @@ struct C10_API SafePyObject {
|
||||
(*other.pyinterpreter_)->incref(other.data_);
|
||||
}
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_);
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
}
|
||||
data_ = other.data_;
|
||||
pyinterpreter_ = other.pyinterpreter_;
|
||||
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
|
||||
|
||||
~SafePyObject() {
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_);
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -27,13 +27,26 @@
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
|
||||
// regarding macros.
|
||||
|
||||
template <typename T>
|
||||
struct CppTypeToScalarType;
|
||||
|
||||
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct CppTypeToScalarType<cpp_type> \
|
||||
: std:: \
|
||||
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
|
||||
#undef SPECIALIZE_CppTypeToScalarType
|
||||
|
||||
#define DEFINE_CONSTANT(_, name) \
|
||||
constexpr ScalarType k##name = ScalarType::name;
|
||||
|
||||
@ -92,6 +105,13 @@ inline bool isComplexType(ScalarType t) {
|
||||
t == ScalarType::ComplexDouble);
|
||||
}
|
||||
|
||||
inline bool isQIntType(ScalarType t) {
|
||||
// Don't forget to extend this when adding new QInt types
|
||||
return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
|
||||
t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
|
||||
t == ScalarType::QUInt2x4;
|
||||
}
|
||||
|
||||
inline bool isBitsType(ScalarType t) {
|
||||
return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
|
||||
t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||
|
||||
@ -185,12 +205,6 @@ inline bool isSignedType(ScalarType t) {
|
||||
break;
|
||||
// Do not add default here, but rather define behavior of every new entry
|
||||
// here. `-Wswitch-enum` would raise a warning in those cases.
|
||||
// TODO: get PyTorch to adopt exhaustive switches by default with a way to
|
||||
// opt specific switches to being non-exhaustive.
|
||||
// Exhaustive:
|
||||
// `-Wswitch-enum`, `-Wswitch-default`, `-Wno-covered-switch-default`
|
||||
// Non-Exhaustive:
|
||||
// `-Wno-switch-enum`, `-Wswitch-default`, `-Wcovered-switch-default`
|
||||
}
|
||||
TORCH_CHECK(false, "Unknown ScalarType ", t);
|
||||
#undef CASE_ISSIGNED
|
||||
|
||||
@ -48,30 +48,6 @@ void warnDeprecatedDataPtr() {
|
||||
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
|
||||
}
|
||||
|
||||
void StorageImpl::incref_pyobject() const {
|
||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
||||
// observed before the load of the PyObject* below.
|
||||
// NB: This is a no-op on x86/x86-64
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
||||
}
|
||||
|
||||
void StorageImpl::decref_pyobject() const {
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
||||
}
|
||||
|
||||
bool StorageImpl::try_incref_pyobject() const {
|
||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
||||
if (C10_UNLIKELY(!interp)) {
|
||||
return false;
|
||||
}
|
||||
return (*interp)->try_incref(pyobj_slot_);
|
||||
}
|
||||
|
||||
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
|
||||
// Allowlist verification.
|
||||
// Only if the devicetype is in the allowlist,
|
||||
|
||||
@ -105,12 +105,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
||||
data_ptr_.clear();
|
||||
}
|
||||
|
||||
void incref_pyobject() const override final;
|
||||
|
||||
void decref_pyobject() const override final;
|
||||
|
||||
bool try_incref_pyobject() const override final;
|
||||
|
||||
size_t nbytes() const {
|
||||
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
|
||||
TORCH_CHECK(!size_bytes_is_heap_allocated_);
|
||||
@ -376,18 +370,4 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
|
||||
bool resizable,
|
||||
std::optional<at::Device> device_opt);
|
||||
|
||||
namespace detail {
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
template <class T>
|
||||
struct TargetTraits<
|
||||
T,
|
||||
std::enable_if_t<
|
||||
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
|
||||
static constexpr bool can_have_pyobject = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -277,6 +277,7 @@ void TensorImpl::release_resources() {
|
||||
if (storage_) {
|
||||
storage_ = {};
|
||||
}
|
||||
pyobj_slot_.maybe_destroy_pyobj();
|
||||
}
|
||||
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
@ -988,30 +989,6 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
|
||||
}
|
||||
}
|
||||
|
||||
void TensorImpl::incref_pyobject() const {
|
||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
||||
// observed before the load of the PyObject* below.
|
||||
// NB: This is a no-op on x86/x86-64
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
||||
}
|
||||
|
||||
void TensorImpl::decref_pyobject() const {
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
||||
}
|
||||
|
||||
bool TensorImpl::try_incref_pyobject() const {
|
||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
||||
if (C10_UNLIKELY(!interp)) {
|
||||
return false;
|
||||
}
|
||||
return (*interp)->try_incref(pyobj_slot_);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -57,8 +57,6 @@ C10_DECLARE_bool(caffe2_keep_on_shrink);
|
||||
// respect caffe2_keep_on_shrink.
|
||||
C10_DECLARE_int64(caffe2_max_keep_on_shrink_memory);
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace at {
|
||||
class Tensor;
|
||||
class TensorBase;
|
||||
@ -2178,12 +2176,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return &pyobj_slot_;
|
||||
}
|
||||
|
||||
void incref_pyobject() const override final;
|
||||
|
||||
void decref_pyobject() const override final;
|
||||
|
||||
bool try_incref_pyobject() const override final;
|
||||
|
||||
private:
|
||||
// See NOTE [std::optional operator usage in CUDA]
|
||||
// We probably don't want to expose this publicly until
|
||||
@ -3085,19 +3077,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
friend class C10_TensorImpl_Size_Check_Dummy_Class;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
template <class T>
|
||||
struct TargetTraits<
|
||||
T,
|
||||
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
|
||||
static constexpr bool can_have_pyobject = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Note [TensorImpl size constraints]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Changed the size of TensorImpl? If the size went down, good for
|
||||
@ -3324,5 +3303,3 @@ static_assert(
|
||||
#undef C10_GCC_VERSION_MINOR
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -11,11 +11,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
|
||||
void incref(PyObject* pyobj) const override {} // do nothing
|
||||
|
||||
void decref(PyObject* pyobj) const override {} // do nothing
|
||||
|
||||
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
|
||||
return false;
|
||||
}
|
||||
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
|
||||
} // do nothing
|
||||
|
||||
#define PANIC(m) \
|
||||
TORCH_INTERNAL_ASSERT( \
|
||||
@ -23,10 +20,6 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
"attempted to call " #m \
|
||||
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
|
||||
|
||||
size_t refcnt(PyObject* pyobj) const override {
|
||||
PANIC(refcnt);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
|
||||
PANIC(detach);
|
||||
}
|
||||
|
||||
@ -18,9 +18,6 @@ namespace c10 {
|
||||
struct IValue;
|
||||
class OperatorHandle;
|
||||
struct TensorImpl;
|
||||
namespace impl {
|
||||
struct PyObjectSlot;
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::jit {
|
||||
@ -129,12 +126,9 @@ struct C10_API PyInterpreterVTable {
|
||||
|
||||
// Run Py_INCREF on a PyObject.
|
||||
virtual void incref(PyObject* pyobj) const = 0;
|
||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call.
|
||||
virtual void decref(PyObject* pyobj) const = 0;
|
||||
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL.
|
||||
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
|
||||
// Run Py_REFCNT on a PyObject.
|
||||
virtual size_t refcnt(PyObject* pyobj) const = 0;
|
||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
|
||||
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
|
||||
|
||||
// Perform a detach by deferring to the __torch_dispatch__ implementation of
|
||||
// detach, which will also arrange for the PyObject to get copied in this
|
||||
|
||||
56
c10/core/impl/PyObjectSlot.cpp
Normal file
56
c10/core/impl/PyObjectSlot.cpp
Normal file
@ -0,0 +1,56 @@
|
||||
#include <c10/core/impl/PyObjectSlot.h>
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||
|
||||
PyObjectSlot::~PyObjectSlot() {
|
||||
maybe_destroy_pyobj();
|
||||
}
|
||||
|
||||
void PyObjectSlot::maybe_destroy_pyobj() {
|
||||
if (owns_pyobj()) {
|
||||
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
|
||||
(*pyobj_interpreter_.load(std::memory_order_acquire))
|
||||
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
|
||||
// NB: this destructor can only be entered when there are no
|
||||
// references to this C++ object (obviously), NOR any references
|
||||
// to the PyObject (if there are references to the PyObject,
|
||||
// then the PyObject holds an owning reference to the tensor).
|
||||
// So it is OK to clear pyobj_ here as it is impossible for it to
|
||||
// be used again (modulo weak reference races)
|
||||
pyobj_ = nullptr; // for safety
|
||||
}
|
||||
}
|
||||
|
||||
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
|
||||
return pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<PyObject*>(
|
||||
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
|
||||
}
|
||||
|
||||
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
|
||||
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter) {
|
||||
return *interpreter;
|
||||
}
|
||||
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
|
||||
}
|
||||
|
||||
bool PyObjectSlot::owns_pyobj() {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
|
||||
}
|
||||
|
||||
void PyObjectSlot::set_owns_pyobj(bool b) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
pyobj_ = reinterpret_cast<PyObject*>(
|
||||
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
|
||||
}
|
||||
|
||||
} // namespace c10::impl
|
||||
@ -8,58 +8,117 @@
|
||||
|
||||
#include <atomic>
|
||||
|
||||
namespace torch::utils {
|
||||
class PyObjectPreservation;
|
||||
}
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
struct C10_API PyObjectSlot {
|
||||
public:
|
||||
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||
PyObjectSlot();
|
||||
|
||||
~PyObjectSlot();
|
||||
|
||||
void maybe_destroy_pyobj();
|
||||
|
||||
// Associate the TensorImpl with the specified PyObject, and, if necessary,
|
||||
// also tag the interpreter.
|
||||
//
|
||||
// NB: This lives in a header so that we can inline away the switch on status
|
||||
//
|
||||
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
|
||||
// PyObject if necessary!
|
||||
void init_pyobj(PyObject* pyobj) {
|
||||
pyobj_interpreter_.store(
|
||||
getGlobalPyInterpreter(), std::memory_order_relaxed);
|
||||
pyobj_ = pyobj;
|
||||
}
|
||||
|
||||
// Query the PyObject interpreter. This may return null if there is no
|
||||
// interpreter.
|
||||
PyInterpreter* pyobj_interpreter() const {
|
||||
return pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
// interpreter. This is racy!
|
||||
PyInterpreter* pyobj_interpreter();
|
||||
|
||||
PyObject* _unchecked_untagged_pyobj() const;
|
||||
|
||||
// Test the interpreter tag. If tagged for the current interpreter, return
|
||||
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
|
||||
// returns a nullopt. If it is definitely invalid, raises an error.
|
||||
//
|
||||
// If `ignore_hermetic_tls` is false and this function is called from a
|
||||
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
|
||||
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
|
||||
// context is ignored, allowing you to check the interpreter tag of a
|
||||
// nonhermetic PyObject from within a hermetic context. This is necessary
|
||||
// because there are some cases where the deallocator function of a
|
||||
// nonhermetic PyObject is called from within a hermetic context, so it must
|
||||
// be properly treated as a nonhermetic PyObject.
|
||||
//
|
||||
// NB: this lives in header so that we can avoid actually creating the
|
||||
// std::optional
|
||||
|
||||
// @todo alban: I'm not too sure what's going on here, we can probably delete
|
||||
// it but it's worthwhile making sure
|
||||
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
|
||||
impl::PyInterpreter* interpreter =
|
||||
pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return _unchecked_untagged_pyobj();
|
||||
}
|
||||
}
|
||||
|
||||
PyInterpreter& load_pyobj_interpreter() const {
|
||||
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
interpreter, "cannot access PyObject for Tensor - no interpreter set");
|
||||
return *interpreter;
|
||||
}
|
||||
PyInterpreter& load_pyobj_interpreter() const;
|
||||
|
||||
PyObject* load_pyobj() const {
|
||||
return pyobj_.load(std::memory_order_acquire);
|
||||
}
|
||||
bool owns_pyobj();
|
||||
|
||||
void store_pyobj(PyObject* obj) {
|
||||
pyobj_.store(obj, std::memory_order_release);
|
||||
}
|
||||
|
||||
bool has_unique_reference() const {
|
||||
PyObject* pyobj = load_pyobj();
|
||||
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
pyobj_.store(nullptr, std::memory_order_relaxed);
|
||||
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
|
||||
}
|
||||
void set_owns_pyobj(bool b);
|
||||
|
||||
private:
|
||||
// This is now always the global interpreter if the PyObject is set.
|
||||
// Maybe we can remove this field some day...
|
||||
// This field contains the interpreter tag for this object. See
|
||||
// Note [Python interpreter tag] for general context
|
||||
//
|
||||
// Note [Memory ordering on Python interpreter tag]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// What memory_order do we need when accessing this atomic? We don't
|
||||
// need a single total modification order (as provided by
|
||||
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
|
||||
// transition from -1 to some positive integer and never changes afterwards.
|
||||
// Because there is only one modification, it trivially already has a total
|
||||
// modification order (e.g., we don't need fences or locked instructions on
|
||||
// x86)
|
||||
//
|
||||
// In fact, one could make a reasonable argument that relaxed reads are OK,
|
||||
// due to the presence of external locking (GIL) to ensure that interactions
|
||||
// with other data structures are still correctly synchronized, so that
|
||||
// we fall in the "Single-Location Data Structures" case as described in
|
||||
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
|
||||
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
|
||||
// as I get the same assembly in both cases. So I just use the more
|
||||
// conservative acquire (which will impede compiler optimizations but I don't
|
||||
// care)
|
||||
std::atomic<PyInterpreter*> pyobj_interpreter_;
|
||||
|
||||
// The PyObject representing this Tensor or nullptr. Ownership is managed
|
||||
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
|
||||
// reference is already dead.
|
||||
std::atomic<PyObject*> pyobj_;
|
||||
|
||||
friend class torch::utils::PyObjectPreservation;
|
||||
// This field contains a reference to a PyObject representing this Tensor.
|
||||
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
|
||||
// PyObject for it and set this field. This field does not have to be
|
||||
// protected by an atomic as it is only allowed to be accessed when you hold
|
||||
// the GIL, or during destruction of the tensor.
|
||||
//
|
||||
// When a PyObject dies, you are obligated to clear this field
|
||||
// (otherwise, you will try to use-after-free the pyobj); this currently
|
||||
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
|
||||
//
|
||||
// NB: Ordinarily, this should not be a strong reference, as if the
|
||||
// PyObject owns the Tensor, this would create a reference cycle.
|
||||
// However, sometimes this ownership flips. To track who owns
|
||||
// who, this has a single pointer tag indicating whether or not the
|
||||
// C++ object owns the PyObject (the common case, zero, means PyObject
|
||||
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
|
||||
// or check_pyobj for checked access. See references to PyObject
|
||||
// resurrection in torch/csrc/autograd/python_variable.cpp
|
||||
PyObject* pyobj_;
|
||||
};
|
||||
|
||||
} // namespace c10::impl
|
||||
|
||||
@ -1012,6 +1012,12 @@ PrivatePoolState::PrivatePoolState(
|
||||
}
|
||||
}
|
||||
|
||||
struct MempoolIdHash {
|
||||
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
|
||||
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
|
||||
}
|
||||
};
|
||||
|
||||
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
|
||||
if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) {
|
||||
*ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size);
|
||||
@ -4504,3 +4510,66 @@ std::atomic<CUDAAllocator*> allocator;
|
||||
static BackendStaticInitializer backend_static_initializer;
|
||||
} // namespace cuda::CUDACachingAllocator
|
||||
} // namespace c10
|
||||
|
||||
namespace c10::cuda {
|
||||
|
||||
// uid_ is incremented when a user creates a MemPool,
|
||||
// for example: using graph_pool_handle() or c10::cuda::MemPool().
|
||||
//
|
||||
// uuid_ is incremented when CUDAGraph creates a MemPool
|
||||
// as a result of a user not providing a pool.
|
||||
//
|
||||
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
|
||||
// passed to a function, either by user or CUDAGraphs. For example,
|
||||
// default value of MempoolId_t for capture_begin function is {0, 0}.
|
||||
// That's why uid_ and uuid_ start at 1.
|
||||
std::atomic<CaptureId_t> MemPool::uid_{1};
|
||||
std::atomic<CaptureId_t> MemPool::uuid_{1};
|
||||
|
||||
MemPool::MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created,
|
||||
bool use_on_oom)
|
||||
: allocator_(allocator), is_user_created_(is_user_created) {
|
||||
if (is_user_created_) {
|
||||
id_ = {0, uid_++};
|
||||
} else {
|
||||
id_ = {uuid_++, 0};
|
||||
}
|
||||
device_ = c10::cuda::current_device();
|
||||
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
|
||||
if (use_on_oom) {
|
||||
CUDACachingAllocator::setUseOnOOM(device_, id_);
|
||||
}
|
||||
}
|
||||
|
||||
MemPool::~MemPool() {
|
||||
TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
CUDACachingAllocator::releasePool(device_, id_);
|
||||
c10::cuda::CUDACachingAllocator::emptyCache(id_);
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::id() {
|
||||
return id_;
|
||||
}
|
||||
|
||||
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
int MemPool::use_count() {
|
||||
return CUDACachingAllocator::getPoolUseCount(device_, id_);
|
||||
}
|
||||
|
||||
c10::DeviceIndex MemPool::device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
|
||||
if (is_user_created) {
|
||||
return {0, uid_++};
|
||||
}
|
||||
return {uuid_++, 0};
|
||||
}
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
||||
@ -562,7 +562,41 @@ inline std::string getUserMetadata() {
|
||||
} // namespace c10::cuda::CUDACachingAllocator
|
||||
|
||||
namespace c10::cuda {
|
||||
|
||||
// Keep BC only
|
||||
using c10::CaptureId_t;
|
||||
using c10::MempoolId_t;
|
||||
|
||||
// MemPool represents a pool of memory in a caching allocator. Currently,
|
||||
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
||||
//
|
||||
// An allocator pointer can be passed to the MemPool to define how the
|
||||
// allocations should be done in the pool. For example: using a different
|
||||
// system allocator such as ncclMemAlloc.
|
||||
struct C10_CUDA_API MemPool {
|
||||
MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
|
||||
bool is_user_created = true,
|
||||
bool use_on_oom = false);
|
||||
MemPool(const MemPool&) = delete;
|
||||
MemPool(MemPool&&) = default;
|
||||
MemPool& operator=(const MemPool&) = delete;
|
||||
MemPool& operator=(MemPool&&) = default;
|
||||
~MemPool();
|
||||
|
||||
MempoolId_t id();
|
||||
CUDACachingAllocator::CUDAAllocator* allocator();
|
||||
int use_count();
|
||||
c10::DeviceIndex device();
|
||||
static MempoolId_t graph_pool_handle(bool is_user_created = true);
|
||||
|
||||
private:
|
||||
static std::atomic<CaptureId_t> uid_;
|
||||
static std::atomic<CaptureId_t> uuid_;
|
||||
CUDACachingAllocator::CUDAAllocator* allocator_;
|
||||
bool is_user_created_;
|
||||
MempoolId_t id_;
|
||||
c10::DeviceIndex device_;
|
||||
};
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user