mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 14:15:07 +08:00
Compare commits
191 Commits
documentat
...
gh/mlazos/
| Author | SHA1 | Date | |
|---|---|---|---|
| 405e7fa2d8 | |||
| cdb9f8ea4d | |||
| e5eb89e111 | |||
| b5e0e6932a | |||
| f87cb783da | |||
| 6ea779188c | |||
| 460c7e196c | |||
| 7aac506cdc | |||
| 374ee9e867 | |||
| 698aa0f3e5 | |||
| d3ca4a3a4f | |||
| c940b1fbbc | |||
| 4de24bcc56 | |||
| f2d0a472ef | |||
| 9ae0ecec7d | |||
| ce4f31f662 | |||
| 2c846bb614 | |||
| 8c86ccfbc9 | |||
| 8f96e7bc1d | |||
| 782fc3c72b | |||
| 1a67403fc6 | |||
| 3d801a4c01 | |||
| 2034ca99ae | |||
| 480b4ff882 | |||
| f570e589da | |||
| f9851af59b | |||
| 8824cc6a88 | |||
| eeebf9f664 | |||
| d9a50bf9a8 | |||
| 2984331c87 | |||
| 9b68682df2 | |||
| 8f5f89c9a0 | |||
| 8919f69362 | |||
| 19c867873a | |||
| e3dadb1d36 | |||
| c9b09a31e8 | |||
| 35571fe94b | |||
| 485f2b607a | |||
| 0c5d5c7e9a | |||
| 5f98a0363a | |||
| 2d739001d3 | |||
| 273babeec3 | |||
| a76dd6b7c6 | |||
| 2fa18d1545 | |||
| 537167aa1e | |||
| 0dac408f43 | |||
| 158e72427b | |||
| 0184ef291d | |||
| 2ca428c721 | |||
| 1311385f9d | |||
| 5f0a5b8f87 | |||
| 74e85c6944 | |||
| a6a0379b9c | |||
| a95eee68d9 | |||
| 2ad70c9446 | |||
| bc09a84150 | |||
| 760c901c9a | |||
| d105e3a198 | |||
| ed79693706 | |||
| 10a1578408 | |||
| bdb37536be | |||
| dd7a45abc0 | |||
| 7557e38e32 | |||
| c5d91d9e3e | |||
| a32832682c | |||
| 4f6aae35fd | |||
| 4cff8b5e07 | |||
| 4714eb7021 | |||
| 780e32524c | |||
| 6bf51de533 | |||
| d33d125c94 | |||
| dc8bb52f77 | |||
| 9997e853e9 | |||
| 2a09f6e02e | |||
| bf380fbd4c | |||
| 148fd9a522 | |||
| 7bb8d8c200 | |||
| 5ce4a8b49f | |||
| 7dd56474f2 | |||
| 3260bf3b19 | |||
| 05c6a06b2b | |||
| 25e9d8124c | |||
| bc882f8284 | |||
| edd365ed4a | |||
| 1366a2fa55 | |||
| 91f0c5a9da | |||
| 67390692c5 | |||
| 1debfd44fd | |||
| cdf0a9c21f | |||
| 115016f1a2 | |||
| 971e6ca434 | |||
| e8d411e7f7 | |||
| 2e5233d7bd | |||
| 514dd96376 | |||
| 9ae62fcc18 | |||
| ae71b0e163 | |||
| 5b6ff8148d | |||
| 1f7e4343e7 | |||
| b21856f5fc | |||
| 259ba0ecab | |||
| 051f1fe8e3 | |||
| ee387c43fe | |||
| 3a944661d6 | |||
| 56034074ca | |||
| 8def619bbe | |||
| 61883a5787 | |||
| d8ada1ee76 | |||
| fe841a1db4 | |||
| b65829b84f | |||
| b0e0ae97ba | |||
| f44a1ddcb2 | |||
| 184e2cbc89 | |||
| 416421c7c4 | |||
| bd99ae3315 | |||
| ce8672c24f | |||
| 402c465030 | |||
| 573a79fffa | |||
| 4945180468 | |||
| 1df723e6f5 | |||
| f9b81e23e4 | |||
| ffe6cc39c7 | |||
| db1f3f6901 | |||
| 43041f0a43 | |||
| dc00842b81 | |||
| f1a129a6d0 | |||
| fad48ffa62 | |||
| 3e7a66fae1 | |||
| 5f0a563dc8 | |||
| 678915d5f1 | |||
| daed97afff | |||
| 53947adb1f | |||
| c297b02f12 | |||
| bd24774f50 | |||
| 525eb9fab9 | |||
| 7886070fc5 | |||
| 87d17e9dee | |||
| 53422e6bc8 | |||
| c34b743eac | |||
| db250fa895 | |||
| 52231a7974 | |||
| cf71c53eae | |||
| f9caae42ed | |||
| 52a6b5a4cc | |||
| 94f6f79e27 | |||
| 5676de1157 | |||
| 2ca0b3f70a | |||
| b06453c7cf | |||
| f0fa39a7e4 | |||
| b5142f74f9 | |||
| a14452bfce | |||
| 619f329a4b | |||
| 7a48db0809 | |||
| 406f2943d2 | |||
| c3bc56c8b4 | |||
| b2be4d24c0 | |||
| 8d5cceeb6a | |||
| f6331192b4 | |||
| f8d408d24a | |||
| 5a85b6eaf8 | |||
| e3d6896d08 | |||
| 9d9e7c7b1c | |||
| 4c3721fe70 | |||
| 8ef4099313 | |||
| de773364be | |||
| 47da714b8b | |||
| 69ab1f93e4 | |||
| 232baa33b3 | |||
| 6f0182495f | |||
| 7da82b84e2 | |||
| cda7604434 | |||
| 6ca8cc6edf | |||
| bb37483464 | |||
| 2751b1d3c3 | |||
| fe0bb7cf60 | |||
| cf63b212e3 | |||
| 17e70ae459 | |||
| ad7db3617e | |||
| 5320ca3725 | |||
| 3e4faca130 | |||
| 0c2f206ded | |||
| 6cf21fa331 | |||
| cdc8460f2c | |||
| 86130aa2ca | |||
| 9491830c79 | |||
| 04a85b4c21 | |||
| a4437d76f0 | |||
| 3ea829a337 | |||
| 3966b5ad05 | |||
| f6a79b2a4a | |||
| 2fcf41dd8e | |||
| 31ccd8f13e |
@ -30,7 +30,6 @@ into a tarball, with the following structure:
|
||||
More specifically, `build_magma.sh` copies over the relevant files from the `package_files` directory depending on the ROCm version.
|
||||
Outputted binaries should be in the `output` folder.
|
||||
|
||||
|
||||
## Pushing
|
||||
|
||||
Packages can be uploaded to an S3 bucket using:
|
||||
|
||||
@ -96,7 +96,6 @@ function pip_build_and_install() {
|
||||
python3 -m pip wheel \
|
||||
--no-build-isolation \
|
||||
--no-deps \
|
||||
--no-use-pep517 \
|
||||
-w "${wheel_dir}" \
|
||||
"${build_target}"
|
||||
fi
|
||||
@ -308,6 +307,28 @@ function install_torchao() {
|
||||
pip_build_and_install "git+https://github.com/pytorch/ao.git@${commit}" dist/ao
|
||||
}
|
||||
|
||||
function install_flash_attn_cute() {
|
||||
echo "Installing FlashAttention CuTe from GitHub..."
|
||||
# Grab latest main til we have a pinned commit
|
||||
local flash_attn_commit
|
||||
flash_attn_commit=$(git ls-remote https://github.com/Dao-AILab/flash-attention.git HEAD | cut -f1)
|
||||
|
||||
# Clone the repo to a temporary directory
|
||||
rm -rf flash-attention-build
|
||||
git clone --depth 1 --recursive https://github.com/Dao-AILab/flash-attention.git flash-attention-build
|
||||
|
||||
pushd flash-attention-build
|
||||
git checkout "${flash_attn_commit}"
|
||||
|
||||
# Install only the 'cute' sub-directory
|
||||
pip_install -e flash_attn/cute/
|
||||
popd
|
||||
|
||||
# remove the local repo
|
||||
rm -rf flash-attention-build
|
||||
echo "FlashAttention CuTe installation complete."
|
||||
}
|
||||
|
||||
function print_sccache_stats() {
|
||||
echo 'PyTorch Build Statistics'
|
||||
sccache --show-stats
|
||||
|
||||
@ -100,6 +100,337 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _compile_and_extract_symbols(
|
||||
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Helper to compile a C++ file and extract all symbols.
|
||||
|
||||
Args:
|
||||
cpp_content: C++ source code to compile
|
||||
compile_flags: Compilation flags
|
||||
exclude_list: List of symbol names to exclude. Defaults to ["main"].
|
||||
|
||||
Returns:
|
||||
List of all symbols found in the object file (excluding those in exclude_list).
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
if exclude_list is None:
|
||||
exclude_list = ["main"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmppath = Path(tmpdir)
|
||||
cpp_file = tmppath / "test.cpp"
|
||||
obj_file = tmppath / "test.o"
|
||||
|
||||
cpp_file.write_text(cpp_content)
|
||||
|
||||
result = subprocess.run(
|
||||
compile_flags + [str(cpp_file), "-o", str(obj_file)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Compilation failed: {result.stderr}")
|
||||
|
||||
symbols = get_symbols(str(obj_file))
|
||||
|
||||
# Return all symbol names, excluding those in the exclude list
|
||||
return [name for _addr, _stype, name in symbols if name not in exclude_list]
|
||||
|
||||
|
||||
def check_stable_only_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
|
||||
|
||||
This approach tests:
|
||||
1. WITHOUT macros -> many torch symbols exposed
|
||||
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
|
||||
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
|
||||
4. WITH both macros -> zero torch symbols (all hidden)
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
test_cpp_content = """
|
||||
// Main torch C++ API headers
|
||||
#include <torch/torch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// ATen tensor library
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
// Core c10 headers (commonly used)
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
int main() { return 0; }
|
||||
"""
|
||||
|
||||
base_compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c", # Compile only, don't link
|
||||
]
|
||||
|
||||
# Compile WITHOUT any macros
|
||||
symbols_without = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=base_compile_flags,
|
||||
)
|
||||
|
||||
# We expect constexpr symbols, inline functions used by other headers etc.
|
||||
# to produce symbols
|
||||
num_symbols_without = len(symbols_without)
|
||||
print(f"Found {num_symbols_without} symbols without any macros defined")
|
||||
assert num_symbols_without != 0, (
|
||||
"Expected a non-zero number of symbols without any macros"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
|
||||
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
|
||||
|
||||
symbols_with_stable_only = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_stable_only,
|
||||
)
|
||||
|
||||
num_symbols_with_stable_only = len(symbols_with_stable_only)
|
||||
assert num_symbols_with_stable_only == 0, (
|
||||
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
|
||||
compile_flags_with_target_version = base_compile_flags + [
|
||||
"-DTORCH_TARGET_VERSION=1"
|
||||
]
|
||||
|
||||
symbols_with_target_version = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_target_version,
|
||||
)
|
||||
|
||||
num_symbols_with_target_version = len(symbols_with_target_version)
|
||||
assert num_symbols_with_target_version == 0, (
|
||||
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
|
||||
)
|
||||
|
||||
# Compile WITH both macros (expect 0 symbols)
|
||||
compile_flags_with_both = base_compile_flags + [
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
"-DTORCH_TARGET_VERSION=1",
|
||||
]
|
||||
|
||||
symbols_with_both = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_both,
|
||||
)
|
||||
|
||||
num_symbols_with_both = len(symbols_with_both)
|
||||
assert num_symbols_with_both == 0, (
|
||||
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_api_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
stable_dir = include_dir / "torch" / "csrc" / "stable"
|
||||
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
|
||||
|
||||
stable_headers = list(stable_dir.rglob("*.h"))
|
||||
if not stable_headers:
|
||||
raise RuntimeError("Could not find any stable headers")
|
||||
|
||||
includes = []
|
||||
for header in stable_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_stable_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable = len(symbols_stable)
|
||||
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
|
||||
assert num_symbols_stable > 0, (
|
||||
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_headeronly_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Find all headers in torch/headeronly
|
||||
headeronly_dir = include_dir / "torch" / "headeronly"
|
||||
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
|
||||
headeronly_headers = list(headeronly_dir.rglob("*.h"))
|
||||
if not headeronly_headers:
|
||||
raise RuntimeError("Could not find any headeronly headers")
|
||||
|
||||
# Filter out platform-specific headers that may not compile everywhere
|
||||
platform_specific_keywords = [
|
||||
"cpu/vec",
|
||||
]
|
||||
|
||||
filtered_headers = []
|
||||
for header in headeronly_headers:
|
||||
rel_path = header.relative_to(include_dir).as_posix()
|
||||
if not any(
|
||||
keyword in rel_path.lower() for keyword in platform_specific_keywords
|
||||
):
|
||||
filtered_headers.append(header)
|
||||
|
||||
includes = []
|
||||
for header in filtered_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_headeronly_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_headeronly = _compile_and_extract_symbols(
|
||||
cpp_content=test_headeronly_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_headeronly = len(symbols_headeronly)
|
||||
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
|
||||
assert num_symbols_headeronly > 0, (
|
||||
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_headeronly} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_aoti_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_shim_content = """
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
int main() {
|
||||
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
|
||||
int32_t (*fp2)() = &aoti_torch_dtype_float32;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_shim = len(symbols_shim)
|
||||
assert num_symbols_shim > 0, (
|
||||
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_c_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Check if the stable C shim exists
|
||||
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
|
||||
if not stable_shim.exists():
|
||||
raise RuntimeError("Could not find stable c shim")
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_stable_shim_content = """
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
int main() {
|
||||
// Reference stable C API functions to create undefined symbols
|
||||
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
|
||||
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable_shim = len(symbols_stable_shim)
|
||||
assert num_symbols_stable_shim > 0, (
|
||||
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
|
||||
print(f"lib: {lib}")
|
||||
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
|
||||
@ -129,6 +460,13 @@ def main() -> None:
|
||||
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
|
||||
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
|
||||
|
||||
# Check symbols when TORCH_STABLE_ONLY is defined
|
||||
check_stable_only_symbols(install_root)
|
||||
check_stable_api_symbols(install_root)
|
||||
check_headeronly_symbols(install_root)
|
||||
check_aoti_shim_symbols(install_root)
|
||||
check_stable_c_shim_symbols(install_root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -353,6 +353,17 @@ def test_linalg(device="cpu") -> None:
|
||||
torch.linalg.svd(A)
|
||||
|
||||
|
||||
def test_sdpa(device="cpu", dtype=torch.float16) -> None:
|
||||
"""Regression test for https://github.com/pytorch/pytorch/issues/167602
|
||||
Without nvrtc_builtins on CuDNN-9.13 on CUDA-13 fails with ` No valid execution plans built.`
|
||||
"""
|
||||
print(f"Testing SDPA on {device} using type {dtype}")
|
||||
k, q, v = torch.rand(3, 1, 16, 77, 64, dtype=dtype, device=device).unbind(0)
|
||||
attn = torch.rand(1, 1, 77, 77, dtype=dtype, device=device)
|
||||
rc = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn)
|
||||
assert rc.isnan().any().item() is False
|
||||
|
||||
|
||||
def smoke_test_compile(device: str = "cpu") -> None:
|
||||
supported_dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
|
||||
@ -489,10 +500,12 @@ def main() -> None:
|
||||
smoke_test_conv2d()
|
||||
test_linalg()
|
||||
test_numpy()
|
||||
test_sdpa()
|
||||
|
||||
if is_cuda_system:
|
||||
test_linalg("cuda")
|
||||
test_cuda_gds_errors_captured()
|
||||
test_sdpa("cuda")
|
||||
|
||||
if options.package == "all":
|
||||
smoke_test_modules()
|
||||
|
||||
@ -344,8 +344,18 @@ test_python_smoke() {
|
||||
}
|
||||
|
||||
test_python_smoke_b200() {
|
||||
# Targeted smoke tests for B200 - staged approach to avoid too many failures
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
# Targeted smoke tests for B200 including FlashAttention CuTe coverage
|
||||
install_flash_attn_cute
|
||||
time python test/run_test.py \
|
||||
--include \
|
||||
test_matmul_cuda \
|
||||
test_scaled_matmul_cuda \
|
||||
inductor/test_fp8 \
|
||||
nn/attention/test_fa4 \
|
||||
nn/attention/test_open_registry \
|
||||
inductor/test_flex_flash \
|
||||
$PYTHON_TEST_EXTRA_OPTION \
|
||||
--upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
@ -1670,6 +1680,22 @@ test_operator_microbenchmark() {
|
||||
done
|
||||
}
|
||||
|
||||
test_attention_microbenchmark() {
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
TEST_DIR=$(pwd)
|
||||
|
||||
# Install attention-gym dependency
|
||||
echo "Installing attention-gym..."
|
||||
python -m pip install git+https://github.com/meta-pytorch/attention-gym.git@main
|
||||
pip show triton
|
||||
|
||||
cd "${TEST_DIR}"/benchmarks/transformer
|
||||
|
||||
$TASKSET python score_mod.py --config configs/config_basic.yaml \
|
||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/attention_microbenchmark.json"
|
||||
}
|
||||
|
||||
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
||||
(cd test && python -c "import torch; print(torch.__config__.show())")
|
||||
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
|
||||
@ -1727,6 +1753,8 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then
|
||||
test_operator_microbenchmark
|
||||
elif [[ "${TEST_CONFIG}" == *attention_microbenchmark* ]]; then
|
||||
test_attention_microbenchmark
|
||||
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
|
||||
2
.github/actionlint.yaml
vendored
2
.github/actionlint.yaml
vendored
@ -63,7 +63,7 @@ self-hosted-runner:
|
||||
- linux.rocm.gpu.gfx942.1
|
||||
- linux.rocm.gpu.gfx942.2
|
||||
- linux.rocm.gpu.gfx942.4
|
||||
- rocm-docker
|
||||
- linux.rocm.gfx942.docker-cache
|
||||
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
|
||||
- macos-m1-stable
|
||||
- macos-m1-14
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
ad5816f0eee1c873df1b7d371c69f1f811a89387
|
||||
07b6cbde121417a70e4dc871adb6d27030e0ce3f
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
ccb801b88af136454798b945175c4c87e636ac33
|
||||
acccf86477759b2d3500f1ae1be065f7b1e409ec
|
||||
|
||||
13
.github/labeler.yml
vendored
13
.github/labeler.yml
vendored
@ -165,3 +165,16 @@
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
- third_party/fbgemm
|
||||
|
||||
"ciflow/mps":
|
||||
- aten/src/ATen/mps/**
|
||||
- aten/src/ATen/native/mps/**
|
||||
- torch/_inductor/codegen/mps.py
|
||||
- test/test_mps.py
|
||||
- test/inductor/test_mps_basic.py
|
||||
|
||||
"ciflow/h100-symm-mem":
|
||||
- torch/csrc/distributed/c10d/symm_mem/**
|
||||
- torch/distributed/_symmetric_memory/**
|
||||
- test/distributed/**/*mem*
|
||||
- test/distributed/**/*mem*/**
|
||||
|
||||
3
.github/scripts/lintrunner.sh
vendored
3
.github/scripts/lintrunner.sh
vendored
@ -34,6 +34,9 @@ python3 torch/utils/data/datapipes/gen_pyi.py
|
||||
# Also check generated pyi files
|
||||
find torch -name '*.pyi' -exec git add --force -- "{}" +
|
||||
|
||||
# Print current environment
|
||||
python3 -m pip freeze
|
||||
|
||||
RC=0
|
||||
# Run lintrunner on all files
|
||||
if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then
|
||||
|
||||
73
.github/workflows/attention_op_microbenchmark.yml
vendored
Normal file
73
.github/workflows/attention_op_microbenchmark.yml
vendored
Normal file
@ -0,0 +1,73 @@
|
||||
name: attention_op_microbenchmark
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/op-benchmark/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
# Run at 06:00 UTC everyday
|
||||
- cron: 0 7 * * *
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
attn-microbenchmark-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '8.0 9.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
|
||||
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
attn-microbenchmark-test:
|
||||
name: attn-microbenchmark-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: attn-microbenchmark-build
|
||||
with:
|
||||
timeout-minutes: 500
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
# B200 runner
|
||||
opmicrobenchmark-build-b200:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: opmicrobenchmark-build-b200
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
opmicrobenchmark-test-b200:
|
||||
name: opmicrobenchmark-test-b200
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: opmicrobenchmark-build-b200
|
||||
with:
|
||||
timeout-minutes: 500
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
||||
16
.github/workflows/docker-builds.yml
vendored
16
.github/workflows/docker-builds.yml
vendored
@ -119,6 +119,22 @@ jobs:
|
||||
with:
|
||||
docker-image: ${{ steps.build-docker-image.outputs.docker-image }}
|
||||
|
||||
- name: Generate output
|
||||
if: contains(matrix.docker-image-name, 'rocm')
|
||||
id: generate_output
|
||||
run: |
|
||||
docker_image_name="${{ matrix.docker-image-name }}"
|
||||
docker_image_tag="${{ steps.build-docker-image.outputs.docker-image }}"
|
||||
echo "${docker_image_name}=${docker_image_tag}" >> docker-builds-output-${docker_image_name}.txt
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4.4.0
|
||||
if: contains(matrix.docker-image-name, 'rocm')
|
||||
with:
|
||||
name: docker-builds-artifacts-${{ matrix.docker-image-name }}
|
||||
retention-days: 14
|
||||
path: ./docker-builds-output-${{ matrix.docker-image-name }}.txt
|
||||
|
||||
- uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
|
||||
name: Push to https://ghcr.io/
|
||||
id: push-to-ghcr-io
|
||||
|
||||
55
.github/workflows/docker-cache-mi300.yml
vendored
55
.github/workflows/docker-cache-mi300.yml
vendored
@ -1,55 +0,0 @@
|
||||
name: docker-cache-mi300
|
||||
|
||||
on:
|
||||
# run every 6 hours
|
||||
schedule:
|
||||
- cron: 0 0,6,12,18 * * *
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
runs-on: rocm-docker
|
||||
steps:
|
||||
- name: Checkout PyTorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
no-sudo: true
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
push: false
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
|
||||
- name: Tar and upload to S3 bucket
|
||||
run: |
|
||||
sudo docker save -o ~/docker-data/pytorch/pytorch_docker_image.tar ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
sudo rclone copy -P --s3-upload-concurrency 64 --s3-chunk-size 200M --s3-upload-cutoff 300M ~/docker-data/pytorch/pytorch_docker_image.tar oci:pytorchbucket0002/pytorch_docker_image --progress
|
||||
105
.github/workflows/docker-cache-rocm.yml
vendored
Normal file
105
.github/workflows/docker-cache-rocm.yml
vendored
Normal file
@ -0,0 +1,105 @@
|
||||
name: docker-cache-rocm
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [docker-builds]
|
||||
branches: [main, release]
|
||||
types:
|
||||
- completed
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
actions: read
|
||||
|
||||
jobs:
|
||||
download-docker-builds-artifacts:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: download-docker-builds-artifacts
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
pytorch-linux-jammy-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}
|
||||
pytorch-linux-noble-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}
|
||||
pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}
|
||||
steps:
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4.1.7
|
||||
with:
|
||||
run-id: ${{ github.event.workflow_run.id }}
|
||||
path: ./docker-builds-artifacts
|
||||
merge-multiple: true
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Process artifacts
|
||||
id: process-artifacts
|
||||
run: |
|
||||
ls -R ./docker-builds-artifacts
|
||||
cat ./docker-builds-artifacts/*txt >> "${GITHUB_OUTPUT}"
|
||||
cat "${GITHUB_OUTPUT}"
|
||||
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
needs: download-docker-builds-artifacts
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
runner: [linux.rocm.gfx942.docker-cache]
|
||||
docker-image: [
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
|
||||
]
|
||||
runs-on: "${{ matrix.runner }}"
|
||||
steps:
|
||||
- name: debug
|
||||
run: |
|
||||
JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}"
|
||||
echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}"
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Generate ghrc.io tag
|
||||
id: ghcr-io-tag
|
||||
run: |
|
||||
ecr_image="${{ matrix.docker-image }}"
|
||||
ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}"
|
||||
echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }}
|
||||
|
||||
- name: Save as tarball
|
||||
run: |
|
||||
docker_image_tag=${{ matrix.docker-image }}
|
||||
docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":"
|
||||
docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-"
|
||||
ref_name=${{ github.event.workflow_run.head_branch }}
|
||||
if [[ $ref_name =~ "release/" ]]; then
|
||||
ref_suffix="release"
|
||||
elif [[ $ref_name == "main" ]]; then
|
||||
ref_suffix="main"
|
||||
else
|
||||
echo "Unexpected branch in ref_name: ${ref_name}" && exit 1
|
||||
fi
|
||||
docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }}
|
||||
# mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention
|
||||
docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }}
|
||||
mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar
|
||||
1
.github/workflows/h100-distributed.yml
vendored
1
.github/workflows/h100-distributed.yml
vendored
@ -37,7 +37,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: "linux.c7i.12xlarge"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '9.0'
|
||||
|
||||
2
.github/workflows/inductor-rocm-mi200.yml
vendored
2
.github/workflows/inductor-rocm-mi200.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: inductor-rocm
|
||||
name: inductor-rocm-mi200
|
||||
|
||||
on:
|
||||
schedule:
|
||||
|
||||
8
.github/workflows/nightly.yml
vendored
8
.github/workflows/nightly.yml
vendored
@ -5,9 +5,11 @@ on:
|
||||
- cron: 0 0 * * *
|
||||
push:
|
||||
tags:
|
||||
# NOTE: Doc build pipelines should only get triggered on release candidate builds
|
||||
# Release candidate tags look like: v1.11.0-rc1
|
||||
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
|
||||
# NOTE: Doc build pipelines should only get triggered on:
|
||||
# Major or minor release candidates builds
|
||||
- v[0-9]+.[0-9]+.0+-rc[0-9]+
|
||||
# Final RC for major, minor and patch releases
|
||||
- v[0-9]+.[0-9]+.[0-9]+
|
||||
- ciflow/nightly/*
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
2
.github/workflows/rocm-mi200.yml
vendored
2
.github/workflows/rocm-mi200.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: rocm
|
||||
name: rocm-mi200
|
||||
|
||||
on:
|
||||
push:
|
||||
|
||||
4
.github/workflows/test-b200.yml
vendored
4
.github/workflows/test-b200.yml
vendored
@ -5,7 +5,9 @@
|
||||
# Flow:
|
||||
# 1. Builds PyTorch with CUDA 12.8+ and sm100 architecture for B200
|
||||
# 2. Runs smoke tests on linux.dgx.b200 runner
|
||||
# 3. Tests executed are defined in .ci/pytorch/test.sh -> test_python_smoke() function
|
||||
# 3. Tests executed are defined in .ci/pytorch/test.sh -> test_python_smoke_b200() function
|
||||
# - Includes matmul, scaled_matmul, FP8, and FlashAttention CuTe tests
|
||||
# - FlashAttention CuTe DSL is installed as part of test execution
|
||||
#
|
||||
# Triggered by:
|
||||
# - Pull requests modifying this workflow file
|
||||
|
||||
1
.github/workflows/test-h100.yml
vendored
1
.github/workflows/test-h100.yml
vendored
@ -41,7 +41,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '9.0'
|
||||
|
||||
83
.github/workflows/trunk-rocm-mi300.yml
vendored
Normal file
83
.github/workflows/trunk-rocm-mi300.yml
vendored
Normal file
@ -0,0 +1,83 @@
|
||||
name: trunk-rocm-mi300
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- release/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: 29 8 * * * # about 1:29am PDT
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
llm-td:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: before-test
|
||||
uses: ./.github/workflows/llm_td_retrieval.yml
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
target-determination:
|
||||
name: before-test
|
||||
uses: ./.github/workflows/target_determination.yml
|
||||
needs: llm-td
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
|
||||
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
|
||||
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
|
||||
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
1
.github/workflows/upload-test-stats.yml
vendored
1
.github/workflows/upload-test-stats.yml
vendored
@ -5,6 +5,7 @@ on:
|
||||
workflows:
|
||||
- pull
|
||||
- trunk
|
||||
- trunk-rocm-mi300
|
||||
- periodic
|
||||
- periodic-rocm-mi200
|
||||
- periodic-rocm-mi300
|
||||
|
||||
@ -186,6 +186,8 @@ include_patterns = [
|
||||
'aten/src/ATen/native/nested/cuda/*.h',
|
||||
'aten/src/ATen/native/nested/*.cpp',
|
||||
'aten/src/ATen/native/nested/*.h',
|
||||
'aten/src/ATen/xpu/**/*.h',
|
||||
'aten/src/ATen/xpu/**/*.cpp',
|
||||
'c10/**/*.cpp',
|
||||
'c10/**/*.h',
|
||||
'torch/*.h',
|
||||
|
||||
@ -736,6 +736,44 @@ if(NOT DEFINED USE_BLAS)
|
||||
set(USE_BLAS ON)
|
||||
endif()
|
||||
|
||||
# Prioritized Text Linker Optimization
|
||||
if(USE_PRIORITIZED_TEXT_FOR_LD)
|
||||
|
||||
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
|
||||
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python_EXECUTABLE}
|
||||
${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py
|
||||
--filein "${LINKER_SCRIPT_FILE_IN}"
|
||||
--fout "${LINKER_SCRIPT_FILE_OUT}"
|
||||
RESULT_VARIABLE _gen_result
|
||||
OUTPUT_VARIABLE _gen_output
|
||||
ERROR_VARIABLE _gen_error
|
||||
)
|
||||
|
||||
if(NOT _gen_result EQUAL 0)
|
||||
message(FATAL_ERROR
|
||||
"Failed to generate linker script:\n${_gen_output}\n${_gen_error}")
|
||||
endif()
|
||||
|
||||
append_cxx_flag_if_supported("-ffunction-sections" CMAKE_CXX_FLAGS)
|
||||
append_cxx_flag_if_supported("-fdata-sections" CMAKE_CXX_FLAGS)
|
||||
append_c_flag_if_supported("-ffunction-sections" CMAKE_C_FLAGS)
|
||||
append_c_flag_if_supported("-fdata-sections" CMAKE_C_FLAGS)
|
||||
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
|
||||
set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
|
||||
|
||||
else()
|
||||
if(LINUX AND CPU_AARCH64)
|
||||
message(WARNING [[
|
||||
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
|
||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||
]])
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Build libtorch mobile library, which contains ATen/TH ops and native support
|
||||
# for TorchScript model, but doesn't contain not-yet-unified caffe2 ops;
|
||||
if(INTERN_BUILD_MOBILE)
|
||||
@ -1402,9 +1440,6 @@ if(BUILD_JNI)
|
||||
add_subdirectory(android/pytorch_android)
|
||||
endif()
|
||||
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
|
||||
# Parse custom debug info
|
||||
if(DEFINED USE_CUSTOM_DEBINFO)
|
||||
string(REPLACE ";" " " SOURCE_FILES "${USE_CUSTOM_DEBINFO}")
|
||||
@ -1444,56 +1479,5 @@ if(BUILD_BUNDLE_PTXAS AND USE_CUDA)
|
||||
DESTINATION "${CMAKE_INSTALL_BINDIR}")
|
||||
endif()
|
||||
|
||||
if(USE_PRIORITIZED_TEXT_FOR_LD)
|
||||
add_compile_options(
|
||||
$<$<COMPILE_LANGUAGE:C,CXX>:-ffunction-sections>
|
||||
$<$<COMPILE_LANGUAGE:C,CXX>:-fdata-sections>
|
||||
)
|
||||
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
|
||||
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${LINKER_SCRIPT_FILE_OUT}"
|
||||
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py --filein "${LINKER_SCRIPT_FILE_IN}" --fout "${LINKER_SCRIPT_FILE_OUT}"
|
||||
DEPENDS ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py "${LINKER_SCRIPT_FILE_IN}"
|
||||
COMMENT "Generating prioritized text linker files"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(generate_linker_script DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
|
||||
|
||||
if(BUILD_PYTHON)
|
||||
set(LINKER_OPT_TARGETS torch_python)
|
||||
endif()
|
||||
|
||||
if(NOT BUILD_LIBTORCHLESS)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_cpu c10)
|
||||
if(USE_CUDA)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_cuda c10_cuda)
|
||||
endif()
|
||||
if(USE_XPU)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_xpu c10_xpu)
|
||||
endif()
|
||||
if(USE_ROCM)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_hip c10_hip)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
foreach(tgt IN LISTS LINKER_OPT_TARGETS)
|
||||
if(TARGET ${tgt})
|
||||
add_dependencies("${tgt}" generate_linker_script)
|
||||
target_link_options_if_supported(${tgt} "-T,${LINKER_SCRIPT_FILE_OUT}")
|
||||
set_property(TARGET ${tgt} APPEND PROPERTY LINK_DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
|
||||
else()
|
||||
message(WARNING "Requested target '${tgt}' for linker script optimization was not found.")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
else()
|
||||
if(LINUX AND CPU_AARCH64)
|
||||
message(WARNING [[
|
||||
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
|
||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||
]])
|
||||
endif()
|
||||
endif()
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
|
||||
2
LICENSE
2
LICENSE
@ -37,7 +37,7 @@ Copyright (c) 2024 Tri Dao.
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Arm:
|
||||
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
|
||||
Copyright (c) 2021, 2023-2025 Arm Limited and/or its affiliates
|
||||
|
||||
All contributions from Caffe:
|
||||
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||
|
||||
@ -18,6 +18,8 @@ Please report security issues using https://github.com/pytorch/pytorch/security/
|
||||
|
||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
|
||||
**Note on crashes and out of bounds access**: PyTorch is a computational framework that performs operations on behalf of the caller. Like many low-level libraries, PyTorch generally does not validate all inputs to every function—the responsibility for providing valid arguments lies with the calling code. While crashes and out of bounds memory access should be reported as bugs, they are generally not considered security vulnerabilities in PyTorch's threat model.
|
||||
|
||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||
|
||||
https://www.facebook.com/whitehat
|
||||
|
||||
@ -18,6 +18,8 @@
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace torch {
|
||||
class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {};
|
||||
namespace jit {
|
||||
@ -1630,4 +1632,6 @@ struct TORCH_API WeakOrStrongTypePtr {
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
#include <ATen/core/ivalue_inl.h> // IWYU pragma: keep
|
||||
|
||||
@ -29,6 +29,8 @@
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
struct Function;
|
||||
@ -2567,3 +2569,5 @@ TypePtr IValue::type() const {
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -11,6 +11,8 @@
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
// Sleef offers vectorized versions of some transcedentals
|
||||
// such as sin, cos, tan etc..
|
||||
// However for now opting for STL, since we are not building
|
||||
@ -650,3 +652,5 @@ inline Vectorized<float> Vectorized<float>::erf() const {
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#include <ATen/cuda/CUDAGraph.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/MemPool.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
|
||||
@ -13,7 +14,7 @@ static bool _cuda_graphs_debug = false;
|
||||
MempoolId_t graph_pool_handle() {
|
||||
// Sets just the second value, to distinguish it from MempoolId_ts created from
|
||||
// cudaStreamGetCaptureInfo id_s in capture_begin.
|
||||
return c10::cuda::MemPool::graph_pool_handle();
|
||||
return at::cuda::MemPool::graph_pool_handle();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -90,7 +91,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
||||
} else {
|
||||
// User did not ask us to share a mempool. Create graph pool handle using is_user_created=false.
|
||||
// Sets just the first value, to distinguish it from MempoolId_ts created by graph_pool_handle().
|
||||
mempool_id_ = c10::cuda::MemPool::graph_pool_handle(false);
|
||||
mempool_id_ = at::cuda::MemPool::graph_pool_handle(false);
|
||||
TORCH_INTERNAL_ASSERT(mempool_id_.first > 0);
|
||||
}
|
||||
|
||||
|
||||
69
aten/src/ATen/cuda/MemPool.cpp
Normal file
69
aten/src/ATen/cuda/MemPool.cpp
Normal file
@ -0,0 +1,69 @@
|
||||
#include <ATen/core/CachingHostAllocator.h>
|
||||
#include <ATen/cuda/MemPool.h>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
// uid_ is incremented when a user creates a MemPool,
|
||||
// for example: using graph_pool_handle() or c10::cuda::MemPool().
|
||||
//
|
||||
// uuid_ is incremented when CUDAGraph creates a MemPool
|
||||
// as a result of a user not providing a pool.
|
||||
//
|
||||
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
|
||||
// passed to a function, either by user or CUDAGraphs. For example,
|
||||
// default value of MempoolId_t for capture_begin function is {0, 0}.
|
||||
// That's why uid_ and uuid_ start at 1.
|
||||
std::atomic<CaptureId_t> MemPool::uid_{1};
|
||||
std::atomic<CaptureId_t> MemPool::uuid_{1};
|
||||
|
||||
MemPool::MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created,
|
||||
bool use_on_oom)
|
||||
: allocator_(allocator), is_user_created_(is_user_created) {
|
||||
if (is_user_created_) {
|
||||
id_ = {0, uid_++};
|
||||
} else {
|
||||
id_ = {uuid_++, 0};
|
||||
}
|
||||
device_ = c10::cuda::current_device();
|
||||
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
|
||||
if (use_on_oom) {
|
||||
CUDACachingAllocator::setUseOnOOM(device_, id_);
|
||||
}
|
||||
}
|
||||
|
||||
MemPool::~MemPool() {
|
||||
// TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
// We used to assert that TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
// However, this assertion is not true if a memory pool is shared
|
||||
// with a cuda graph. That CUDAGraph will increase the use count
|
||||
// until it is reset.
|
||||
CUDACachingAllocator::releasePool(device_, id_);
|
||||
c10::cuda::CUDACachingAllocator::emptyCache(id_);
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::id() {
|
||||
return id_;
|
||||
}
|
||||
|
||||
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
int MemPool::use_count() {
|
||||
return CUDACachingAllocator::getPoolUseCount(device_, id_);
|
||||
}
|
||||
|
||||
c10::DeviceIndex MemPool::device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
|
||||
if (is_user_created) {
|
||||
return {0, uid_++};
|
||||
}
|
||||
return {uuid_++, 0};
|
||||
}
|
||||
|
||||
} // namespace at::cuda
|
||||
44
aten/src/ATen/cuda/MemPool.h
Normal file
44
aten/src/ATen/cuda/MemPool.h
Normal file
@ -0,0 +1,44 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
// Keep BC only
|
||||
using c10::CaptureId_t;
|
||||
using c10::MempoolId_t;
|
||||
|
||||
// MemPool represents a pool of memory in a caching allocator. Currently,
|
||||
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
||||
//
|
||||
// An allocator pointer can be passed to the MemPool to define how the
|
||||
// allocations should be done in the pool. For example: using a different
|
||||
// system allocator such as ncclMemAlloc.
|
||||
struct TORCH_CUDA_CPP_API MemPool {
|
||||
MemPool(
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
|
||||
bool is_user_created = true,
|
||||
bool use_on_oom = false);
|
||||
MemPool(const MemPool&) = delete;
|
||||
MemPool(MemPool&&) = default;
|
||||
MemPool& operator=(const MemPool&) = delete;
|
||||
MemPool& operator=(MemPool&&) = default;
|
||||
~MemPool();
|
||||
|
||||
MempoolId_t id();
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator();
|
||||
int use_count();
|
||||
c10::DeviceIndex device();
|
||||
static MempoolId_t graph_pool_handle(bool is_user_created = true);
|
||||
|
||||
private:
|
||||
static std::atomic<CaptureId_t> uid_;
|
||||
static std::atomic<CaptureId_t> uuid_;
|
||||
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator_;
|
||||
bool is_user_created_;
|
||||
MempoolId_t id_;
|
||||
c10::DeviceIndex device_;
|
||||
};
|
||||
|
||||
} // namespace at::cuda
|
||||
@ -55,14 +55,6 @@ struct numeric_limits<int8_t> {
|
||||
static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<uint16_t> {
|
||||
static inline __host__ __device__ uint16_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint16_t max() { return UINT16_MAX; }
|
||||
static inline __host__ __device__ uint16_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint16_t upper_bound() { return UINT16_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<int16_t> {
|
||||
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
|
||||
@ -71,14 +63,6 @@ struct numeric_limits<int16_t> {
|
||||
static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<uint32_t> {
|
||||
static inline __host__ __device__ uint32_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint32_t max() { return UINT32_MAX; }
|
||||
static inline __host__ __device__ uint32_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint32_t upper_bound() { return UINT32_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<int32_t> {
|
||||
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
|
||||
@ -87,21 +71,6 @@ struct numeric_limits<int32_t> {
|
||||
static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<uint64_t> {
|
||||
#ifdef _MSC_VER
|
||||
static inline __host__ __device__ uint64_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint64_t max() { return _UI64_MAX; }
|
||||
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint64_t upper_bound() { return _UI64_MAX; }
|
||||
#else
|
||||
static inline __host__ __device__ uint64_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint64_t max() { return UINT64_MAX; }
|
||||
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint64_t upper_bound() { return UINT64_MAX; }
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<int64_t> {
|
||||
#ifdef _MSC_VER
|
||||
|
||||
@ -440,7 +440,7 @@ bool MPSHeapAllocatorImpl::release_cached_buffers() {
|
||||
// we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers.
|
||||
m_mutex.unlock();
|
||||
auto stream = getDefaultMPSStream();
|
||||
dispatch_sync(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
stream->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
});
|
||||
m_mutex.lock();
|
||||
|
||||
@ -110,6 +110,9 @@ class TORCH_API MPSStream {
|
||||
return _stream;
|
||||
}
|
||||
|
||||
MTLBuffer_t getErrorBuffer();
|
||||
void checkLastError();
|
||||
|
||||
private:
|
||||
Stream _stream;
|
||||
MTLCommandQueue_t _commandQueue = nil;
|
||||
@ -121,6 +124,8 @@ class TORCH_API MPSStream {
|
||||
dispatch_queue_t _serialQueue = nullptr;
|
||||
// CommitAndContinue is enabled by default
|
||||
bool _enableCommitAndContinue = true;
|
||||
// Buffer that contains last raised error
|
||||
MTLBuffer_t _errorBuffer = nil;
|
||||
|
||||
// use synchronize() to access any of these commit functions outside MPSStream
|
||||
void commit();
|
||||
@ -155,4 +160,7 @@ class TORCH_API MPSStreamImpl {
|
||||
MPSStreamImpl();
|
||||
};
|
||||
|
||||
#ifdef __OBJC__
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
|
||||
#endif
|
||||
} // namespace at::mps
|
||||
|
||||
@ -3,13 +3,13 @@
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <c10/metal/error.h>
|
||||
|
||||
@interface MPSGraphExecutionDescriptor ()
|
||||
@property(readwrite, atomic) BOOL enableCommitAndContinue;
|
||||
@end
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
// MPSStream
|
||||
//-----------------------------------------------------------------
|
||||
@ -30,6 +30,10 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) {
|
||||
// Choose level which optimizes for GPU
|
||||
_compilationDescriptor.optimizationLevel = MPSGraphOptimizationLevel0;
|
||||
_executionDescriptor.compilationDescriptor = _compilationDescriptor;
|
||||
|
||||
_errorBuffer = [MPSDevice::getInstance()->device() newBufferWithLength:sizeof(c10::metal::ErrorMessages)
|
||||
options:MTLResourceStorageModeShared];
|
||||
std::memset([_errorBuffer contents], 0, 1024);
|
||||
}
|
||||
|
||||
MPSStream::~MPSStream() {
|
||||
@ -38,6 +42,8 @@ MPSStream::~MPSStream() {
|
||||
[_executionDescriptor release];
|
||||
[_compilationDescriptor release];
|
||||
_executionDescriptor = nil;
|
||||
[_errorBuffer release];
|
||||
_errorBuffer = nil;
|
||||
_compilationDescriptor = nil;
|
||||
|
||||
assert(_commandBuffer == nil);
|
||||
@ -104,6 +110,7 @@ void MPSStream::commitAndWait() {
|
||||
[_prevCommandBuffer waitUntilCompleted];
|
||||
[_prevCommandBuffer release];
|
||||
_prevCommandBuffer = nil;
|
||||
checkLastError();
|
||||
}
|
||||
|
||||
if (_commandBuffer) {
|
||||
@ -111,6 +118,7 @@ void MPSStream::commitAndWait() {
|
||||
[_commandBuffer waitUntilCompleted];
|
||||
[_commandBuffer release];
|
||||
_commandBuffer = nil;
|
||||
checkLastError();
|
||||
}
|
||||
}
|
||||
|
||||
@ -153,7 +161,7 @@ void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t
|
||||
if (length == 0) {
|
||||
return;
|
||||
}
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
||||
@autoreleasepool {
|
||||
endKernelCoalescing();
|
||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||
@ -183,7 +191,7 @@ void MPSStream::copy(id<MTLBuffer> srcBuffer,
|
||||
size_t dstOffset,
|
||||
uint64_t profileId,
|
||||
SyncType syncType) {
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
||||
@autoreleasepool {
|
||||
endKernelCoalescing();
|
||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||
@ -236,7 +244,7 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
|
||||
auto& profiler = getMPSProfiler();
|
||||
const bool isGraphProfilingEnabled = profiler.isOperationProfilingEnabled();
|
||||
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
||||
endKernelCoalescing();
|
||||
if (isGraphProfilingEnabled) {
|
||||
// this function call is only relevant for interval-based Signposts
|
||||
@ -266,6 +274,24 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
|
||||
});
|
||||
}
|
||||
|
||||
id<MTLBuffer> MPSStream::getErrorBuffer() {
|
||||
return _errorBuffer;
|
||||
}
|
||||
|
||||
void MPSStream::checkLastError() {
|
||||
auto msgs = reinterpret_cast<c10::metal::ErrorMessages*>([_errorBuffer contents]);
|
||||
const auto& msg = msgs->msg[0];
|
||||
if (!msgs) {
|
||||
return;
|
||||
}
|
||||
unsigned int count = 0;
|
||||
std::swap(count, msgs->count);
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
throw c10::AcceleratorError({msg.func, msg.file, msg.line}, 1, msg.message);
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
// MPSStreamImpl
|
||||
//-----------------------------------------------------------------
|
||||
@ -289,4 +315,19 @@ MPSStream* getDefaultMPSStream() {
|
||||
return MPSStreamImpl::getInstance();
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
|
||||
__block std::optional<std::exception_ptr> block_exception;
|
||||
dispatch_sync(queue, ^() {
|
||||
try {
|
||||
block();
|
||||
} catch (...) {
|
||||
block_exception = std::current_exception();
|
||||
}
|
||||
});
|
||||
if (block_exception) {
|
||||
std::rethrow_exception(*block_exception);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::mps
|
||||
|
||||
@ -1936,7 +1936,7 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
|
||||
|
||||
// We order the tensors. t1 will be the larger tensor
|
||||
// We can always transpose tensor2 as the dimensions are always >= 1 (precondition from matmul)
|
||||
// and tensor1_larger iff tensor2.dim() > tensor1.dim(9
|
||||
// and tensor1_larger iff tensor2.dim() > tensor1.dim()
|
||||
const auto t1 = tensor1_larger ? MaybeOwned<Tensor>::borrowed(tensor1)
|
||||
: MaybeOwned<Tensor>::owned(tensor2.mT());
|
||||
const int64_t dim_t1 = t1->dim();
|
||||
@ -1948,20 +1948,11 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
|
||||
return false;
|
||||
}
|
||||
|
||||
// In this case we *do* incur in an extra copy to avoid creating an unnecessary large tensor in the backward
|
||||
// Suppose we don't fold here. Let t1.shape = [b, m, n] t2.shape = [n, k] like in a transformer
|
||||
// t2 will be expanded to a tensor of shape [b, n, k] and then we do t1.bmm(t2_expanded)
|
||||
// The issue appears in the backward.
|
||||
// The output gradient g of this operation would have shape [b, m, k]
|
||||
// The backward wrt. t2 of bmm would be given by t1.mH @ g, which has shape [b, n, k]
|
||||
// Then, the backward of expand is simply `sum(0)`. As such, we are instantiating a tensor
|
||||
// of shape [b, n, k] unnecessarily, which may cause a large memory footprint, and in the
|
||||
// worst case, an OOM
|
||||
bool t2_requires_grad = tensor1_larger ? tensor2.requires_grad() : tensor1.requires_grad();
|
||||
if (t2_requires_grad && !has_out) {
|
||||
// We should be checking !at::GradMode::is_enabled(), but apparently
|
||||
// this regresses performance in some cases:
|
||||
// https://github.com/pytorch/pytorch/issues/118548#issuecomment-1916022394
|
||||
// If we require a gradient, we should fold to minimize backward memory usage - even if this
|
||||
// leads to a copy in forward because is needed in backward,
|
||||
// only time we avoid this strict pre-allocated memory usage (has_out = True)
|
||||
bool requires_grad = tensor1.requires_grad() || tensor2.requires_grad();
|
||||
if (requires_grad && !has_out) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -142,6 +142,7 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra
|
||||
std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
|
||||
auto batch_sizes_t = _batch_sizes.contiguous();
|
||||
checkLongTensor(batch_sizes_t);
|
||||
TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty");
|
||||
|
||||
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
|
||||
int64_t max_batch_size = batch_sizes[0];
|
||||
|
||||
@ -23,6 +23,7 @@
|
||||
#include <ATen/ops/_aminmax_native.h>
|
||||
#include <ATen/ops/_assert_async_native.h>
|
||||
#include <ATen/ops/_assert_scalar_native.h>
|
||||
#include <ATen/ops/_async_error_native.h>
|
||||
#include <ATen/ops/_functional_assert_async_native.h>
|
||||
#include <ATen/ops/_functional_assert_scalar_native.h>
|
||||
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
|
||||
@ -479,6 +480,14 @@ Tensor isfinite(const Tensor& self) {
|
||||
});
|
||||
}
|
||||
|
||||
void _async_error(std::string_view msg) {
|
||||
TORCH_CHECK(0, msg);
|
||||
}
|
||||
|
||||
void _async_error_meta(std::string_view msg) {
|
||||
// Do NOT error, it's an async error!
|
||||
}
|
||||
|
||||
void _assert_async_cpu(const Tensor& self) {
|
||||
TORCH_CHECK(
|
||||
native::is_nonzero(self),
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace at::native {
|
||||
|
||||
// Used as an interface between the different BLAS-like libraries
|
||||
@ -21,3 +23,5 @@ static inline char to_blas(TransposeType trans) {
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
#include <ATen/native/ReduceOpsUtils.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
@ -79,12 +78,12 @@ void min_all_kernel_impl(Tensor& result, const Tensor& input) {
|
||||
reduce_all_impl<int64_t>(result, input, upper_bound<int64_t>(),
|
||||
[=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); });
|
||||
} else {
|
||||
AT_DISPATCH_V2(input.scalar_type(), "min_all", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] {
|
||||
using Vec = Vectorized<opmath_type<scalar_t>>;
|
||||
reduce_all_impl_vec<scalar_t>(result, input, upper_bound<scalar_t>(),
|
||||
[=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); },
|
||||
[=](Vec a, Vec b) -> Vec { return minimum(a, b); });
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -104,12 +103,12 @@ void max_all_kernel_impl(Tensor& result, const Tensor& input) {
|
||||
reduce_all_impl<int64_t>(result, input, lower_bound<int64_t>(),
|
||||
[=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); });
|
||||
} else {
|
||||
AT_DISPATCH_V2(input.scalar_type(), "max_all", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] {
|
||||
using Vec = Vectorized<opmath_type<scalar_t>>;
|
||||
reduce_all_impl_vec<scalar_t>(result, input, lower_bound<scalar_t>(),
|
||||
[=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); },
|
||||
[=](Vec a, Vec b) -> Vec { return maximum(a, b); });
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -200,7 +199,7 @@ void aminmax_allreduce_kernel(
|
||||
}
|
||||
);
|
||||
} else {
|
||||
AT_DISPATCH_V2(input.scalar_type(), "aminmax_cpu", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] {
|
||||
using Vec = Vectorized<opmath_type<scalar_t>>;
|
||||
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
|
||||
reduce_all_impl_vec_two_outputs<scalar_t>(
|
||||
@ -215,7 +214,7 @@ void aminmax_allreduce_kernel(
|
||||
[=](Vec a, Vec b) -> Vec { return minimum(a, b); },
|
||||
[=](Vec a, Vec b) -> Vec { return maximum(a, b); }
|
||||
);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
@ -348,35 +347,34 @@ struct MinValuesOps: public at::native::MinOps<scalar_t> {
|
||||
};
|
||||
|
||||
void min_values_kernel_impl(TensorIterator& iter) {
|
||||
// This case is special because of Vectorized<int64_t> does not
|
||||
// handle upper_bound<int64_t>().
|
||||
// See: https://github.com/pytorch/pytorch/issues/43254
|
||||
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce(
|
||||
iter,
|
||||
MinValuesOps<scalar_t>{},
|
||||
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
|
||||
}), kLong, kUInt64);
|
||||
if (iter.dtype() == kLong) {
|
||||
// This case is special because of Vectorized<int64_t> does not
|
||||
// handle upper_bound<int64_t>().
|
||||
// See: https://github.com/pytorch/pytorch/issues/43254
|
||||
using scalar_t = int64_t;
|
||||
binary_kernel_reduce(
|
||||
iter,
|
||||
MinValuesOps<scalar_t>{},
|
||||
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
|
||||
return;
|
||||
}
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cpu", [&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
|
||||
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
|
||||
static_cast<double>(upper_bound<scalar_t>()));
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void max_values_kernel_impl(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "max_values_cpu", AT_WRAP([&iter] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cpu", [&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return max_impl(a, b); },
|
||||
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return maximum(a, b); },
|
||||
lower_bound<scalar_t>());
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void argmax_kernel_impl(TensorIterator &iter) {
|
||||
|
||||
@ -11,7 +11,6 @@
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
@ -107,7 +106,7 @@ void min_kernel_impl(
|
||||
bool keepdim) {
|
||||
int64_t self_dim_size = ensure_nonempty_size(self, dim);
|
||||
|
||||
AT_DISPATCH_V2(self.scalar_type(), "min_cpu", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] {
|
||||
compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
|
||||
scalar_t* result_data, int64_t* indice_data,
|
||||
const scalar_t* self_data, auto self_dim_stride) {
|
||||
@ -129,7 +128,7 @@ void min_kernel_impl(
|
||||
*indice_data = index;
|
||||
}
|
||||
);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool);
|
||||
});
|
||||
}
|
||||
|
||||
void max_kernel_impl(
|
||||
@ -140,7 +139,7 @@ void max_kernel_impl(
|
||||
bool keepdim) {
|
||||
int64_t self_dim_size = ensure_nonempty_size(self, dim);
|
||||
|
||||
AT_DISPATCH_V2(self.scalar_type(), "max_cpu", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "max_cpu", [&] {
|
||||
compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
|
||||
scalar_t* result_data, int64_t* indice_data,
|
||||
const scalar_t* self_data, auto self_dim_stride) {
|
||||
@ -162,7 +161,7 @@ void max_kernel_impl(
|
||||
*indice_data = index;
|
||||
}
|
||||
);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool);
|
||||
});
|
||||
}
|
||||
|
||||
void aminmax_kernel(
|
||||
@ -187,7 +186,7 @@ void aminmax_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
AT_DISPATCH_V2(self.scalar_type(), "aminmax_cpu", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] {
|
||||
compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
|
||||
scalar_t* min_result_data, scalar_t* max_result_data,
|
||||
const scalar_t* self_data, auto self_dim_stride) {
|
||||
@ -210,7 +209,7 @@ void aminmax_kernel(
|
||||
*max_result_data = max_number;
|
||||
}
|
||||
);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half);
|
||||
});
|
||||
}
|
||||
|
||||
void where_kernel_impl(TensorIterator &iter) {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/CompositeRandomAccessorCommon.h>
|
||||
#include <thrust/swap.h>
|
||||
#include <thrust/tuple.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
@ -669,9 +669,12 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
// On non CK system(w/ ROCm), make sure use_fast_path is false
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
@ -680,7 +683,11 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/ReduceAllOps.h>
|
||||
@ -29,22 +28,22 @@ void _min_max_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
}
|
||||
|
||||
void aminmax_allreduce_launch_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(
|
||||
iter.input_dtype(), "aminmax_all_cuda", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] {
|
||||
_min_max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void aminmax_launch_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(
|
||||
iter.input_dtype(), "aminmax_cuda", AT_WRAP([&]() {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter,
|
||||
MinMaxOps<scalar_t, scalar_t, int32_t>{},
|
||||
thrust::pair<scalar_t, scalar_t>(
|
||||
at::numeric_limits<scalar_t>::upper_bound(),
|
||||
at::numeric_limits<scalar_t>::lower_bound()));
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/ReduceAllOps.h>
|
||||
@ -34,27 +33,27 @@ void max_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
}
|
||||
|
||||
void max_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(
|
||||
iter.dtype(), "max_values_cuda", AT_WRAP([&]() {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() {
|
||||
max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void max_launch_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(
|
||||
iter.input_dtype(), "max_cuda", AT_WRAP([&]() {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter,
|
||||
MaxOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(
|
||||
at::numeric_limits<scalar_t>::lower_bound(), 0));
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void max_all_launch_kernel(TensorIterator &iter) {
|
||||
AT_DISPATCH_V2(iter.input_dtype(), "max_all_cuda", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] {
|
||||
max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda)
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
#include <ATen/NumericUtils.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/cuda/NumericLimits.cuh>
|
||||
|
||||
@ -34,24 +33,24 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
}
|
||||
|
||||
void min_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void min_launch_kernel(TensorIterator &iter) {
|
||||
AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter,
|
||||
MinOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(at::numeric_limits<scalar_t>::upper_bound(), 0));
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void min_all_launch_kernel(TensorIterator &iter) {
|
||||
AT_DISPATCH_V2(iter.input_dtype(), "min_all_cuda", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda)
|
||||
|
||||
@ -267,15 +267,15 @@ void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, con
|
||||
* outer dimensions, which contains several "inner rows").
|
||||
* Each thread processes a single inner row at a time.
|
||||
*/
|
||||
template<typename scalar_t, class BinaryOp>
|
||||
template<typename scalar_t, typename index_t, class BinaryOp>
|
||||
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
|
||||
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
|
||||
const scalar_t init, BinaryOp binary_op)
|
||||
{
|
||||
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
|
||||
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
|
||||
const scalar_t *src = src_ + orow * row_size * num_irows + irow;
|
||||
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
|
||||
const scalar_t *src = src_ + static_cast<index_t>(orow) * row_size * num_irows + irow;
|
||||
scalar_t *tgt = tgt_ + (index_t) orow * row_size * num_irows + irow;
|
||||
scalar_t acc = init;
|
||||
|
||||
for (uint32_t col = 0; col < row_size; ++col) {
|
||||
@ -409,10 +409,15 @@ __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
|
||||
check_fits_in_unsigned(num_irows, "num_irows");
|
||||
check_fits_in_unsigned(num_orows, "num_orows");
|
||||
check_fits_in_unsigned(row_size, "row_size");
|
||||
|
||||
tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
if (static_cast<size_t>(num_irows) * num_orows * row_size <= UINT_MAX) {
|
||||
tensor_kernel_scan_outer_dim<scalar_t, uint32_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
|
||||
num_orows, num_irows, row_size, init, binary_op);
|
||||
} else {
|
||||
tensor_kernel_scan_outer_dim<scalar_t, size_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
|
||||
num_orows, num_irows, row_size, init, binary_op);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
|
||||
@ -40,8 +40,6 @@ using namespace at::mps;
|
||||
|
||||
namespace at::native::mps {
|
||||
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
|
||||
|
||||
struct MPSScalar {
|
||||
id<MTLBuffer> getMTLBuffer() const {
|
||||
return __builtin_bit_cast(id<MTLBuffer>, buffer.get());
|
||||
|
||||
@ -53,21 +53,6 @@
|
||||
@end
|
||||
|
||||
namespace at::native::mps {
|
||||
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
|
||||
__block std::optional<std::exception_ptr> block_exception;
|
||||
dispatch_sync(queue, ^() {
|
||||
try {
|
||||
block();
|
||||
} catch (...) {
|
||||
block_exception = std::current_exception();
|
||||
}
|
||||
});
|
||||
if (block_exception) {
|
||||
std::rethrow_exception(*block_exception);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes distance from lowest to highest element offset in given tensor.
|
||||
*/
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
#include <c10/metal/atomic.h>
|
||||
#include <c10/metal/error.h>
|
||||
#include <c10/metal/indexing.h>
|
||||
#include <metal_stdlib>
|
||||
|
||||
@ -31,10 +32,24 @@ OffsetT index_apply_indices(
|
||||
constant IndexAB* indices,
|
||||
constant int64_t* sizes,
|
||||
constant int64_t* strides,
|
||||
uint num_indices) {
|
||||
uint num_indices,
|
||||
thread bool& error,
|
||||
device ErrorMessages* error_buf) {
|
||||
OffsetT rc = offs.x;
|
||||
for (uint i = 0; i < num_indices; i++) {
|
||||
auto idx = indices[i].indexArray[offs.y];
|
||||
if (idx < -sizes[i] || idx >= sizes[i]) {
|
||||
TORCH_REPORT_ERROR(
|
||||
error_buf,
|
||||
"index ",
|
||||
idx,
|
||||
" is out of bounds for dimension ",
|
||||
i,
|
||||
" with size ",
|
||||
sizes[i]);
|
||||
error = true;
|
||||
break;
|
||||
}
|
||||
if (idx < 0) {
|
||||
idx += sizes[i];
|
||||
}
|
||||
@ -55,6 +70,7 @@ kernel void index_select(
|
||||
constant int64_t* index_sizes,
|
||||
constant int64_t* index_strides,
|
||||
constant uint4& ndim_nindices_numel,
|
||||
device ErrorMessages* error_buffer,
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const auto ndim = ndim_nindices_numel.x;
|
||||
const auto num_indices = ndim_nindices_numel.y;
|
||||
@ -65,8 +81,19 @@ kernel void index_select(
|
||||
indices_strides,
|
||||
ndim,
|
||||
thread_index);
|
||||
bool error = false;
|
||||
auto input_offs = index_apply_indices<OffsetT>(
|
||||
offs.yz, indices, index_sizes, index_strides, num_indices);
|
||||
offs.yz,
|
||||
indices,
|
||||
index_sizes,
|
||||
index_strides,
|
||||
num_indices,
|
||||
error,
|
||||
error_buffer);
|
||||
if (error) {
|
||||
output[offs.x / sizeof(T)] = 0;
|
||||
return;
|
||||
}
|
||||
output[offs.x / sizeof(T)] = input[input_offs / sizeof(T)];
|
||||
}
|
||||
|
||||
@ -82,7 +109,9 @@ inline void index_put_impl(
|
||||
constant int64_t* index_sizes,
|
||||
constant int64_t* index_strides,
|
||||
constant uint4& ndim_nindices_numel,
|
||||
device ErrorMessages* error_buffer,
|
||||
uint thread_index) {
|
||||
bool error = false;
|
||||
const auto ndim = ndim_nindices_numel.x;
|
||||
const auto num_indices = ndim_nindices_numel.y;
|
||||
const auto offs = index_get_offsets(
|
||||
@ -93,7 +122,16 @@ inline void index_put_impl(
|
||||
ndim,
|
||||
thread_index);
|
||||
auto output_offs = index_apply_indices<OffsetT>(
|
||||
offs.xz, indices, index_sizes, index_strides, num_indices);
|
||||
offs.xz,
|
||||
indices,
|
||||
index_sizes,
|
||||
index_strides,
|
||||
num_indices,
|
||||
error,
|
||||
error_buffer);
|
||||
if (error) {
|
||||
return;
|
||||
}
|
||||
output[output_offs / sizeof(T)] = input[offs.y / sizeof(T)];
|
||||
}
|
||||
|
||||
@ -109,6 +147,7 @@ kernel void index_put(
|
||||
constant int64_t* index_sizes,
|
||||
constant int64_t* index_strides,
|
||||
constant uint4& ndim_nindices_numel,
|
||||
device ErrorMessages* error_buffer,
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
index_put_impl(
|
||||
output,
|
||||
@ -121,6 +160,7 @@ kernel void index_put(
|
||||
index_sizes,
|
||||
index_strides,
|
||||
ndim_nindices_numel,
|
||||
error_buffer,
|
||||
thread_index);
|
||||
}
|
||||
|
||||
@ -136,6 +176,7 @@ kernel void index_put_serial(
|
||||
constant int64_t* index_sizes,
|
||||
constant int64_t* index_strides,
|
||||
constant uint4& ndim_nindices_numel,
|
||||
device ErrorMessages* error_buffer,
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
(void)thread_index; // Suppress unused vairable varning
|
||||
for (uint idx = 0; idx < ndim_nindices_numel.z; ++idx) {
|
||||
@ -150,6 +191,7 @@ kernel void index_put_serial(
|
||||
index_sizes,
|
||||
index_strides,
|
||||
ndim_nindices_numel,
|
||||
error_buffer,
|
||||
idx);
|
||||
}
|
||||
}
|
||||
@ -166,6 +208,7 @@ kernel void index_put_accumulate(
|
||||
constant int64_t* index_sizes,
|
||||
constant int64_t* index_strides,
|
||||
constant uint4& ndim_nindices_numel,
|
||||
device ErrorMessages* error_buffer,
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const auto ndim = ndim_nindices_numel.x;
|
||||
const auto num_indices = ndim_nindices_numel.y;
|
||||
@ -176,8 +219,18 @@ kernel void index_put_accumulate(
|
||||
indices_strides,
|
||||
ndim,
|
||||
thread_index);
|
||||
bool error = false;
|
||||
auto output_offs = index_apply_indices<OffsetT>(
|
||||
offs.xz, indices, index_sizes, index_strides, num_indices);
|
||||
offs.xz,
|
||||
indices,
|
||||
index_sizes,
|
||||
index_strides,
|
||||
num_indices,
|
||||
error,
|
||||
error_buffer);
|
||||
if (error) {
|
||||
return;
|
||||
}
|
||||
AtomicType<T>::atomic_add(
|
||||
reinterpret_cast<device AtomicType_t<T>*>(output),
|
||||
output_offs / sizeof(T),
|
||||
@ -197,6 +250,7 @@ kernel void index_put_accumulate(
|
||||
constant int64_t* index_sizes, \
|
||||
constant int64_t* index_strides, \
|
||||
constant uint4& ndim_nindices_numel, \
|
||||
device ErrorMessages* error_buffer, \
|
||||
uint thread_index [[thread_position_in_grid]])
|
||||
|
||||
#define REGISTER_INDEX_OP_ALL_DTYPES(OP_NAME) \
|
||||
|
||||
@ -220,7 +220,7 @@ Tensor _embedding_bag_dense_backward_mps(const Tensor& output_grad,
|
||||
auto num_threads = (params.mode == EmbeddingBagMode::MAX) ? output_grad.numel() : num_indices * params.feature_size;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("embedding_bag_backward_{}_{}",
|
||||
@ -273,7 +273,7 @@ Tensor _embedding_bag_per_sample_weights_backward_mps(const Tensor& output_grad,
|
||||
auto num_threads = num_indices * feature_size;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("embedding_bag_per_sample_weights_backward_{}_{}",
|
||||
|
||||
@ -179,7 +179,8 @@ static void dispatch_index_kernel(TensorIteratorBase& iter,
|
||||
iter.strides(2),
|
||||
index_size,
|
||||
index_stride,
|
||||
ndim_nindiees);
|
||||
ndim_nindiees,
|
||||
mpsStream->getErrorBuffer());
|
||||
mtl_dispatch1DJob(computeEncoder, indexSelectPSO, serial ? 1 : iter.numel());
|
||||
});
|
||||
}
|
||||
@ -299,7 +300,7 @@ static Tensor& nonzero_out_native_mps(const Tensor& self, Tensor& out_) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
dispatch_sync(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
stream->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
});
|
||||
int64_t total_nonzero = at::count_nonzero(self).item<int64_t>();
|
||||
@ -384,7 +385,7 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
dispatch_sync(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
stream->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
});
|
||||
int64_t total_nonzero = at::count_nonzero(self).item<int64_t>();
|
||||
|
||||
@ -923,7 +923,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(const Tensor& input,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "Not implemented for long on MPS");
|
||||
@autoreleasepool {
|
||||
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
// which kernel variant to use based on the normalized axis N size
|
||||
const int N_READS = 4;
|
||||
auto metalType = mps::scalarToMetalTypeString(input);
|
||||
|
||||
@ -192,6 +192,11 @@
|
||||
CompositeExplicitAutograd: _assert_tensor_metadata
|
||||
Meta: _assert_tensor_metadata_meta_symint
|
||||
|
||||
- func: _async_error(str msg) -> ()
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _async_error
|
||||
Meta: _async_error_meta
|
||||
|
||||
- func: _print(str s) -> ()
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _print
|
||||
@ -7513,7 +7518,7 @@
|
||||
- func: _sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_mask_projection
|
||||
SparseCPU, SparseCUDA, SparseMPS: sparse_mask_projection
|
||||
autogen: _sparse_mask_projection.out
|
||||
|
||||
- func: _to_cpu(Tensor[] tensors) -> Tensor[]
|
||||
|
||||
@ -30,10 +30,12 @@
|
||||
|
||||
#include <thrust/binary_search.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/distance.h>
|
||||
#include <thrust/iterator/constant_iterator.h>
|
||||
#include <thrust/scan.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/system/cuda/execution_policy.h>
|
||||
#include <thrust/iterator/constant_iterator.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusparse.h>
|
||||
@ -47,6 +49,7 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/distance.h>
|
||||
#include <thrust/for_each.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/gather.h>
|
||||
|
||||
@ -445,6 +445,33 @@ static SparseTensor& mul_out_dense_sparse_mps(
|
||||
return out;
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor, int64_t> mps_intersect_binary_search(
|
||||
const Tensor& A_keys,
|
||||
const Tensor& B_keys,
|
||||
int64_t lenA,
|
||||
int64_t lenB,
|
||||
bool boolean_flag) {
|
||||
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto outA_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
|
||||
auto outB_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
|
||||
auto counter = at::zeros({1}, A_keys.options().dtype(at::kInt));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), boolean_flag);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const auto match_count = static_cast<int64_t>(counter.item<int32_t>());
|
||||
return std::make_tuple(std::move(outA_idx), std::move(outB_idx), match_count);
|
||||
}
|
||||
|
||||
|
||||
SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTensor& r_) {
|
||||
TORCH_CHECK(r_.is_mps(), "mul: expected 'out' to be MPS, but got ", r_.device());
|
||||
@ -523,22 +550,10 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
||||
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
|
||||
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
|
||||
|
||||
auto outA_idx = at::empty({lenA}, at::device(device).dtype(kLong));
|
||||
auto outB_idx = at::empty({lenA}, at::device(device).dtype(kLong));
|
||||
auto counter = at::zeros({1}, at::device(device).dtype(kInt));
|
||||
auto [outA_idx, outB_idx, M_int64] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_lhs);
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), A_is_lhs);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const uint32_t M = counter.item<int32_t>(); // number of structural matches
|
||||
const auto M = static_cast<uint32_t>(M_int64); // number of structural matches
|
||||
|
||||
r_.resize_as_(lhs);
|
||||
|
||||
@ -762,6 +777,14 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self,
|
||||
|
||||
using OptTensor = std::optional<Tensor>;
|
||||
|
||||
static Tensor create_sparse_output_values(
|
||||
const Tensor& template_values,
|
||||
int64_t output_nnz,
|
||||
ScalarType dtype) {
|
||||
auto out_val_sizes = template_values.sizes().vec();
|
||||
out_val_sizes[0] = output_nnz;
|
||||
return at::zeros(out_val_sizes, template_values.options().dtype(dtype));
|
||||
}
|
||||
|
||||
static void sparse_mask_apply_out_mps_kernel(
|
||||
Tensor& result,
|
||||
@ -783,9 +806,9 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
auto src = src_in.coalesce();
|
||||
auto mask = coalesce_mask ? mask_in.coalesce() : mask_in;
|
||||
|
||||
const int64_t src_nnz = src._nnz();
|
||||
const int64_t mask_nnz = mask._nnz();
|
||||
const int64_t sd = src.sparse_dim();
|
||||
const auto src_nnz = src._nnz();
|
||||
const auto mask_nnz = mask._nnz();
|
||||
const auto sd = src.sparse_dim();
|
||||
result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim());
|
||||
|
||||
auto commonDtype = at::result_type(src, mask);
|
||||
@ -814,53 +837,27 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_indices = mask._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype).contiguous();
|
||||
auto out_values = create_sparse_output_values(src_values, mask_nnz, commonDtype);
|
||||
|
||||
if (src_nnz == 0) {
|
||||
auto out_indices = mask._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype);
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
alias_into_sparse(result, out_indices, out_values);
|
||||
alias_into_sparse(result, mask_indices, out_values);
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_indices = mask._indices().contiguous();
|
||||
auto src_indices = src._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype).contiguous();
|
||||
auto mask_keys = flatten_indices(mask._indices().contiguous(), mask.sizes().slice(0, sd)).contiguous();
|
||||
auto src_keys = flatten_indices(src._indices().contiguous(), src.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous();
|
||||
auto src_keys = flatten_indices(src_indices, src.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
const bool A_is_src = (src_nnz <= mask_nnz);
|
||||
const int64_t lenA = A_is_src ? src_nnz : mask_nnz;
|
||||
const int64_t lenB = A_is_src ? mask_nnz : src_nnz;
|
||||
const auto A_is_src = (src_nnz <= mask_nnz);
|
||||
const auto lenA = A_is_src ? src_nnz : mask_nnz;
|
||||
const auto lenB = A_is_src ? mask_nnz : src_nnz;
|
||||
auto A_keys = A_is_src ? src_keys : mask_keys;
|
||||
auto B_keys = A_is_src ? mask_keys : src_keys;
|
||||
|
||||
const auto device = result.device();
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto counter = at::zeros({1}, at::device(device).dtype(at::kInt));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), A_is_src);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const int64_t M = static_cast<int64_t>(counter.item<int32_t>());
|
||||
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_src);
|
||||
|
||||
if (M > 0) {
|
||||
auto src_match = outA_idx.narrow(0, 0, M);
|
||||
@ -878,6 +875,70 @@ static void sparse_mask_apply_out_mps_kernel(
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
}
|
||||
|
||||
static void sparse_mask_projection_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& lhs,
|
||||
const Tensor& rhs,
|
||||
const OptTensor& /*x_hash_opt*/,
|
||||
bool accumulate_matches) {
|
||||
|
||||
TORCH_CHECK(lhs.is_sparse() && rhs.is_sparse(), "sparse_mask_projection: expected sparse COO");
|
||||
TORCH_CHECK(lhs.is_mps() && rhs.is_mps(), "sparse_mask_projection: expected MPS tensors");
|
||||
TORCH_CHECK(lhs.sparse_dim() == rhs.sparse_dim(), "sparse_dim mismatch");
|
||||
|
||||
auto lhs_c = lhs.coalesce();
|
||||
auto rhs_c = rhs.coalesce();
|
||||
|
||||
const auto sd = lhs_c.sparse_dim();
|
||||
const auto lhs_nnz = lhs_c._nnz();
|
||||
const auto rhs_nnz = rhs_c._nnz();
|
||||
|
||||
auto commonDtype = at::result_type(lhs_c, rhs_c);
|
||||
TORCH_CHECK(canCast(commonDtype, result.scalar_type()),
|
||||
"Can't convert ", commonDtype, " to output ", result.scalar_type());
|
||||
|
||||
result.sparse_resize_(lhs.sizes(), lhs.sparse_dim(), lhs.dense_dim());
|
||||
|
||||
auto lhs_indices = lhs_c._indices().contiguous();
|
||||
auto rhs_values = rhs_c._values().to(commonDtype).contiguous();
|
||||
auto out_values = create_sparse_output_values(rhs_values, lhs_nnz, commonDtype);
|
||||
|
||||
if (lhs_nnz > 0 && rhs_nnz > 0) {
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs_c.sizes().slice(0, sd)).contiguous();
|
||||
auto rhs_keys = flatten_indices(rhs_c._indices().contiguous(), rhs_c.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
const auto A_is_lhs = (lhs_nnz <= rhs_nnz);
|
||||
const auto lenA = A_is_lhs ? lhs_nnz : rhs_nnz;
|
||||
const auto lenB = A_is_lhs ? rhs_nnz : lhs_nnz;
|
||||
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
|
||||
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
|
||||
|
||||
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
|
||||
A_keys, B_keys, lenA, lenB, A_is_lhs);
|
||||
|
||||
if (M > 0) {
|
||||
auto idx_in_A = outA_idx.narrow(0, 0, M);
|
||||
auto idx_in_B = outB_idx.narrow(0, 0, M);
|
||||
auto idx_in_lhs = A_is_lhs ? idx_in_A : idx_in_B;
|
||||
auto idx_in_rhs = A_is_lhs ? idx_in_B : idx_in_A;
|
||||
|
||||
const auto view_cols = rhs_values.numel() / std::max<int64_t>(rhs_nnz, 1);
|
||||
auto rhs_rows = rhs_values.index_select(0, idx_in_rhs).contiguous();
|
||||
auto rhs_rows_2d = rhs_rows.view({M, view_cols});
|
||||
auto out_2d = out_values.view({lhs_nnz, view_cols});
|
||||
|
||||
if (accumulate_matches) {
|
||||
out_2d.index_add_(0, idx_in_lhs, rhs_rows_2d);
|
||||
} else {
|
||||
out_2d.index_copy_(0, idx_in_lhs, rhs_rows_2d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
alias_into_sparse(result, lhs._indices(), out_values);
|
||||
result._coalesced_(lhs.is_coalesced());
|
||||
}
|
||||
|
||||
static void sparse_mask_intersection_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& lhs,
|
||||
@ -1002,4 +1063,5 @@ Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_mps_kernel);
|
||||
} // namespace at::native
|
||||
@ -1,191 +1,3 @@
|
||||
#pragma once
|
||||
#include <ATen/xpu/XPUContext.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace at::xpu {
|
||||
|
||||
/*
|
||||
* XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
|
||||
* constructed lazily when first recorded. It has a device, and this device is
|
||||
* acquired from the first recording stream. Later streams that record the event
|
||||
* must match the same device.
|
||||
*
|
||||
* Currently, XPUEvent does NOT support to export an inter-process event from
|
||||
* another process via inter-process communication(IPC). So it means that
|
||||
* inter-process communication for event handles between different processes is
|
||||
* not available. This could impact some applications that rely on cross-process
|
||||
* synchronization and communication.
|
||||
*/
|
||||
struct TORCH_XPU_API XPUEvent {
|
||||
// Constructors
|
||||
XPUEvent(bool enable_timing = false) noexcept
|
||||
: enable_timing_{enable_timing} {}
|
||||
|
||||
~XPUEvent() {
|
||||
if (isCreated()) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_deletion(
|
||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
XPUEvent(const XPUEvent&) = delete;
|
||||
XPUEvent& operator=(const XPUEvent&) = delete;
|
||||
|
||||
XPUEvent(XPUEvent&& other) = default;
|
||||
XPUEvent& operator=(XPUEvent&& other) = default;
|
||||
|
||||
operator sycl::event&() const {
|
||||
return event();
|
||||
}
|
||||
|
||||
std::optional<at::Device> device() const {
|
||||
if (isCreated()) {
|
||||
return at::Device(at::kXPU, device_index_);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool isCreated() const {
|
||||
return (event_.get() != nullptr);
|
||||
}
|
||||
|
||||
DeviceIndex device_index() const {
|
||||
return device_index_;
|
||||
}
|
||||
|
||||
sycl::event& event() const {
|
||||
return *event_;
|
||||
}
|
||||
|
||||
bool query() const {
|
||||
using namespace sycl::info;
|
||||
if (!isCreated()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return event().get_info<event::command_execution_status>() ==
|
||||
event_command_status::complete;
|
||||
}
|
||||
|
||||
void record() {
|
||||
record(getCurrentXPUStream());
|
||||
}
|
||||
|
||||
void recordOnce(const XPUStream& stream) {
|
||||
if (!isCreated()) {
|
||||
record(stream);
|
||||
}
|
||||
}
|
||||
|
||||
void record(const XPUStream& stream) {
|
||||
if (!isCreated()) {
|
||||
device_index_ = stream.device_index();
|
||||
assignEvent(stream.queue());
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_creation(
|
||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
device_index_ == stream.device_index(),
|
||||
"Event device ",
|
||||
device_index_,
|
||||
" does not match recording stream's device ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
reassignEvent(stream.queue());
|
||||
}
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_record(
|
||||
at::kXPU,
|
||||
reinterpret_cast<uintptr_t>(event_.get()),
|
||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||
}
|
||||
}
|
||||
|
||||
void block(const XPUStream& stream) {
|
||||
if (isCreated()) {
|
||||
std::vector<sycl::event> event_list{event()};
|
||||
// Make this stream wait until event_ is completed.
|
||||
stream.queue().ext_oneapi_submit_barrier(event_list);
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_wait(
|
||||
at::kXPU,
|
||||
reinterpret_cast<uintptr_t>(event_.get()),
|
||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double elapsed_time(const XPUEvent& other) const {
|
||||
TORCH_CHECK(
|
||||
isCreated() && other.isCreated(),
|
||||
"Both events must be recorded before calculating elapsed time.");
|
||||
TORCH_CHECK(
|
||||
query() && other.query(),
|
||||
"Both events must be completed before calculating elapsed time.");
|
||||
TORCH_CHECK(
|
||||
enable_timing_ && other.enable_timing_,
|
||||
"Both events must be created with argument 'enable_timing=True'.");
|
||||
|
||||
#if SYCL_COMPILER_VERSION < 20250000
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"elapsed_time of XPUEvent requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
|
||||
#endif
|
||||
|
||||
using namespace sycl::info::event_profiling;
|
||||
// Block until both of the recorded events are completed.
|
||||
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
|
||||
uint64_t start_time_ns = event().get_profiling_info<command_end>();
|
||||
// Return the eplased time in milliseconds.
|
||||
return 1e-6 *
|
||||
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
|
||||
}
|
||||
|
||||
void synchronize() const {
|
||||
if (isCreated()) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_synchronization(
|
||||
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
event().wait_and_throw();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void assignEvent(sycl::queue& queue) {
|
||||
#if SYCL_COMPILER_VERSION >= 20250000
|
||||
if (enable_timing_) {
|
||||
event_ = std::make_unique<sycl::event>(
|
||||
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
|
||||
} else {
|
||||
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
||||
}
|
||||
#else
|
||||
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
||||
#endif
|
||||
}
|
||||
|
||||
void reassignEvent(sycl::queue& queue) {
|
||||
event_.reset();
|
||||
assignEvent(queue);
|
||||
}
|
||||
|
||||
bool enable_timing_ = false;
|
||||
DeviceIndex device_index_ = -1;
|
||||
// Only need to track the last event, as events in an in-order queue are
|
||||
// executed sequentially.
|
||||
std::unique_ptr<sycl::event> event_;
|
||||
};
|
||||
|
||||
} // namespace at::xpu
|
||||
#include <c10/xpu/XPUEvent.h>
|
||||
|
||||
@ -50,6 +50,7 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
|
||||
"mobilenet_v2",
|
||||
"pytorch_CycleGAN_and_pix2pix",
|
||||
"pytorch_stargan",
|
||||
"repvgg_a2",
|
||||
"resnet152",
|
||||
"resnet18",
|
||||
"resnet50",
|
||||
|
||||
@ -10,7 +10,7 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -50,7 +50,7 @@ nfnet_l0,pass,7
|
||||
|
||||
|
||||
|
||||
repvgg_a2,fail_accuracy,7
|
||||
repvgg_a2,pass,7
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -952,7 +952,7 @@ def latency_experiment_summary(suite_name, args, model, timings, **kwargs):
|
||||
first_fields.append(kwargs["tag"])
|
||||
headers = first_headers + ["speedup", "abs_latency"]
|
||||
row = first_fields + [float(speedup), median[1] * 1000]
|
||||
msg = f"{speedup:.3f}x"
|
||||
msg = f"{median[0] * 1000} ms, {median[1] * 1000} ms, {speedup:.3f}x"
|
||||
if args.baseline:
|
||||
headers.extend(
|
||||
[
|
||||
@ -1010,7 +1010,7 @@ def latency_experiment_summary(suite_name, args, model, timings, **kwargs):
|
||||
# Hypothetically you can use this from other places, but it's currently
|
||||
# inaccessible, and when this assert fails you need to update the
|
||||
# event_name here to account for the other cases you are using this
|
||||
assert args.quantization is not None
|
||||
assert any([args.quantization, args.optimus])
|
||||
output_signpost(
|
||||
dict(zip(headers, row)),
|
||||
args,
|
||||
@ -2288,11 +2288,9 @@ class BenchmarkRunner:
|
||||
)
|
||||
):
|
||||
is_same = False
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Sometimes torch.allclose may throw RuntimeError
|
||||
exception_string = str(e)
|
||||
accuracy_status = f"fail_exception: {exception_string}"
|
||||
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
||||
is_same = False
|
||||
|
||||
if not is_same:
|
||||
accuracy_status = "eager_two_runs_differ"
|
||||
@ -2409,11 +2407,9 @@ class BenchmarkRunner:
|
||||
force_max_multiplier=force_max_multiplier,
|
||||
):
|
||||
is_same = False
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Sometimes torch.allclose may throw RuntimeError
|
||||
exception_string = str(e)
|
||||
accuracy_status = f"fail_exception: {exception_string}"
|
||||
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
||||
is_same = False
|
||||
|
||||
if not is_same:
|
||||
if self.args.skip_accuracy_check:
|
||||
@ -2587,6 +2583,9 @@ class BenchmarkRunner:
|
||||
**experiment_kwargs,
|
||||
)
|
||||
|
||||
# reset dynamo
|
||||
torch._dynamo.reset()
|
||||
|
||||
if self.args.export_aot_inductor:
|
||||
optimized_model_iter_fn = optimize_ctx
|
||||
else:
|
||||
@ -2950,7 +2949,7 @@ class BenchmarkRunner:
|
||||
status = self.check_tolerance(name, model, example_inputs, optimize_ctx)
|
||||
print(status)
|
||||
elif self.args.performance:
|
||||
if self.args.backend == "torchao":
|
||||
if self.args.backend in ["torchao", "optimus"]:
|
||||
status = self.run_performance_test_non_alternate(
|
||||
name, model, example_inputs, optimize_ctx, experiment, tag
|
||||
)
|
||||
@ -3526,6 +3525,12 @@ def parse_args(args=None):
|
||||
action="store_true",
|
||||
help="Measure speedup with TorchInductor",
|
||||
)
|
||||
group.add_argument(
|
||||
"--optimus",
|
||||
choices=["vertical_opt", "horizontal_opt", "all"],
|
||||
default=None,
|
||||
help="Measure speedup of Optimus with TorchInductor baseline",
|
||||
)
|
||||
group.add_argument(
|
||||
"--quantization",
|
||||
choices=[
|
||||
@ -3783,6 +3788,9 @@ def run(runner, args, original_dir=None):
|
||||
if args.inductor:
|
||||
assert args.backend is None
|
||||
args.backend = "inductor"
|
||||
if args.optimus:
|
||||
assert args.backend is None
|
||||
args.backend = "optimus"
|
||||
if args.quantization:
|
||||
assert args.backend is None
|
||||
args.backend = "torchao"
|
||||
@ -4067,10 +4075,22 @@ def run(runner, args, original_dir=None):
|
||||
|
||||
runner.model_iter_fn = model_iter_fn_and_mark_step
|
||||
optimize_ctx = torchao_optimize_ctx(args.quantization)
|
||||
elif args.backend == "optimus":
|
||||
from .optimus import get_baseline_ctx, get_optimus_optimize_ctx
|
||||
|
||||
baseline_ctx = get_baseline_ctx(
|
||||
nopython=args.nopython, inductor_compile_mode=args.inductor_compile_mode
|
||||
)
|
||||
runner.model_iter_fn = baseline_ctx(runner.model_iter_fn)
|
||||
optimize_ctx = get_optimus_optimize_ctx(
|
||||
args.optimus, args.nopython, args.inductor_compile_mode
|
||||
)
|
||||
else:
|
||||
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
|
||||
experiment = (
|
||||
speedup_experiment if args.backend != "torchao" else latency_experiment
|
||||
speedup_experiment
|
||||
if args.backend not in ["torchao", "optimus"]
|
||||
else latency_experiment
|
||||
)
|
||||
if args.accuracy:
|
||||
output_filename = f"accuracy_{args.backend}.csv"
|
||||
@ -4091,7 +4111,12 @@ def run(runner, args, original_dir=None):
|
||||
if args.only in runner.disable_cudagraph_models:
|
||||
args.disable_cudagraphs = True
|
||||
|
||||
if args.inductor or args.backend == "inductor" or args.export_aot_inductor:
|
||||
if (
|
||||
args.inductor
|
||||
or args.backend == "inductor"
|
||||
or args.export_aot_inductor
|
||||
or args.backend == "optimus"
|
||||
):
|
||||
inductor_config.triton.cudagraphs = not args.disable_cudagraphs
|
||||
inductor_config.triton.persistent_reductions = (
|
||||
not args.disable_persistent_reductions
|
||||
|
||||
62
benchmarks/dynamo/optimus.py
Normal file
62
benchmarks/dynamo/optimus.py
Normal file
@ -0,0 +1,62 @@
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_baseline_ctx(nopython, inductor_compile_mode):
|
||||
return functools.partial(
|
||||
torch.compile,
|
||||
backend="inductor",
|
||||
fullgraph=nopython,
|
||||
mode=inductor_compile_mode,
|
||||
)
|
||||
|
||||
|
||||
def get_optimus_optimize_ctx(config, nopython, inductor_compile_mode):
|
||||
if config == "vertical_opt":
|
||||
optimus_inductor_config = {
|
||||
"pre_grad_fusion_options": {
|
||||
"normalization_pass": {},
|
||||
"merge_splits_pass": {},
|
||||
"split_cat_pass": {},
|
||||
"unbind_stack_pass": {},
|
||||
"unbind_cat_to_view_pass": {},
|
||||
}
|
||||
}
|
||||
elif config == "horizontal_opt":
|
||||
optimus_inductor_config = {
|
||||
"pre_grad_fusion_options": {
|
||||
"normalization_pass": {},
|
||||
"batch_linear": {},
|
||||
"batch_layernorm": {},
|
||||
},
|
||||
}
|
||||
elif config == "all":
|
||||
optimus_inductor_config = {
|
||||
"pre_grad_fusion_options": {
|
||||
"normalization_pass": {},
|
||||
"batch_linear": {},
|
||||
"batch_layernorm": {},
|
||||
"merge_splits_pass": {},
|
||||
"split_cat_pass": {},
|
||||
"unbind_stack_pass": {},
|
||||
"unbind_cat_to_view_pass": {},
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Unknown optimus config: {config}")
|
||||
|
||||
def _inner(fn):
|
||||
if "pre_grad_fusion_options" in optimus_inductor_config:
|
||||
torch._inductor.config.pre_grad_fusion_options = optimus_inductor_config[
|
||||
"pre_grad_fusion_options"
|
||||
]
|
||||
if "post_grad_fusion_options" in optimus_inductor_config:
|
||||
torch._inductor.config.post_grad_fusion_options = optimus_inductor_config[
|
||||
"post_grad_fusion_options"
|
||||
]
|
||||
return torch.compile(
|
||||
fn, backend="inductor", fullgraph=nopython, mode=inductor_compile_mode
|
||||
)
|
||||
|
||||
return _inner
|
||||
@ -2,6 +2,7 @@ import csv
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# This script takes the logs produced by the benchmark scripts (e.g.,
|
||||
@ -15,8 +16,7 @@ import sys
|
||||
# This script is not very well written, feel free to rewrite it as necessary
|
||||
|
||||
assert len(sys.argv) == 2
|
||||
|
||||
full_log = open(sys.argv[1]).read()
|
||||
full_log = Path(sys.argv[1]).read_text()
|
||||
|
||||
# If the log contains a gist URL, extract it so we can include it in the CSV
|
||||
gist_url = ""
|
||||
|
||||
62
benchmarks/dynamo/pr_time_benchmarks/benchmarks/dtensor.py
Normal file
62
benchmarks/dynamo/pr_time_benchmarks/benchmarks/dtensor.py
Normal file
@ -0,0 +1,62 @@
|
||||
import sys
|
||||
|
||||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DTensor, Replicate
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
|
||||
class BenchmarkDTensorDispatch(BenchmarkBase):
|
||||
def __init__(self, operator, world_size) -> None:
|
||||
super().__init__(
|
||||
category=f"dtensor_dispatch_{operator}",
|
||||
device="cuda",
|
||||
)
|
||||
self.world_size = world_size
|
||||
|
||||
def name(self) -> str:
|
||||
prefix = f"{self.category()}"
|
||||
return prefix
|
||||
|
||||
def description(self) -> str:
|
||||
return f"DTensor dispatch time for {self.category()}"
|
||||
|
||||
def _prepare_once(self) -> None:
|
||||
self.mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
"cuda", (self.world_size,), mesh_dim_names=("dp",)
|
||||
)
|
||||
self.a = DTensor.from_local(
|
||||
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
|
||||
)
|
||||
self.b = DTensor.from_local(
|
||||
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
|
||||
)
|
||||
|
||||
def _prepare(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkDetach(BenchmarkDTensorDispatch):
|
||||
def __init__(self, world_size) -> None:
|
||||
super().__init__(operator="detach", world_size=world_size)
|
||||
|
||||
def _work(self) -> None:
|
||||
self.a.detach()
|
||||
|
||||
|
||||
def main():
|
||||
world_size = 256
|
||||
fake_store = FakeStore()
|
||||
torch.distributed.init_process_group(
|
||||
"fake", store=fake_store, rank=0, world_size=world_size
|
||||
)
|
||||
result_path = sys.argv[1]
|
||||
BenchmarkDetach(world_size).enable_instruction_count().collect_all().append_results(
|
||||
result_path
|
||||
)
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -484,24 +484,106 @@ PyTorch,sum,sum_R256_V512_dim0_contiguousTrue_cpu,short,False,50.954394,0.000000
|
||||
PyTorch,sum,sum_R256_V512_dim0_contiguousFalse_cpu,short,False,57.957757,0.000000
|
||||
PyTorch,sum,sum_R256_V512_dim1_contiguousTrue_cpu,short,False,53.592068,0.000000
|
||||
PyTorch,sum,sum_R256_V512_dim1_contiguousFalse_cpu,short,False,51.339726,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N16_cpu,short,False,7.040985,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N64_cpu,short,False,7.168604,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N128_cpu,short,False,7.434442,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N16_cpu,short,False,7.078318,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N64_cpu,short,False,7.426670,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N128_cpu,short,False,7.679027,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N16_cpu,short,False,7.281365,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N64_cpu,short,False,7.682783,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N128_cpu,short,False,8.381938,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N16_cpu,short,False,7.039854,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N64_cpu,short,False,7.399855,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N128_cpu,short,False,7.715193,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N16_cpu,short,False,7.255140,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N64_cpu,short,False,7.753522,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N128_cpu,short,False,8.364281,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N16_cpu,short,False,7.476377,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N64_cpu,short,False,8.458564,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N128_cpu,short,False,9.391939,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,0.927,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.uint8,short,False,6.261,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int8,short,False,6.351,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int16,short,False,6.177,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int32,short,False,6.333,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int64,short,False,6.588,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float16,short,False,8.117,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bfloat16,short,False,9.358,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float32,short,False,7.844,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float64,short,False,8.097,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bool,short,False,6.159,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,0.926,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int8,short,False,6.192,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int16,short,False,6.276,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,6.461,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int64,short,False,6.524,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float16,short,False,8.136,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bfloat16,short,False,6.854,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float32,short,False,6.446,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float64,short,False,6.829,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bool,short,False,6.088,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.uint8,short,False,6.059,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int8,short,False,0.922,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int16,short,False,6.263,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int32,short,False,6.330,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int64,short,False,6.688,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float16,short,False,8.176,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bfloat16,short,False,6.959,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float32,short,False,6.430,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float64,short,False,6.818,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bool,short,False,6.350,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.uint8,short,False,6.221,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int8,short,False,6.193,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int16,short,False,0.922,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int32,short,False,6.263,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int64,short,False,6.525,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float16,short,False,7.960,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bfloat16,short,False,6.801,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float32,short,False,6.594,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float64,short,False,7.089,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bool,short,False,6.498,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,6.358,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int8,short,False,6.390,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int16,short,False,6.415,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,0.925,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int64,short,False,6.657,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float16,short,False,7.954,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bfloat16,short,False,6.930,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float32,short,False,6.737,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float64,short,False,6.948,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bool,short,False,6.757,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.uint8,short,False,6.402,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int8,short,False,6.550,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int16,short,False,6.518,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int32,short,False,6.766,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int64,short,False,0.929,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float16,short,False,8.557,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bfloat16,short,False,9.045,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float32,short,False,7.672,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float64,short,False,7.276,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bool,short,False,6.414,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.uint8,short,False,7.736,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int8,short,False,7.889,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int16,short,False,8.170,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int32,short,False,7.783,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int64,short,False,7.743,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float16,short,False,0.927,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bfloat16,short,False,7.018,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float32,short,False,8.428,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float64,short,False,6.767,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bool,short,False,6.479,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.uint8,short,False,7.827,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int8,short,False,6.450,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int16,short,False,6.320,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int32,short,False,6.385,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int64,short,False,8.119,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float16,short,False,8.063,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bfloat16,short,False,0.925,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float32,short,False,8.629,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float64,short,False,6.638,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bool,short,False,6.425,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.uint8,short,False,7.803,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int8,short,False,6.502,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int16,short,False,6.429,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int32,short,False,6.549,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int64,short,False,7.749,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float16,short,False,7.301,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bfloat16,short,False,7.682,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,0.930,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float64,short,False,6.738,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bool,short,False,6.798,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.uint8,short,False,6.506,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int8,short,False,6.494,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int16,short,False,6.668,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int32,short,False,6.696,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int64,short,False,7.115,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float16,short,False,7.910,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bfloat16,short,False,7.410,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float32,short,False,6.868,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float64,short,False,0.924,0.000000
|
||||
PyTorch,addcmul,addcmul_M1_N2_cpu_dtypetorch.float32,short,False,4.461410,0.000000
|
||||
PyTorch,addcmul,addcmul_M1_N2_cpu_dtypetorch.bfloat16,short,False,4.560082,0.000000
|
||||
PyTorch,addcmul,addcmul_M32_N64_cpu_dtypetorch.float32,short,False,5.141248,0.000000
|
||||
|
||||
|
@ -4,74 +4,84 @@ import torch
|
||||
|
||||
|
||||
tensor_conversion_short_configs = op_bench.cross_product_configs(
|
||||
M=(
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
),
|
||||
N=(
|
||||
16,
|
||||
64,
|
||||
128,
|
||||
),
|
||||
M=[32],
|
||||
N=[128],
|
||||
device=["cpu", "cuda"],
|
||||
dtype_one=[
|
||||
torch.bool,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.half,
|
||||
torch.bfloat16,
|
||||
torch.float,
|
||||
torch.double,
|
||||
],
|
||||
dtype_two=[
|
||||
torch.bool,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.half,
|
||||
torch.bfloat16,
|
||||
torch.float,
|
||||
torch.double,
|
||||
],
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
tensor_conversion_long_configs = op_bench.cross_product_configs(
|
||||
M=(
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
),
|
||||
N=(
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
),
|
||||
M=[1024],
|
||||
N=[1024],
|
||||
device=["cpu", "cuda"],
|
||||
dtype_one=[
|
||||
torch.bool,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.half,
|
||||
torch.bfloat16,
|
||||
torch.float,
|
||||
torch.double,
|
||||
],
|
||||
dtype_two=[
|
||||
torch.bool,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.half,
|
||||
torch.bfloat16,
|
||||
torch.float,
|
||||
torch.double,
|
||||
],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class FloatToHalfTensorConversionBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, device):
|
||||
class TensorConversionBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, dtype_one, dtype_two, device):
|
||||
self.inputs = {
|
||||
"input": torch.rand(
|
||||
M, N, device=device, requires_grad=False, dtype=torch.float
|
||||
)
|
||||
).to(dtype=dtype_one)
|
||||
}
|
||||
self.dtype_one = dtype_one
|
||||
self.dtype_two = dtype_two
|
||||
|
||||
def forward(self, input):
|
||||
return input.to(torch.half)
|
||||
return input.to(dtype=self.dtype_two)
|
||||
|
||||
|
||||
class HalfToFloatTensorConversionBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, device):
|
||||
self.inputs = {
|
||||
"input": torch.rand(
|
||||
M, N, device=device, requires_grad=False, dtype=torch.half
|
||||
)
|
||||
}
|
||||
|
||||
def forward(self, input):
|
||||
return input.to(torch.float)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(
|
||||
tensor_conversion_short_configs, FloatToHalfTensorConversionBenchmark
|
||||
)
|
||||
op_bench.generate_pt_test(
|
||||
tensor_conversion_long_configs, FloatToHalfTensorConversionBenchmark
|
||||
)
|
||||
op_bench.generate_pt_test(
|
||||
tensor_conversion_short_configs, HalfToFloatTensorConversionBenchmark
|
||||
)
|
||||
op_bench.generate_pt_test(
|
||||
tensor_conversion_long_configs, HalfToFloatTensorConversionBenchmark
|
||||
)
|
||||
op_bench.generate_pt_test(tensor_conversion_short_configs, TensorConversionBenchmark)
|
||||
op_bench.generate_pt_test(tensor_conversion_long_configs, TensorConversionBenchmark)
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
|
||||
@ -349,24 +349,106 @@ PyTorch,sum,sum_R256_V512_dim0_contiguousTrue_cpu,short,FALSE,12.5841
|
||||
PyTorch,sum,sum_R256_V512_dim0_contiguousFALSE_cpu,short,FALSE,20.8765
|
||||
PyTorch,sum,sum_R256_V512_dim1_contiguousTrue_cpu,short,FALSE,15.4414
|
||||
PyTorch,sum,sum_R256_V512_dim1_contiguousFALSE_cpu,short,FALSE,15.3287
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N16_cpu,short,FALSE,5.0499
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N64_cpu,short,FALSE,5.3229
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N128_cpu,short,FALSE,5.4418
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N16_cpu,short,FALSE,5.0868
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N64_cpu,short,FALSE,5.4495
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N128_cpu,short,FALSE,5.5578
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N16_cpu,short,FALSE,5.2631
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N64_cpu,short,FALSE,5.5646
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N128_cpu,short,FALSE,5.7898
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N16_cpu,short,FALSE,5.0228
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N64_cpu,short,FALSE,5.3692
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N128_cpu,short,FALSE,5.4006
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N16_cpu,short,FALSE,5.1107
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N64_cpu,short,FALSE,5.4119
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N128_cpu,short,FALSE,5.5583
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N16_cpu,short,FALSE,5.3818
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N64_cpu,short,FALSE,5.5742
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N128_cpu,short,FALSE,6.8414
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,0.797
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.uint8,short,False,6.071
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int8,short,False,6.031
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int16,short,False,6.243
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int32,short,False,7.231
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int64,short,False,7.791
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float16,short,False,12.661
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bfloat16,short,False,11.225
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float32,short,False,9.772
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float64,short,False,9.872
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bool,short,False,6.033
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,0.781
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int8,short,False,6.060
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int16,short,False,6.180
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.258
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int64,short,False,7.758
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float16,short,False,10.504
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bfloat16,short,False,6.749
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float32,short,False,7.679
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float64,short,False,7.797
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bool,short,False,6.019
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.uint8,short,False,6.079
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int8,short,False,0.785
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int16,short,False,6.188
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int32,short,False,7.288
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int64,short,False,7.770
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float16,short,False,10.466
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bfloat16,short,False,6.676
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float32,short,False,7.736
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float64,short,False,7.780
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bool,short,False,6.130
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.uint8,short,False,6.221
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int8,short,False,6.101
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int16,short,False,0.791
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int32,short,False,6.254
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int64,short,False,7.733
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float16,short,False,10.562
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bfloat16,short,False,6.704
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float32,short,False,7.819
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float64,short,False,8.276
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bool,short,False,6.361
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,6.364
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int8,short,False,6.309
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int16,short,False,6.362
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,0.791
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int64,short,False,7.746
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float16,short,False,9.462
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bfloat16,short,False,6.678
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float32,short,False,7.827
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float64,short,False,8.200
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bool,short,False,6.925
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.uint8,short,False,6.947
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int8,short,False,6.962
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int16,short,False,6.906
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int32,short,False,7.664
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int64,short,False,0.782
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float16,short,False,10.528
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bfloat16,short,False,10.123
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float32,short,False,9.234
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float64,short,False,8.694
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bool,short,False,12.653
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.uint8,short,False,9.348
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int8,short,False,8.774
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int16,short,False,9.063
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int32,short,False,10.012
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int64,short,False,13.641
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float16,short,False,0.788
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bfloat16,short,False,13.757
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float32,short,False,7.170
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float64,short,False,12.511
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bool,short,False,6.516
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.uint8,short,False,8.539
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int8,short,False,6.483
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int16,short,False,6.468
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int32,short,False,7.752
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int64,short,False,9.868
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float16,short,False,10.556
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bfloat16,short,False,0.792
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float32,short,False,7.577
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float64,short,False,8.267
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bool,short,False,6.819
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.uint8,short,False,7.715
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int8,short,False,6.754
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int16,short,False,6.825
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int32,short,False,7.790
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int64,short,False,9.219
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float16,short,False,5.977
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bfloat16,short,False,7.069
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,0.794
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float64,short,False,8.301
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bool,short,False,7.401
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.uint8,short,False,7.843
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int8,short,False,7.117
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int16,short,False,7.170
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int32,short,False,8.000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int64,short,False,9.284
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float16,short,False,7.179
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bfloat16,short,False,7.645
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float32,short,False,7.988
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float64,short,False,0.792
|
||||
PyTorch,relu,"relu_dims(3,4,5)_contigFALSE_inplaceFALSE_dtypetorch.quint8",short,FALSE,9.4657
|
||||
PyTorch,relu,"relu_dims(3,4,5)_contigFALSE_inplaceFALSE_dtypetorch.qint8",short,FALSE,9.4625
|
||||
PyTorch,relu,"relu_dims(3,4,5)_contigFALSE_inplaceFALSE_dtypetorch.qint32",short,FALSE,9.4165
|
||||
|
||||
|
@ -83,10 +83,13 @@ if __name__ == "__main__":
|
||||
|
||||
if args.outfile == "stdout":
|
||||
outfile = sys.stdout
|
||||
need_close = False
|
||||
elif args.outfile == "stderr":
|
||||
outfile = sys.stderr
|
||||
need_close = False
|
||||
else:
|
||||
outfile = open(args.outfile, "a")
|
||||
need_close = True
|
||||
|
||||
test_count = args.test_count
|
||||
m = args.m
|
||||
@ -147,3 +150,5 @@ if __name__ == "__main__":
|
||||
time,
|
||||
file=outfile,
|
||||
)
|
||||
if need_close:
|
||||
outfile.close()
|
||||
|
||||
@ -82,10 +82,13 @@ if __name__ == "__main__":
|
||||
|
||||
if args.outfile == "stdout":
|
||||
outfile = sys.stdout
|
||||
need_close = False
|
||||
elif args.outfile == "stderr":
|
||||
outfile = sys.stderr
|
||||
need_close = False
|
||||
else:
|
||||
outfile = open(args.outfile, "a")
|
||||
need_close = True
|
||||
|
||||
test_count = args.test_count
|
||||
m = args.m
|
||||
@ -132,3 +135,5 @@ if __name__ == "__main__":
|
||||
time_csr,
|
||||
file=outfile,
|
||||
)
|
||||
if need_close:
|
||||
outfile.close()
|
||||
|
||||
@ -179,10 +179,13 @@ if __name__ == "__main__":
|
||||
|
||||
if args.outfile == "stdout":
|
||||
outfile = sys.stdout
|
||||
need_close = False
|
||||
elif args.outfile == "stderr":
|
||||
outfile = sys.stderr
|
||||
need_close = False
|
||||
else:
|
||||
outfile = open(args.outfile, "a")
|
||||
need_close = True
|
||||
|
||||
ops = args.ops.split(",")
|
||||
|
||||
@ -434,3 +437,5 @@ if __name__ == "__main__":
|
||||
if op not in {"bsr_scatter_mm6", "bsr_dense_mm_with_meta"}:
|
||||
# Break on operations that do not consume parameters
|
||||
break
|
||||
if need_close:
|
||||
outfile.close()
|
||||
|
||||
@ -125,6 +125,17 @@ AttentionType = Literal[
|
||||
]
|
||||
DtypeString = Literal["bfloat16", "float16", "float32"]
|
||||
SpeedupType = Literal["fwd", "bwd"]
|
||||
# Operator Name mapping
|
||||
backend_to_operator_name = {
|
||||
"math": "math attention kernel",
|
||||
"efficient": "efficient attention kernel",
|
||||
"cudnn": "cudnn attention kernel",
|
||||
"fav2": "flash attention 2 kernel",
|
||||
"fav3": "flash attention 3 kernel",
|
||||
"fakv": "flash attention kv cache kernel",
|
||||
"og-eager": "eager attention kernel",
|
||||
"flex": "flex attention kernel",
|
||||
}
|
||||
|
||||
|
||||
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
||||
@ -1265,12 +1276,14 @@ def _output_json_for_dashboard(
|
||||
model: ModelInfo
|
||||
metric: MetricInfo
|
||||
|
||||
operator_name = backend_to_operator_name.get(backend, backend)
|
||||
|
||||
# Benchmark extra info
|
||||
benchmark_extra_info = {
|
||||
"input_config": input_config,
|
||||
"device": device,
|
||||
"arch": device_arch,
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
"attn_type": config.attn_type,
|
||||
"shape": str(config.shape),
|
||||
"max_autotune": config.max_autotune,
|
||||
@ -1288,7 +1301,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
"attn_type": config.attn_type,
|
||||
},
|
||||
),
|
||||
@ -1315,7 +1328,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
@ -1341,7 +1354,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
@ -1371,7 +1384,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
|
||||
@ -19,6 +19,17 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
using CaptureId_t = unsigned long long;
|
||||
// first is set if the instance is created by CUDAGraph::capture_begin.
|
||||
// second is set if the instance is created by at::cuda::graph_pool_handle.
|
||||
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
|
||||
|
||||
struct MempoolIdHash {
|
||||
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
|
||||
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
|
||||
}
|
||||
};
|
||||
|
||||
// A DataPtr is a unique pointer (with an attached deleter and some
|
||||
// context for the deleter) to some memory, which also records what
|
||||
// device is for its data.
|
||||
|
||||
@ -99,7 +99,10 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
|
||||
// Return the free memory size and total memory size in bytes for the
|
||||
// specified device.
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "getMemoryInfo is not implemented for this allocator yet.");
|
||||
}
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
|
||||
@ -27,6 +27,7 @@
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -205,6 +206,12 @@ inline bool isSignedType(ScalarType t) {
|
||||
break;
|
||||
// Do not add default here, but rather define behavior of every new entry
|
||||
// here. `-Wswitch-enum` would raise a warning in those cases.
|
||||
// TODO: get PyTorch to adopt exhaustive switches by default with a way to
|
||||
// opt specific switches to being non-exhaustive.
|
||||
// Exhaustive:
|
||||
// `-Wswitch-enum`, `-Wswitch-default`, `-Wno-covered-switch-default`
|
||||
// Non-Exhaustive:
|
||||
// `-Wno-switch-enum`, `-Wswitch-default`, `-Wcovered-switch-default`
|
||||
}
|
||||
TORCH_CHECK(false, "Unknown ScalarType ", t);
|
||||
#undef CASE_ISSIGNED
|
||||
|
||||
@ -57,6 +57,8 @@ C10_DECLARE_bool(caffe2_keep_on_shrink);
|
||||
// respect caffe2_keep_on_shrink.
|
||||
C10_DECLARE_int64(caffe2_max_keep_on_shrink_memory);
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default")
|
||||
|
||||
namespace at {
|
||||
class Tensor;
|
||||
class TensorBase;
|
||||
@ -3303,3 +3305,5 @@ static_assert(
|
||||
#undef C10_GCC_VERSION_MINOR
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
@ -1012,12 +1012,6 @@ PrivatePoolState::PrivatePoolState(
|
||||
}
|
||||
}
|
||||
|
||||
struct MempoolIdHash {
|
||||
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
|
||||
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
|
||||
}
|
||||
};
|
||||
|
||||
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
|
||||
if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) {
|
||||
*ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size);
|
||||
@ -4510,66 +4504,3 @@ std::atomic<CUDAAllocator*> allocator;
|
||||
static BackendStaticInitializer backend_static_initializer;
|
||||
} // namespace cuda::CUDACachingAllocator
|
||||
} // namespace c10
|
||||
|
||||
namespace c10::cuda {
|
||||
|
||||
// uid_ is incremented when a user creates a MemPool,
|
||||
// for example: using graph_pool_handle() or c10::cuda::MemPool().
|
||||
//
|
||||
// uuid_ is incremented when CUDAGraph creates a MemPool
|
||||
// as a result of a user not providing a pool.
|
||||
//
|
||||
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
|
||||
// passed to a function, either by user or CUDAGraphs. For example,
|
||||
// default value of MempoolId_t for capture_begin function is {0, 0}.
|
||||
// That's why uid_ and uuid_ start at 1.
|
||||
std::atomic<CaptureId_t> MemPool::uid_{1};
|
||||
std::atomic<CaptureId_t> MemPool::uuid_{1};
|
||||
|
||||
MemPool::MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created,
|
||||
bool use_on_oom)
|
||||
: allocator_(allocator), is_user_created_(is_user_created) {
|
||||
if (is_user_created_) {
|
||||
id_ = {0, uid_++};
|
||||
} else {
|
||||
id_ = {uuid_++, 0};
|
||||
}
|
||||
device_ = c10::cuda::current_device();
|
||||
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
|
||||
if (use_on_oom) {
|
||||
CUDACachingAllocator::setUseOnOOM(device_, id_);
|
||||
}
|
||||
}
|
||||
|
||||
MemPool::~MemPool() {
|
||||
TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
CUDACachingAllocator::releasePool(device_, id_);
|
||||
c10::cuda::CUDACachingAllocator::emptyCache(id_);
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::id() {
|
||||
return id_;
|
||||
}
|
||||
|
||||
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
int MemPool::use_count() {
|
||||
return CUDACachingAllocator::getPoolUseCount(device_, id_);
|
||||
}
|
||||
|
||||
c10::DeviceIndex MemPool::device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
|
||||
if (is_user_created) {
|
||||
return {0, uid_++};
|
||||
}
|
||||
return {uuid_++, 0};
|
||||
}
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
||||
@ -562,41 +562,7 @@ inline std::string getUserMetadata() {
|
||||
} // namespace c10::cuda::CUDACachingAllocator
|
||||
|
||||
namespace c10::cuda {
|
||||
|
||||
// Keep BC only
|
||||
using c10::CaptureId_t;
|
||||
using c10::MempoolId_t;
|
||||
|
||||
// MemPool represents a pool of memory in a caching allocator. Currently,
|
||||
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
||||
//
|
||||
// An allocator pointer can be passed to the MemPool to define how the
|
||||
// allocations should be done in the pool. For example: using a different
|
||||
// system allocator such as ncclMemAlloc.
|
||||
struct C10_CUDA_API MemPool {
|
||||
MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
|
||||
bool is_user_created = true,
|
||||
bool use_on_oom = false);
|
||||
MemPool(const MemPool&) = delete;
|
||||
MemPool(MemPool&&) = default;
|
||||
MemPool& operator=(const MemPool&) = delete;
|
||||
MemPool& operator=(MemPool&&) = default;
|
||||
~MemPool();
|
||||
|
||||
MempoolId_t id();
|
||||
CUDACachingAllocator::CUDAAllocator* allocator();
|
||||
int use_count();
|
||||
c10::DeviceIndex device();
|
||||
static MempoolId_t graph_pool_handle(bool is_user_created = true);
|
||||
|
||||
private:
|
||||
static std::atomic<CaptureId_t> uid_;
|
||||
static std::atomic<CaptureId_t> uuid_;
|
||||
CUDACachingAllocator::CUDAAllocator* allocator_;
|
||||
bool is_user_created_;
|
||||
MempoolId_t id_;
|
||||
c10::DeviceIndex device_;
|
||||
};
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
||||
@ -295,11 +295,19 @@ DeviceAssertionsData* CUDAKernelLaunchRegistry::
|
||||
C10_CUDA_CHECK_WO_DSA(
|
||||
cudaMallocManaged(&uvm_assertions_ptr, sizeof(DeviceAssertionsData)));
|
||||
|
||||
#if CUDART_VERSION >= 13000
|
||||
cudaMemLocation cpuDevice;
|
||||
cpuDevice.type = cudaMemLocationTypeDevice;
|
||||
cpuDevice.id = cudaCpuDeviceId;
|
||||
#else
|
||||
const auto cpuDevice = cudaCpuDeviceId;
|
||||
#endif
|
||||
|
||||
C10_CUDA_CHECK_WO_DSA(cudaMemAdvise(
|
||||
uvm_assertions_ptr,
|
||||
sizeof(DeviceAssertionsData),
|
||||
cudaMemAdviseSetPreferredLocation,
|
||||
cudaCpuDeviceId));
|
||||
cpuDevice));
|
||||
|
||||
// GPU will establish direct mapping of data in CPU memory, no page faults
|
||||
// will be generated
|
||||
@ -307,7 +315,7 @@ DeviceAssertionsData* CUDAKernelLaunchRegistry::
|
||||
uvm_assertions_ptr,
|
||||
sizeof(DeviceAssertionsData),
|
||||
cudaMemAdviseSetAccessedBy,
|
||||
cudaCpuDeviceId));
|
||||
cpuDevice));
|
||||
|
||||
// Initialize the memory from the CPU; otherwise, pages may have to be created
|
||||
// on demand. We think that UVM documentation indicates that first access may
|
||||
|
||||
111
c10/metal/error.h
Normal file
111
c10/metal/error.h
Normal file
@ -0,0 +1,111 @@
|
||||
#pragma once
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
namespace c10 {
|
||||
namespace metal {
|
||||
C10_METAL_CONSTEXPR unsigned error_message_count = 30;
|
||||
struct ErrorMessage {
|
||||
char file[128];
|
||||
char func[128];
|
||||
char message[250];
|
||||
unsigned int line;
|
||||
};
|
||||
|
||||
struct ErrorMessages {
|
||||
#ifdef __METAL__
|
||||
::metal::atomic<unsigned int> count;
|
||||
#else
|
||||
unsigned int count;
|
||||
#endif
|
||||
ErrorMessage msg[error_message_count];
|
||||
};
|
||||
|
||||
#ifdef __METAL__
|
||||
namespace detail {
|
||||
static uint strncpy(device char* dst, constant const char* src, unsigned len) {
|
||||
uint i = 0;
|
||||
while (src[i] != 0 && i < len - 1) {
|
||||
dst[i] = src[i];
|
||||
i++;
|
||||
}
|
||||
dst[i] = 0;
|
||||
return i;
|
||||
}
|
||||
|
||||
inline uint print_arg(
|
||||
device char* ptr,
|
||||
unsigned len,
|
||||
constant const char* arg) {
|
||||
return strncpy(ptr, arg, len);
|
||||
}
|
||||
|
||||
// Returns number length as string in base10
|
||||
static inline uint base10_length(long num) {
|
||||
uint rc = 1;
|
||||
if (num < 0) {
|
||||
num = -num;
|
||||
rc += 1;
|
||||
}
|
||||
while (num > 9) {
|
||||
num /= 10;
|
||||
rc++;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
// Converts signed integer to string
|
||||
inline uint print_arg(device char* ptr, unsigned len, long arg) {
|
||||
const auto arg_len = base10_length(arg);
|
||||
if (arg_len >= len)
|
||||
return 0;
|
||||
if (arg < 0) {
|
||||
ptr[0] = '-';
|
||||
arg = -arg;
|
||||
}
|
||||
uint idx = 1;
|
||||
do {
|
||||
ptr[arg_len - idx] = '0' + (arg % 10);
|
||||
arg /= 10;
|
||||
idx++;
|
||||
} while (arg > 0);
|
||||
ptr[arg_len] = 0;
|
||||
return arg_len;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void print_args(device char* ptr, unsigned len, T arg) {
|
||||
print_arg(ptr, len, arg);
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline void print_args(device char* ptr, unsigned len, T arg, Args... args) {
|
||||
const auto rc = print_arg(ptr, len, arg);
|
||||
print_args(ptr + rc, len - rc, args...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename... Args>
|
||||
static void report_error(
|
||||
device ErrorMessages* msgs,
|
||||
constant const char* file,
|
||||
int line,
|
||||
constant const char* func,
|
||||
Args... args) {
|
||||
const auto idx =
|
||||
atomic_fetch_add_explicit(&msgs->count, 1, ::metal::memory_order_relaxed);
|
||||
if (idx >= error_message_count) {
|
||||
return;
|
||||
}
|
||||
device auto* msg = &msgs->msg[idx];
|
||||
detail::strncpy(msg->file, file, 128);
|
||||
detail::strncpy(msg->func, func, 128);
|
||||
detail::print_args(msg->message, 250, args...);
|
||||
msg->line = line;
|
||||
}
|
||||
|
||||
#define TORCH_REPORT_ERROR(buf, ...) \
|
||||
::c10::metal::report_error(buf, __FILE__, __LINE__, __func__, __VA_ARGS__)
|
||||
#endif
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
@ -1 +0,0 @@
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
@ -1,224 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/TypeList.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace c10::guts {
|
||||
|
||||
/**
|
||||
* Access information about result type or arguments from a function type.
|
||||
* Example:
|
||||
* using A = function_traits<int (float, double)>::return_type // A == int
|
||||
* using A = function_traits<int (float, double)>::parameter_types::tuple_type
|
||||
* // A == tuple<float, double>
|
||||
*/
|
||||
template <class Func>
|
||||
struct function_traits {
|
||||
static_assert(
|
||||
!std::is_same_v<Func, Func>,
|
||||
"In function_traits<Func>, Func must be a plain function type.");
|
||||
};
|
||||
template <class Result, class... Args>
|
||||
struct function_traits<Result(Args...)> {
|
||||
using func_type = Result(Args...);
|
||||
using return_type = Result;
|
||||
using parameter_types = typelist::typelist<Args...>;
|
||||
static constexpr auto number_of_parameters = sizeof...(Args);
|
||||
};
|
||||
|
||||
/**
|
||||
* infer_function_traits: creates a `function_traits` type for a simple
|
||||
* function (pointer) or functor (lambda/struct). Currently does not support
|
||||
* class methods.
|
||||
*/
|
||||
|
||||
template <typename Functor>
|
||||
struct infer_function_traits {
|
||||
using type = function_traits<
|
||||
c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>;
|
||||
};
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
struct infer_function_traits<Result (*)(Args...)> {
|
||||
using type = function_traits<Result(Args...)>;
|
||||
};
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
struct infer_function_traits<Result(Args...)> {
|
||||
using type = function_traits<Result(Args...)>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using infer_function_traits_t = typename infer_function_traits<T>::type;
|
||||
|
||||
/**
|
||||
* make_function_traits: creates a `function_traits` type given a Return type
|
||||
* and a typelist of Argument types
|
||||
*
|
||||
* Example:
|
||||
* bool f(int, int);
|
||||
*
|
||||
* infer_function_traits_t<f> == make_function_traits_t<bool,
|
||||
* typelist::typelist<int, int>>
|
||||
*/
|
||||
template <typename Result, typename ArgList>
|
||||
struct make_function_traits {
|
||||
static_assert(
|
||||
false_t<ArgList>::value,
|
||||
"In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>.");
|
||||
};
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
struct make_function_traits<Result, typelist::typelist<Args...>> {
|
||||
using type = function_traits<Result(Args...)>;
|
||||
};
|
||||
|
||||
template <typename Result, typename ArgList>
|
||||
using make_function_traits_t =
|
||||
typename make_function_traits<Result, ArgList>::type;
|
||||
|
||||
/**
|
||||
* make_offset_index_sequence<Start, N>
|
||||
* Like make_index_sequence<N>, but starting from Start instead of 0.
|
||||
*
|
||||
* Example:
|
||||
* make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12>
|
||||
*/
|
||||
template <size_t Start, size_t N, size_t... Is>
|
||||
struct make_offset_index_sequence_impl
|
||||
: make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> {
|
||||
static_assert(
|
||||
static_cast<int>(Start) >= 0,
|
||||
"make_offset_index_sequence: Start < 0");
|
||||
static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0");
|
||||
};
|
||||
|
||||
template <size_t Start, size_t... Is>
|
||||
struct make_offset_index_sequence_impl<Start, 0, Is...> {
|
||||
typedef std::index_sequence<Is...> type;
|
||||
};
|
||||
|
||||
template <size_t Start, size_t N>
|
||||
using make_offset_index_sequence =
|
||||
typename make_offset_index_sequence_impl<Start, N>::type;
|
||||
|
||||
/**
|
||||
* Use tuple_elements to extract a position-indexed subset of elements
|
||||
* from the argument tuple into a result tuple.
|
||||
*
|
||||
* Example:
|
||||
* std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
|
||||
* std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0,
|
||||
* 2>());
|
||||
*/
|
||||
template <class Tuple, size_t... Is>
|
||||
constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...> /*unused*/) {
|
||||
return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* Use tuple_take to extract the first or last n elements from the argument
|
||||
* tuple into a result tuple.
|
||||
*
|
||||
* Example:
|
||||
* std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
|
||||
* std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t);
|
||||
* std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t);
|
||||
*/
|
||||
template <class Tuple, int N, class Enable = void>
|
||||
struct TupleTake {};
|
||||
|
||||
template <class Tuple, int N>
|
||||
struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> {
|
||||
static auto call(Tuple t) {
|
||||
constexpr size_t size = std::tuple_size<Tuple>();
|
||||
static_assert(N <= size, "tuple_take: N > size");
|
||||
return tuple_elements(t, std::make_index_sequence<N>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Tuple, int N>
|
||||
struct TupleTake < Tuple,
|
||||
N, std::enable_if_t<N<0, void>> {
|
||||
static auto call(Tuple t) {
|
||||
constexpr size_t size = std::tuple_size<Tuple>();
|
||||
static_assert(-N <= size, "tuple_take: -N > size");
|
||||
return tuple_elements(t, make_offset_index_sequence<size + N, -N>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Tuple, int N>
|
||||
auto tuple_take(Tuple t) {
|
||||
return TupleTake<Tuple, N>::call(t);
|
||||
}
|
||||
|
||||
/**
|
||||
* Use tuple_slice to extract a contiguous subtuple from the argument.
|
||||
*
|
||||
* Example:
|
||||
* std::tuple<int, const char*, double, bool> t = std::make_tuple(0,
|
||||
* "HEY", 2.0, false); std::tuple<int, const char*> middle_two =
|
||||
* tuple_slice<decltype(t), 1, 2>(t);
|
||||
*/
|
||||
template <class Tuple, size_t Start, size_t N>
|
||||
constexpr auto tuple_slice(Tuple t) {
|
||||
constexpr size_t size = std::tuple_size<Tuple>();
|
||||
static_assert(Start + N <= size, "tuple_slice: Start + N > size");
|
||||
return tuple_elements(t, make_offset_index_sequence<Start, N>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* Use tuple_map to run a mapping function over a tuple to get a new tuple.
|
||||
*
|
||||
* Example 1:
|
||||
* auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), []
|
||||
* (int32_t a) -> int16_t {return a+1;});
|
||||
* // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6)
|
||||
*
|
||||
* Example 2:
|
||||
* struct Mapper {
|
||||
* std::string operator()(int32_t a) const {
|
||||
* return std::to_string(a);
|
||||
* }
|
||||
* int64_t operator()(const std::string& a) const {
|
||||
* return atoi(a.c_str());
|
||||
* }
|
||||
* };
|
||||
* auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"),
|
||||
* Mapper());
|
||||
* // result == std::tuple<std::string, int64_t>("3", 4)
|
||||
*
|
||||
* Example 3:
|
||||
* struct A final {
|
||||
* int32_t func() {
|
||||
* return 5;
|
||||
* }
|
||||
* };
|
||||
* struct B final {
|
||||
* std::string func() {
|
||||
* return "5";
|
||||
* }
|
||||
* };
|
||||
* auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return
|
||||
* a.func(); });
|
||||
* // result == std::tuple<int32_t, std::string>(5, "5");
|
||||
*/
|
||||
namespace detail {
|
||||
template <class Mapper, class... Args, size_t... Indices>
|
||||
auto tuple_map(
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
|
||||
std::tuple<Args...>&& tuple,
|
||||
const Mapper& mapper,
|
||||
std::index_sequence<Indices...> /*unused*/) {
|
||||
return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>(
|
||||
tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template <class Mapper, class... Args>
|
||||
auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) {
|
||||
return detail::tuple_map(
|
||||
std::move(tuple), mapper, std::index_sequence_for<Args...>());
|
||||
}
|
||||
|
||||
} // namespace c10::guts
|
||||
#include <torch/headeronly/util/Metaprogramming.h>
|
||||
|
||||
@ -1,515 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/TypeTraits.h>
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace c10::guts {
|
||||
|
||||
template <class... T>
|
||||
struct false_t : std::false_type {};
|
||||
template <template <class> class... T>
|
||||
struct false_higher_t : std::false_type {};
|
||||
|
||||
namespace typelist {
|
||||
|
||||
/**
|
||||
* Type holding a list of types for compile time type computations
|
||||
*/
|
||||
template <class... Items>
|
||||
struct typelist final {
|
||||
public:
|
||||
typelist() = delete; // not for instantiation
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns the number of types in a typelist
|
||||
* Example:
|
||||
* 3 == size<typelist<int, int, double>>::value
|
||||
*/
|
||||
template <class TypeList>
|
||||
struct size final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::size<T>, T must be typelist<...>.");
|
||||
};
|
||||
template <class... Types>
|
||||
struct size<typelist<Types...>> final {
|
||||
static constexpr size_t value = sizeof...(Types);
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms a list of types into a tuple holding these types.
|
||||
* Example:
|
||||
* std::tuple<int, string> == to_tuple_t<typelist<int, string>>
|
||||
*/
|
||||
template <class TypeList>
|
||||
struct to_tuple final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::to_tuple<T>, T must be typelist<...>.");
|
||||
};
|
||||
template <class... Types>
|
||||
struct to_tuple<typelist<Types...>> final {
|
||||
using type = std::tuple<Types...>;
|
||||
};
|
||||
template <class TypeList>
|
||||
using to_tuple_t = typename to_tuple<TypeList>::type;
|
||||
|
||||
/**
|
||||
* Creates a typelist containing the types of a given tuple.
|
||||
* Example:
|
||||
* typelist<int, string> == from_tuple_t<std::tuple<int, string>>
|
||||
*/
|
||||
template <class Tuple>
|
||||
struct from_tuple final {
|
||||
static_assert(
|
||||
false_t<Tuple>::value,
|
||||
"In typelist::from_tuple<T>, T must be std::tuple<...>.");
|
||||
};
|
||||
template <class... Types>
|
||||
struct from_tuple<std::tuple<Types...>> final {
|
||||
using type = typelist<Types...>;
|
||||
};
|
||||
template <class Tuple>
|
||||
using from_tuple_t = typename from_tuple<Tuple>::type;
|
||||
|
||||
/**
|
||||
* Concatenates multiple type lists.
|
||||
* Example:
|
||||
* typelist<int, string, int> == concat_t<typelist<int, string>,
|
||||
* typelist<int>>
|
||||
*/
|
||||
template <class... TypeLists>
|
||||
struct concat final {
|
||||
static_assert(
|
||||
false_t<TypeLists...>::value,
|
||||
"In typelist::concat<T1, ...>, the T arguments each must be typelist<...>.");
|
||||
};
|
||||
template <class... Head1Types, class... Head2Types, class... TailLists>
|
||||
struct concat<typelist<Head1Types...>, typelist<Head2Types...>, TailLists...>
|
||||
final {
|
||||
using type =
|
||||
typename concat<typelist<Head1Types..., Head2Types...>, TailLists...>::
|
||||
type;
|
||||
};
|
||||
template <class... HeadTypes>
|
||||
struct concat<typelist<HeadTypes...>> final {
|
||||
using type = typelist<HeadTypes...>;
|
||||
};
|
||||
template <>
|
||||
struct concat<> final {
|
||||
using type = typelist<>;
|
||||
};
|
||||
template <class... TypeLists>
|
||||
using concat_t = typename concat<TypeLists...>::type;
|
||||
|
||||
/**
|
||||
* Filters the types in a type list by a type trait.
|
||||
* Examples:
|
||||
* typelist<int&, const string&&> == filter_t<std::is_reference,
|
||||
* typelist<void, string, int&, bool, const string&&, int>>
|
||||
*/
|
||||
template <template <class> class Condition, class TypeList>
|
||||
struct filter final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::filter<Condition, TypeList>, the TypeList argument must be typelist<...>.");
|
||||
};
|
||||
template <template <class> class Condition, class Head, class... Tail>
|
||||
struct filter<Condition, typelist<Head, Tail...>> final {
|
||||
static_assert(
|
||||
is_type_condition<Condition>::value,
|
||||
"In typelist::filter<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
||||
using type = std::conditional_t<
|
||||
Condition<Head>::value,
|
||||
concat_t<
|
||||
typelist<Head>,
|
||||
typename filter<Condition, typelist<Tail...>>::type>,
|
||||
typename filter<Condition, typelist<Tail...>>::type>;
|
||||
};
|
||||
template <template <class> class Condition>
|
||||
struct filter<Condition, typelist<>> final {
|
||||
static_assert(
|
||||
is_type_condition<Condition>::value,
|
||||
"In typelist::filter<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
||||
using type = typelist<>;
|
||||
};
|
||||
template <template <class> class Condition, class TypeList>
|
||||
using filter_t = typename filter<Condition, TypeList>::type;
|
||||
|
||||
/**
|
||||
* Counts how many types in the list fulfill a type trait
|
||||
* Examples:
|
||||
* 2 == count_if<std::is_reference, typelist<void, string, int&, bool, const
|
||||
* string&&, int>>
|
||||
*/
|
||||
template <template <class> class Condition, class TypeList>
|
||||
struct count_if final {
|
||||
static_assert(
|
||||
is_type_condition<Condition>::value,
|
||||
"In typelist::count_if<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
||||
static_assert(
|
||||
is_instantiation_of<typelist, TypeList>::value,
|
||||
"In typelist::count_if<Condition, TypeList>, the TypeList argument must be typelist<...>.");
|
||||
// TODO Direct implementation might be faster
|
||||
static constexpr size_t value = size<filter_t<Condition, TypeList>>::value;
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks if a typelist contains a certain type.
|
||||
* Examples:
|
||||
* contains<typelist<int, string>, string> == true_type
|
||||
* contains<typelist<int, string>, double> == false_type
|
||||
*/
|
||||
namespace detail {
|
||||
template <class TypeList, class Type, class Enable = void>
|
||||
struct contains {};
|
||||
template <class Type>
|
||||
struct contains<typelist<>, Type, void> : std::false_type {};
|
||||
template <class Type, class Head, class... Tail>
|
||||
struct contains<
|
||||
typelist<Head, Tail...>,
|
||||
Type,
|
||||
std::enable_if_t<std::is_same_v<Head, Type>>> : std::true_type {};
|
||||
template <class Type, class Head, class... Tail>
|
||||
struct contains<
|
||||
typelist<Head, Tail...>,
|
||||
Type,
|
||||
std::enable_if_t<!std::is_same_v<Head, Type>>>
|
||||
: contains<typelist<Tail...>, Type> {};
|
||||
} // namespace detail
|
||||
template <class TypeList, class Type>
|
||||
using contains = typename detail::contains<TypeList, Type>::type;
|
||||
|
||||
/**
|
||||
* Returns true iff the type trait is true for all types in the type list
|
||||
* Examples:
|
||||
* true == all<std::is_reference, typelist<int&, const float&&, const
|
||||
* MyClass&>>::value false == all<std::is_reference, typelist<int&, const
|
||||
* float&&, MyClass>>::value
|
||||
*/
|
||||
template <template <class> class Condition, class TypeList>
|
||||
struct all {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::all<Condition, TypeList>, the TypeList argument must be typelist<...>.");
|
||||
};
|
||||
template <template <class> class Condition, class... Types>
|
||||
struct all<Condition, typelist<Types...>>
|
||||
: std::conjunction<Condition<Types>...> {
|
||||
static_assert(
|
||||
is_type_condition<Condition>::value,
|
||||
"In typelist::all<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns true iff the type trait is true for any type in the type list
|
||||
* Examples:
|
||||
* true == true_for_any_type<std::is_reference, typelist<int, const
|
||||
* float&&, const MyClass>>::value false ==
|
||||
* true_for_any_type<std::is_reference, typelist<int, const float,
|
||||
* MyClass>>::value
|
||||
*/
|
||||
template <template <class> class Condition, class TypeList>
|
||||
struct true_for_any_type final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::true_for_any_type<Condition, TypeList>, the TypeList argument must be typelist<...>.");
|
||||
};
|
||||
template <template <class> class Condition, class... Types>
|
||||
struct true_for_any_type<Condition, typelist<Types...>> final
|
||||
: std::disjunction<Condition<Types>...> {
|
||||
static_assert(
|
||||
is_type_condition<Condition>::value,
|
||||
"In typelist::true_for_any_type<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
||||
};
|
||||
|
||||
/**
|
||||
* Maps types of a type list using a type trait
|
||||
* Example:
|
||||
* typelist<int&, double&, string&> == map_t<std::add_lvalue_reference_t,
|
||||
* typelist<int, double, string>>
|
||||
*/
|
||||
template <template <class> class Mapper, class TypeList>
|
||||
struct map final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::map<Mapper, TypeList>, the TypeList argument must be typelist<...>.");
|
||||
};
|
||||
template <template <class> class Mapper, class... Types>
|
||||
struct map<Mapper, typelist<Types...>> final {
|
||||
using type = typelist<Mapper<Types>...>;
|
||||
};
|
||||
template <template <class> class Mapper, class TypeList>
|
||||
using map_t = typename map<Mapper, TypeList>::type;
|
||||
|
||||
/**
|
||||
* Returns the first element of a type list.
|
||||
* Example:
|
||||
* int == head_t<typelist<int, string>>
|
||||
*/
|
||||
template <class TypeList>
|
||||
struct head final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::head<T>, the T argument must be typelist<...>.");
|
||||
};
|
||||
template <class Head, class... Tail>
|
||||
struct head<typelist<Head, Tail...>> final {
|
||||
using type = Head;
|
||||
};
|
||||
template <class TypeList>
|
||||
using head_t = typename head<TypeList>::type;
|
||||
|
||||
/**
|
||||
* Returns the first element of a type list, or the specified default if the
|
||||
* type list is empty. Example: int == head_t<bool, typelist<int, string>>
|
||||
* bool == head_t<bool, typelist<>>
|
||||
*/
|
||||
template <class Default, class TypeList>
|
||||
struct head_with_default final {
|
||||
using type = Default;
|
||||
};
|
||||
template <class Default, class Head, class... Tail>
|
||||
struct head_with_default<Default, typelist<Head, Tail...>> final {
|
||||
using type = Head;
|
||||
};
|
||||
template <class Default, class TypeList>
|
||||
using head_with_default_t = typename head_with_default<Default, TypeList>::type;
|
||||
|
||||
/**
|
||||
* Returns the N-th element of a type list.
|
||||
* Example:
|
||||
* int == element_t<1, typelist<float, int, char>>
|
||||
*/
|
||||
|
||||
/// Base template.
|
||||
template <size_t Index, class TypeList>
|
||||
struct element final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::element<T>, the T argument must be typelist<...>.");
|
||||
};
|
||||
|
||||
/// Successful case, we have reached the zero index and can "return" the head
|
||||
/// type.
|
||||
template <class Head, class... Tail>
|
||||
struct element<0, typelist<Head, Tail...>> {
|
||||
using type = Head;
|
||||
};
|
||||
|
||||
/// Error case, we have an index but ran out of types! It will only be selected
|
||||
/// if `Ts...` is actually empty!
|
||||
template <size_t Index, class... Ts>
|
||||
struct element<Index, typelist<Ts...>> {
|
||||
static_assert(
|
||||
Index < sizeof...(Ts),
|
||||
"Index is out of bounds in typelist::element");
|
||||
};
|
||||
|
||||
/// Shave off types until we hit the <0, Head, Tail...> or <Index> case.
|
||||
template <size_t Index, class Head, class... Tail>
|
||||
struct element<Index, typelist<Head, Tail...>>
|
||||
: element<Index - 1, typelist<Tail...>> {};
|
||||
|
||||
/// Convenience alias.
|
||||
template <size_t Index, class TypeList>
|
||||
using element_t = typename element<Index, TypeList>::type;
|
||||
|
||||
/**
|
||||
* Returns the last element of a type list.
|
||||
* Example:
|
||||
* int == last_t<typelist<int, string>>
|
||||
*/
|
||||
template <class TypeList>
|
||||
struct last final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::last<T>, the T argument must be typelist<...>.");
|
||||
};
|
||||
template <class Head, class... Tail>
|
||||
struct last<typelist<Head, Tail...>> final {
|
||||
using type = typename last<typelist<Tail...>>::type;
|
||||
};
|
||||
template <class Head>
|
||||
struct last<typelist<Head>> final {
|
||||
using type = Head;
|
||||
};
|
||||
template <class TypeList>
|
||||
using last_t = typename last<TypeList>::type;
|
||||
static_assert(std::is_same_v<int, last_t<typelist<double, float, int>>>);
|
||||
|
||||
/**
|
||||
* Take/drop a number of arguments from a typelist.
|
||||
* Example:
|
||||
* typelist<int, string> == take_t<typelist<int, string, bool>, 2>
|
||||
* typelist<bool> == drop_t<typelist<int, string, bool>, 2>
|
||||
*/
|
||||
namespace detail {
|
||||
template <class TypeList, size_t offset, class IndexSequence>
|
||||
struct take_elements final {};
|
||||
|
||||
template <class TypeList, size_t offset, size_t... Indices>
|
||||
struct take_elements<TypeList, offset, std::index_sequence<Indices...>> final {
|
||||
using type = typelist<typename element<offset + Indices, TypeList>::type...>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <class TypeList, size_t num>
|
||||
struct take final {
|
||||
static_assert(
|
||||
is_instantiation_of<typelist, TypeList>::value,
|
||||
"In typelist::take<T, num>, the T argument must be typelist<...>.");
|
||||
static_assert(
|
||||
num <= size<TypeList>::value,
|
||||
"Tried to typelist::take more elements than there are in the list");
|
||||
using type = typename detail::
|
||||
take_elements<TypeList, 0, std::make_index_sequence<num>>::type;
|
||||
};
|
||||
template <class TypeList, size_t num>
|
||||
using take_t = typename take<TypeList, num>::type;
|
||||
|
||||
template <class TypeList, size_t num>
|
||||
struct drop final {
|
||||
static_assert(
|
||||
is_instantiation_of<typelist, TypeList>::value,
|
||||
"In typelist::drop<T, num>, the T argument must be typelist<...>.");
|
||||
static_assert(
|
||||
num <= size<TypeList>::value,
|
||||
"Tried to typelist::drop more elements than there are in the list");
|
||||
using type = typename detail::take_elements<
|
||||
TypeList,
|
||||
num,
|
||||
std::make_index_sequence<size<TypeList>::value - num>>::type;
|
||||
};
|
||||
template <class TypeList, size_t num>
|
||||
using drop_t = typename drop<TypeList, num>::type;
|
||||
|
||||
/**
|
||||
* Like drop, but returns an empty list rather than an assertion error if `num`
|
||||
* is larger than the size of the TypeList.
|
||||
* Example:
|
||||
* typelist<> == drop_if_nonempty_t<typelist<string, bool>, 2>
|
||||
* typelist<> == drop_if_nonempty_t<typelist<int, string, bool>, 3>
|
||||
*/
|
||||
template <class TypeList, size_t num>
|
||||
struct drop_if_nonempty final {
|
||||
static_assert(
|
||||
is_instantiation_of<typelist, TypeList>::value,
|
||||
"In typelist::drop<T, num>, the T argument must be typelist<...>.");
|
||||
using type = typename detail::take_elements<
|
||||
TypeList,
|
||||
std::min(num, size<TypeList>::value),
|
||||
std::make_index_sequence<
|
||||
size<TypeList>::value - std::min(num, size<TypeList>::value)>>::type;
|
||||
};
|
||||
template <class TypeList, size_t num>
|
||||
using drop_if_nonempty_t = typename drop_if_nonempty<TypeList, num>::type;
|
||||
|
||||
/**
|
||||
* Reverses a typelist.
|
||||
* Example:
|
||||
* typelist<int, string> == reverse_t<typelist<string, int>>
|
||||
*/
|
||||
template <class TypeList>
|
||||
struct reverse final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::reverse<T>, the T argument must be typelist<...>.");
|
||||
};
|
||||
template <class Head, class... Tail>
|
||||
struct reverse<typelist<Head, Tail...>> final {
|
||||
using type =
|
||||
concat_t<typename reverse<typelist<Tail...>>::type, typelist<Head>>;
|
||||
};
|
||||
template <>
|
||||
struct reverse<typelist<>> final {
|
||||
using type = typelist<>;
|
||||
};
|
||||
template <class TypeList>
|
||||
using reverse_t = typename reverse<TypeList>::type;
|
||||
|
||||
/**
|
||||
* Find the index of the first type in a typelist fulfilling a type trait
|
||||
* condition. Example:
|
||||
*
|
||||
* 2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value
|
||||
*/
|
||||
template <class TypeList, template <class> class Condition, class Enable = void>
|
||||
struct find_if final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::find_if<TypeList, Condition>, the TypeList argument must be typelist<...>.");
|
||||
};
|
||||
template <template <class> class Condition>
|
||||
struct find_if<typelist<>, Condition, void> final {
|
||||
static_assert(
|
||||
false_higher_t<Condition>::value,
|
||||
"In typelist::find_if<Type/List, Condition>, didn't find any type fulfilling the Condition.");
|
||||
};
|
||||
template <class Head, class... Tail, template <class> class Condition>
|
||||
struct find_if<
|
||||
typelist<Head, Tail...>,
|
||||
Condition,
|
||||
std::enable_if_t<Condition<Head>::value>>
|
||||
final {
|
||||
static constexpr size_t value = 0;
|
||||
};
|
||||
template <class Head, class... Tail, template <class> class Condition>
|
||||
struct find_if<
|
||||
typelist<Head, Tail...>,
|
||||
Condition,
|
||||
std::enable_if_t<!Condition<Head>::value>>
|
||||
final {
|
||||
static constexpr size_t value =
|
||||
1 + find_if<typelist<Tail...>, Condition>::value;
|
||||
};
|
||||
|
||||
/**
|
||||
* Maps a list of types into a list of values.
|
||||
* Examples:
|
||||
* // Example 1
|
||||
* auto sizes =
|
||||
* map_types_to_values<typelist<int64_t, bool, uint32_t>>(
|
||||
* [] (auto t) { return sizeof(decltype(t)::type); }
|
||||
* );
|
||||
* // sizes == std::tuple<size_t, size_t, size_t>{8, 1, 4}
|
||||
*
|
||||
* // Example 2
|
||||
* auto shared_ptrs =
|
||||
* map_types_to_values<typelist<int, double>>(
|
||||
* [] (auto t) { return make_shared<typename decltype(t)::type>(); }
|
||||
* );
|
||||
* // shared_ptrs == std::tuple<shared_ptr<int>, shared_ptr<double>>()
|
||||
*/
|
||||
namespace detail {
|
||||
template <class T>
|
||||
struct type_ final {
|
||||
using type = T;
|
||||
};
|
||||
template <class TypeList>
|
||||
struct map_types_to_values final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::map_types_to_values<T>, the T argument must be typelist<...>.");
|
||||
};
|
||||
template <class... Types>
|
||||
struct map_types_to_values<typelist<Types...>> final {
|
||||
template <class Func>
|
||||
static auto call(Func&& func) {
|
||||
return std::tuple{std::forward<Func>(func)(type_<Types>())...};
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <class TypeList, class Func>
|
||||
auto map_types_to_values(Func&& func) {
|
||||
return detail::map_types_to_values<TypeList>::call(std::forward<Func>(func));
|
||||
}
|
||||
|
||||
} // namespace typelist
|
||||
} // namespace c10::guts
|
||||
#include <torch/headeronly/util/TypeList.h>
|
||||
|
||||
@ -1,151 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace c10::guts {
|
||||
|
||||
/**
|
||||
* is_equality_comparable<T> is true_type iff the equality operator is defined
|
||||
* for T.
|
||||
*/
|
||||
template <class T, class Enable = void>
|
||||
struct is_equality_comparable : std::false_type {};
|
||||
template <class T>
|
||||
struct is_equality_comparable<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<T&>() == std::declval<T&>())>>
|
||||
: std::true_type {};
|
||||
template <class T>
|
||||
using is_equality_comparable_t = typename is_equality_comparable<T>::type;
|
||||
|
||||
/**
|
||||
* is_hashable<T> is true_type iff std::hash is defined for T
|
||||
*/
|
||||
template <class T, class Enable = void>
|
||||
struct is_hashable : std::false_type {};
|
||||
template <class T>
|
||||
struct is_hashable<T, std::void_t<decltype(std::hash<T>()(std::declval<T&>()))>>
|
||||
: std::true_type {};
|
||||
template <class T>
|
||||
using is_hashable_t = typename is_hashable<T>::type;
|
||||
|
||||
/**
|
||||
* is_function_type<T> is true_type iff T is a plain function type (i.e.
|
||||
* "Result(Args...)")
|
||||
*/
|
||||
template <class T>
|
||||
struct is_function_type : std::false_type {};
|
||||
template <class Result, class... Args>
|
||||
struct is_function_type<Result(Args...)> : std::true_type {};
|
||||
template <class T>
|
||||
using is_function_type_t = typename is_function_type<T>::type;
|
||||
|
||||
/**
|
||||
* is_instantiation_of<T, I> is true_type iff I is a template instantiation of T
|
||||
* (e.g. vector<int> is an instantiation of vector) Example:
|
||||
* is_instantiation_of_t<vector, vector<int>> // true
|
||||
* is_instantiation_of_t<pair, pair<int, string>> // true
|
||||
* is_instantiation_of_t<vector, pair<int, string>> // false
|
||||
*/
|
||||
template <template <class...> class Template, class T>
|
||||
struct is_instantiation_of : std::false_type {};
|
||||
template <template <class...> class Template, class... Args>
|
||||
struct is_instantiation_of<Template, Template<Args...>> : std::true_type {};
|
||||
template <template <class...> class Template, class T>
|
||||
using is_instantiation_of_t = typename is_instantiation_of<Template, T>::type;
|
||||
|
||||
namespace detail {
|
||||
/**
|
||||
* strip_class: helper to remove the class type from pointers to `operator()`.
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
struct strip_class {};
|
||||
template <typename Class, typename Result, typename... Args>
|
||||
struct strip_class<Result (Class::*)(Args...)> {
|
||||
using type = Result(Args...);
|
||||
};
|
||||
template <typename Class, typename Result, typename... Args>
|
||||
struct strip_class<Result (Class::*)(Args...) const> {
|
||||
using type = Result(Args...);
|
||||
};
|
||||
template <typename T>
|
||||
using strip_class_t = typename strip_class<T>::type;
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* Evaluates to true_type, iff the given class is a Functor
|
||||
* (i.e. has a call operator with some set of arguments)
|
||||
*/
|
||||
|
||||
template <class Functor, class Enable = void>
|
||||
struct is_functor : std::false_type {};
|
||||
template <class Functor>
|
||||
struct is_functor<
|
||||
Functor,
|
||||
std::enable_if_t<is_function_type<
|
||||
detail::strip_class_t<decltype(&Functor::operator())>>::value>>
|
||||
: std::true_type {};
|
||||
|
||||
/**
|
||||
* lambda_is_stateless<T> is true iff the lambda type T is stateless
|
||||
* (i.e. does not have a closure).
|
||||
* Example:
|
||||
* auto stateless_lambda = [] (int a) {return a;};
|
||||
* lambda_is_stateless<decltype(stateless_lambda)> // true
|
||||
* auto stateful_lambda = [&] (int a) {return a;};
|
||||
* lambda_is_stateless<decltype(stateful_lambda)> // false
|
||||
*/
|
||||
namespace detail {
|
||||
template <class LambdaType, class FuncType>
|
||||
struct is_stateless_lambda__ final {
|
||||
static_assert(
|
||||
!std::is_same_v<LambdaType, LambdaType>,
|
||||
"Base case shouldn't be hit");
|
||||
};
|
||||
// implementation idea: According to the C++ standard, stateless lambdas are
|
||||
// convertible to function pointers
|
||||
template <class LambdaType, class C, class Result, class... Args>
|
||||
struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...) const>
|
||||
: std::is_convertible<LambdaType, Result (*)(Args...)> {};
|
||||
template <class LambdaType, class C, class Result, class... Args>
|
||||
struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...)>
|
||||
: std::is_convertible<LambdaType, Result (*)(Args...)> {};
|
||||
|
||||
// case where LambdaType is not even a functor
|
||||
template <class LambdaType, class Enable = void>
|
||||
struct is_stateless_lambda_ final : std::false_type {};
|
||||
// case where LambdaType is a functor
|
||||
template <class LambdaType>
|
||||
struct is_stateless_lambda_<
|
||||
LambdaType,
|
||||
std::enable_if_t<is_functor<LambdaType>::value>>
|
||||
: is_stateless_lambda__<LambdaType, decltype(&LambdaType::operator())> {};
|
||||
} // namespace detail
|
||||
template <class T>
|
||||
using is_stateless_lambda = detail::is_stateless_lambda_<std::decay_t<T>>;
|
||||
|
||||
/**
|
||||
* is_type_condition<C> is true_type iff C<...> is a type trait representing a
|
||||
* condition (i.e. has a constexpr static bool ::value member) Example:
|
||||
* is_type_condition<std::is_reference> // true
|
||||
*/
|
||||
template <template <class> class C, class Enable = void>
|
||||
struct is_type_condition : std::false_type {};
|
||||
template <template <class> class C>
|
||||
struct is_type_condition<
|
||||
C,
|
||||
std::enable_if_t<
|
||||
std::is_same_v<bool, std::remove_cv_t<decltype(C<int>::value)>>>>
|
||||
: std::true_type {};
|
||||
|
||||
/**
|
||||
* is_fundamental<T> is true_type iff the lambda type T is a fundamental type
|
||||
* (that is, arithmetic type, void, or nullptr_t). Example: is_fundamental<int>
|
||||
* // true We define it here to resolve a MSVC bug. See
|
||||
* https://github.com/pytorch/pytorch/issues/30932 for details.
|
||||
*/
|
||||
template <class T>
|
||||
struct is_fundamental : std::is_fundamental<T> {};
|
||||
} // namespace c10::guts
|
||||
#include <torch/headeronly/util/TypeTraits.h>
|
||||
|
||||
@ -24,6 +24,7 @@ set(C10_XPU_HEADERS
|
||||
XPUCachingAllocator.h
|
||||
XPUDeviceProp.h
|
||||
XPUException.h
|
||||
XPUEvent.h
|
||||
XPUFunctions.h
|
||||
XPUMacros.h
|
||||
XPUStream.h
|
||||
|
||||
178
c10/xpu/XPUEvent.h
Normal file
178
c10/xpu/XPUEvent.h
Normal file
@ -0,0 +1,178 @@
|
||||
#pragma once
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
|
||||
namespace c10::xpu {
|
||||
|
||||
/*
|
||||
* XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
|
||||
* constructed lazily when first recorded. It has a device, and this device is
|
||||
* acquired from the first recording stream. Later streams that record the event
|
||||
* must match the same device.
|
||||
*
|
||||
* Currently, XPUEvent does NOT support to export an inter-process event from
|
||||
* another process via inter-process communication(IPC). So it means that
|
||||
* inter-process communication for event handles between different processes is
|
||||
* not available. This could impact some applications that rely on cross-process
|
||||
* synchronization and communication.
|
||||
*/
|
||||
struct XPUEvent {
|
||||
// Constructors
|
||||
XPUEvent(bool enable_timing = false) noexcept
|
||||
: enable_timing_{enable_timing} {}
|
||||
|
||||
~XPUEvent() {
|
||||
if (isCreated()) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_deletion(
|
||||
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
C10_DISABLE_COPY_AND_ASSIGN(XPUEvent);
|
||||
|
||||
XPUEvent(XPUEvent&& other) = default;
|
||||
XPUEvent& operator=(XPUEvent&& other) = default;
|
||||
|
||||
operator sycl::event&() const {
|
||||
return event();
|
||||
}
|
||||
|
||||
std::optional<c10::Device> device() const {
|
||||
if (isCreated()) {
|
||||
return c10::Device(c10::kXPU, device_index_);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool isCreated() const {
|
||||
return (event_.get() != nullptr);
|
||||
}
|
||||
|
||||
DeviceIndex device_index() const {
|
||||
return device_index_;
|
||||
}
|
||||
|
||||
sycl::event& event() const {
|
||||
return *event_;
|
||||
}
|
||||
|
||||
bool query() const {
|
||||
using namespace sycl::info;
|
||||
if (!isCreated()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return event().get_info<event::command_execution_status>() ==
|
||||
event_command_status::complete;
|
||||
}
|
||||
|
||||
void record() {
|
||||
record(getCurrentXPUStream());
|
||||
}
|
||||
|
||||
void recordOnce(const XPUStream& stream) {
|
||||
if (!isCreated()) {
|
||||
record(stream);
|
||||
}
|
||||
}
|
||||
|
||||
void record(const XPUStream& stream) {
|
||||
if (!isCreated()) {
|
||||
device_index_ = stream.device_index();
|
||||
assignEvent(stream.queue());
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_creation(
|
||||
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
device_index_ == stream.device_index(),
|
||||
"Event device ",
|
||||
device_index_,
|
||||
" does not match recording stream's device ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
reassignEvent(stream.queue());
|
||||
}
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_record(
|
||||
c10::kXPU,
|
||||
reinterpret_cast<uintptr_t>(event_.get()),
|
||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||
}
|
||||
}
|
||||
|
||||
void block(const XPUStream& stream) {
|
||||
if (isCreated()) {
|
||||
std::vector<sycl::event> event_list{event()};
|
||||
// Make this stream wait until event_ is completed.
|
||||
stream.queue().ext_oneapi_submit_barrier(event_list);
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_wait(
|
||||
c10::kXPU,
|
||||
reinterpret_cast<uintptr_t>(event_.get()),
|
||||
reinterpret_cast<uintptr_t>(&stream.queue()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double elapsed_time(const XPUEvent& other) const {
|
||||
TORCH_CHECK(
|
||||
isCreated() && other.isCreated(),
|
||||
"Both events must be recorded before calculating elapsed time.");
|
||||
TORCH_CHECK(
|
||||
query() && other.query(),
|
||||
"Both events must be completed before calculating elapsed time.");
|
||||
TORCH_CHECK(
|
||||
enable_timing_ && other.enable_timing_,
|
||||
"Both events must be created with argument 'enable_timing=True'.");
|
||||
|
||||
using namespace sycl::info::event_profiling;
|
||||
// Block until both of the recorded events are completed.
|
||||
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
|
||||
uint64_t start_time_ns = event().get_profiling_info<command_end>();
|
||||
// Return the eplased time in milliseconds.
|
||||
return 1e-6 *
|
||||
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
|
||||
}
|
||||
|
||||
void synchronize() const {
|
||||
if (isCreated()) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_synchronization(
|
||||
c10::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
|
||||
}
|
||||
event().wait_and_throw();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void assignEvent(sycl::queue& queue) {
|
||||
if (enable_timing_) {
|
||||
event_ = std::make_unique<sycl::event>(
|
||||
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
|
||||
} else {
|
||||
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
||||
}
|
||||
}
|
||||
|
||||
void reassignEvent(sycl::queue& queue) {
|
||||
event_.reset();
|
||||
assignEvent(queue);
|
||||
}
|
||||
|
||||
bool enable_timing_ = false;
|
||||
c10::DeviceIndex device_index_ = -1;
|
||||
// Only need to track the last event, as events in an in-order queue are
|
||||
// executed sequentially.
|
||||
std::unique_ptr<sycl::event> event_;
|
||||
};
|
||||
|
||||
} // namespace c10::xpu
|
||||
@ -1,7 +1,7 @@
|
||||
# This will define the following variables:
|
||||
# SYCL_FOUND : True if the system has the SYCL library.
|
||||
# SYCL_INCLUDE_DIR : Include directories needed to use SYCL.
|
||||
# SYCL_LIBRARY_DIR :The path to the SYCL library.
|
||||
# SYCL_LIBRARY_DIR : The path to the SYCL library.
|
||||
# SYCL_LIBRARY : SYCL library fullname.
|
||||
# SYCL_COMPILER_VERSION : SYCL compiler version.
|
||||
|
||||
|
||||
@ -478,6 +478,7 @@ function(torch_update_find_cuda_flags)
|
||||
endfunction()
|
||||
|
||||
include(CheckCXXCompilerFlag)
|
||||
include(CheckCCompilerFlag)
|
||||
include(CheckLinkerFlag)
|
||||
|
||||
##############################################################################
|
||||
@ -501,6 +502,24 @@ function(append_cxx_flag_if_supported flag outputvar)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(append_c_flag_if_supported flag outputvar)
|
||||
string(TOUPPER "HAS${flag}" _FLAG_NAME)
|
||||
string(REGEX REPLACE "[=-]" "_" _FLAG_NAME "${_FLAG_NAME}")
|
||||
|
||||
# GCC silences unknown -Wno-XXX flags, so test the corresponding -WXXX.
|
||||
if(CMAKE_C_COMPILER_ID STREQUAL "GNU")
|
||||
string(REGEX REPLACE "^Wno-" "W" new_flag "${flag}")
|
||||
else()
|
||||
set(new_flag "${flag}")
|
||||
endif()
|
||||
|
||||
check_c_compiler_flag("${new_flag}" ${_FLAG_NAME})
|
||||
if(${_FLAG_NAME})
|
||||
string(APPEND ${outputvar} " ${flag}")
|
||||
set(${outputvar} "${${outputvar}}" PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(target_compile_options_if_supported target flag)
|
||||
set(_compile_options "")
|
||||
append_cxx_flag_if_supported("${flag}" _compile_options)
|
||||
|
||||
164
docs/source/accelerator/hooks.md
Normal file
164
docs/source/accelerator/hooks.md
Normal file
@ -0,0 +1,164 @@
|
||||
# Accelerator Hooks
|
||||
|
||||
## Background
|
||||
|
||||
OpenReg hooks provide a mechanism for integrating custom accelerator devices into PyTorch's runtime system. OpenReg (Open Registration) is PyTorch's extensibility framework that allows accelerator vendors to register custom device backends without modifying PyTorch core code.
|
||||
|
||||
## Design
|
||||
|
||||
The following tables list all hooks that accelerator vendors need to implement when integrating a new device backend. These hooks are categorized into two priority levels:
|
||||
|
||||
- **High Priority Hooks**: Core APIs that PyTorch runtime directly depends on. Accelerator vendors are recommended to implement all high priority hooks to ensure full PyTorch compatibility and enable basic device functionality.
|
||||
|
||||
- **Low Priority Hooks**: Device management and utility APIs that PyTorch does not directly depend on. These hooks enhance user experience and multi-device support but are *optional*. Accelerator vendors can choose to implement them based on their specific requirements and use cases.
|
||||
|
||||
### High Priority Hooks
|
||||
|
||||
| Hook Method | Description | Application Scenario |
|
||||
| ---------------------------------- | --------------------------------------------------------- | -------------------------------------------------------------------------------- |
|
||||
| `init()` | Initializes the accelerator runtime and device contexts | Set up necessary state when PyTorch first accesses the device |
|
||||
| `hasPrimaryContext(DeviceIndex)` | Checks if a primary context exists for the device | Determine whether device initialization has occurred |
|
||||
| `getDefaultGenerator(DeviceIndex)` | Returns the default random number generator for a device | Access the device's primary RNG for reproducible random operations |
|
||||
| `getNewGenerator(DeviceIndex)` | Creates a new independent random number generator | Create isolated RNG instances for parallel operations |
|
||||
| `getDeviceFromPtr(void*)` | Determines which device a memory pointer belongs to | Identify the accelerator device associated with a memory allocation |
|
||||
| `getPinnedMemoryAllocator()` | Returns an allocator for pinned (page-locked) host memory | Allocate host memory that can be efficiently transferred to/from the accelerator |
|
||||
| `isPinnedPtr(void*)` | Checks if a pointer points to pinned memory | Validate memory types before performing operations |
|
||||
|
||||
### Low Priority Hooks
|
||||
|
||||
| Hook Method | Description | Application Scenario |
|
||||
| ---------------------------------- | ---------------------------------------------------------------------------- | -------------------------------------------------------------------- |
|
||||
| `isBuilt()` | Returns whether the accelerator backend is built/compiled into the extension | Check whether the accelerator library is available at compile time |
|
||||
| `isAvailable()` | Returns whether the accelerator hardware is available at runtime | Verify whether accelerator devices can be detected and initialized |
|
||||
| `deviceCount()` | Returns the number of available accelerator devices | Enumerate all available accelerator devices for device selection |
|
||||
| `setCurrentDevice(DeviceIndex)` | Sets the active device for the current thread | Switch the current thread's context to a specific accelerator device |
|
||||
| `getCurrentDevice()` | Returns the currently active device index | Query which accelerator device is active in the current thread |
|
||||
| `exchangeDevice(DeviceIndex)` | Atomically exchanges the current device and returns the previous one | Temporarily switch devices and restore the previous device afterward |
|
||||
| `maybeExchangeDevice(DeviceIndex)` | Conditionally exchanges device only if the index is valid | Safely attempt device switching with validation |
|
||||
|
||||
## Implementation
|
||||
|
||||
We can just take `getDefaultGenerator` as an implementation example:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
|
||||
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
|
||||
:linenos:
|
||||
```
|
||||
|
||||
In this implementation:
|
||||
|
||||
1. **Override the base interface**: The `getDefaultGenerator` method overrides the virtual method from `at::PrivateUse1HooksInterface`.
|
||||
|
||||
2. **Delegate to device-specific implementation**: It calls `getDefaultOpenRegGenerator(device_index)`, which manages a per-device generator instance.
|
||||
|
||||
3. **Return device-specific generator**: The returned `at::Generator` wraps an `OpenRegGeneratorImpl` that implements device-specific random number generation.
|
||||
|
||||
This pattern applies to all hooks: override the interface method, validate inputs, delegate to your device-specific API, and return results in PyTorch's expected format.
|
||||
|
||||
## Integration Example
|
||||
|
||||
The following sections demonstrate how PyTorch integrates with accelerator hooks when accessing the default random number generator. The example traces the complete flow from user-facing Python code down to the device-specific implementation.
|
||||
|
||||
### Layer 1: User Code
|
||||
|
||||
User code initiates the operation by calling `manual_seed` to set the random seed for reproducible results:
|
||||
|
||||
```python
|
||||
import torch
|
||||
torch.openreg.manual_seed(42)
|
||||
```
|
||||
|
||||
### Layer 2: Extension Python API
|
||||
|
||||
The Python API layer handles device management and calls into the C++ extension (defined in [`torch_openreg/openreg/random.py`][random.py]):
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py
|
||||
:language: python
|
||||
:start-after: LITERALINCLUDE START: OPENREG MANUAL SEED
|
||||
:end-before: LITERALINCLUDE END: OPENREG MANUAL SEED
|
||||
:linenos:
|
||||
```
|
||||
|
||||
The `manual_seed` function gets the current device index and calls `torch_openreg._C._get_default_generator(idx)` to obtain the device-specific generator, then sets the seed on it.
|
||||
|
||||
### Layer 3: Python/C++ Bridge
|
||||
|
||||
The C++ extension exposes `_getDefaultGenerator` to Python, which bridges to PyTorch's core runtime:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR
|
||||
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
|
||||
:linenos:
|
||||
:emphasize-lines: 10-11
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
|
||||
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
|
||||
:linenos:
|
||||
:emphasize-lines: 3
|
||||
```
|
||||
|
||||
This function unpacks the device index from Python, creates a `PrivateUse1` device object, and calls `at::globalContext().defaultGenerator()`. PyTorch's context then dispatches to the registered hooks.
|
||||
|
||||
### Layer 4: PyTorch Core Context
|
||||
|
||||
PyTorch's Context class dispatches to the appropriate accelerator hooks ([`aten/src/ATen/Context.h`][Context.h]):
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../aten/src/ATen/Context.h
|
||||
:language: c++
|
||||
:lines: 60-103
|
||||
:linenos:
|
||||
:emphasize-lines: 8-9, 24-25
|
||||
```
|
||||
|
||||
This layered architecture enables PyTorch to remain device-agnostic while delegating hardware-specific operations to accelerator implementations. The hooks are registered once at module load time:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG HOOK REGISTER
|
||||
:end-before: LITERALINCLUDE END: OPENREG HOOK REGISTER
|
||||
:linenos:
|
||||
:emphasize-lines: 4
|
||||
```
|
||||
|
||||
### Layer 5: Accelerator Hooks
|
||||
|
||||
The hooks interface provides the abstraction that PyTorch uses to delegate to device-specific implementations:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
|
||||
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
|
||||
:linenos:
|
||||
```
|
||||
|
||||
The `getDefaultGenerator` hook method overrides the base interface and delegates to `getDefaultOpenRegGenerator`, which manages the actual generator instances.
|
||||
|
||||
### Layer 6: Device-Specific Implementation
|
||||
|
||||
The device-specific implementation manages per-device generator instances:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR IMPL
|
||||
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR IMPL
|
||||
:linenos:
|
||||
```
|
||||
|
||||
This function maintains a static vector of generators (one per device), initializes them on first access, validates the device index, and returns the appropriate generator instance.
|
||||
|
||||
[random.py]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py#L48-L53 "random.py"
|
||||
[Context.h]: https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/Context.h#L61-L102 "Context.h"
|
||||
@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
|
||||
:glob:
|
||||
:maxdepth: 1
|
||||
|
||||
hooks
|
||||
autoload
|
||||
operators
|
||||
amp
|
||||
|
||||
@ -1308,8 +1308,319 @@ coverage_ignore_functions = [
|
||||
# torch.onnx.symbolic_opset7
|
||||
"max",
|
||||
"min",
|
||||
# torch.onnx.symbolic_opset8
|
||||
"addmm",
|
||||
"bmm",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"flatten",
|
||||
"full",
|
||||
"full_like",
|
||||
"gt",
|
||||
"lt",
|
||||
"matmul",
|
||||
"mm",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"prelu",
|
||||
"repeat",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
# torch.onnx.symbolic_opset9
|
||||
"abs",
|
||||
"acos",
|
||||
"adaptive_avg_pool1d",
|
||||
"adaptive_avg_pool2d",
|
||||
"adaptive_avg_pool3d",
|
||||
"adaptive_max_pool1d",
|
||||
"adaptive_max_pool2d",
|
||||
"adaptive_max_pool3d",
|
||||
"add",
|
||||
"addcmul",
|
||||
"addmm",
|
||||
"alias",
|
||||
"amax",
|
||||
"amin",
|
||||
"aminmax",
|
||||
"arange",
|
||||
"argmax",
|
||||
"argmin",
|
||||
"as_strided",
|
||||
"as_tensor",
|
||||
"asin",
|
||||
"atan",
|
||||
"atan2",
|
||||
"avg_pool1d",
|
||||
"avg_pool2d",
|
||||
"avg_pool3d",
|
||||
"baddbmm",
|
||||
"batch_norm",
|
||||
"bernoulli",
|
||||
"bitwise_not",
|
||||
"bitwise_or",
|
||||
"bmm",
|
||||
"broadcast_tensors",
|
||||
"broadcast_to",
|
||||
"bucketize",
|
||||
"cat",
|
||||
"cdist",
|
||||
"ceil",
|
||||
"clamp",
|
||||
"clamp_max",
|
||||
"clamp_min",
|
||||
"clone",
|
||||
"constant_pad_nd",
|
||||
"contiguous",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
"conv3d",
|
||||
"conv_tbc",
|
||||
"conv_transpose1d",
|
||||
"conv_transpose2d",
|
||||
"conv_transpose3d",
|
||||
"convert_element_type",
|
||||
"convolution",
|
||||
"cos",
|
||||
"cosine_similarity",
|
||||
"cross",
|
||||
"cumsum",
|
||||
"detach",
|
||||
"dim",
|
||||
"div",
|
||||
"dot",
|
||||
"dropout",
|
||||
"elu",
|
||||
"embedding",
|
||||
"embedding_bag",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"eq",
|
||||
"erf",
|
||||
"exp",
|
||||
"expand",
|
||||
"expand_as",
|
||||
"eye",
|
||||
"fill",
|
||||
"flatten",
|
||||
"floor",
|
||||
"floor_divide",
|
||||
"floordiv",
|
||||
"frobenius_norm",
|
||||
"full",
|
||||
"full_like",
|
||||
"gather",
|
||||
"ge",
|
||||
"gelu",
|
||||
"get_pool_ceil_padding",
|
||||
"glu",
|
||||
"group_norm",
|
||||
"gru",
|
||||
"gt",
|
||||
"hann_window",
|
||||
"hardshrink",
|
||||
"hardsigmoid",
|
||||
"hardswish",
|
||||
"hardtanh",
|
||||
"index",
|
||||
"index_add",
|
||||
"index_copy",
|
||||
"index_fill",
|
||||
"index_put",
|
||||
"index_select",
|
||||
"instance_norm",
|
||||
"is_floating_point",
|
||||
"is_pinned",
|
||||
"isnan",
|
||||
"item",
|
||||
"kl_div",
|
||||
"layer_norm",
|
||||
"le",
|
||||
"leaky_relu",
|
||||
"lerp",
|
||||
"lift",
|
||||
"linalg_cross",
|
||||
"linalg_matrix_norm",
|
||||
"linalg_norm",
|
||||
"linalg_vector_norm",
|
||||
"linear",
|
||||
"linspace",
|
||||
"log",
|
||||
"log10",
|
||||
"log1p",
|
||||
"log2",
|
||||
"log_sigmoid",
|
||||
"log_softmax",
|
||||
"logical_and",
|
||||
"logical_not",
|
||||
"logical_or",
|
||||
"logical_xor",
|
||||
"logit",
|
||||
"logsumexp",
|
||||
"lstm",
|
||||
"lstm_cell",
|
||||
"lt",
|
||||
"masked_fill",
|
||||
"masked_fill_",
|
||||
"matmul",
|
||||
"max",
|
||||
"max_pool1d",
|
||||
"max_pool1d_with_indices",
|
||||
"max_pool2d",
|
||||
"max_pool2d_with_indices",
|
||||
"max_pool3d",
|
||||
"max_pool3d_with_indices",
|
||||
"maximum",
|
||||
"meshgrid",
|
||||
"min",
|
||||
"minimum",
|
||||
"mish",
|
||||
"mm",
|
||||
"movedim",
|
||||
"mse_loss",
|
||||
"mul",
|
||||
"multinomial",
|
||||
"mv",
|
||||
"narrow",
|
||||
"native_layer_norm",
|
||||
"ne",
|
||||
"neg",
|
||||
"new_empty",
|
||||
"new_full",
|
||||
"new_ones",
|
||||
"new_zeros",
|
||||
"nonzero",
|
||||
"nonzero_numpy",
|
||||
"noop_complex_operators",
|
||||
"norm",
|
||||
"numel",
|
||||
"numpy_T",
|
||||
"one_hot",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"onnx_placeholder",
|
||||
"overload_by_arg_count",
|
||||
"pad",
|
||||
"pairwise_distance",
|
||||
"permute",
|
||||
"pixel_shuffle",
|
||||
"pixel_unshuffle",
|
||||
"pow",
|
||||
"prelu",
|
||||
"prim_constant",
|
||||
"prim_constant_chunk",
|
||||
"prim_constant_split",
|
||||
"prim_data",
|
||||
"prim_device",
|
||||
"prim_dtype",
|
||||
"prim_if",
|
||||
"prim_layout",
|
||||
"prim_list_construct",
|
||||
"prim_list_unpack",
|
||||
"prim_loop",
|
||||
"prim_max",
|
||||
"prim_min",
|
||||
"prim_shape",
|
||||
"prim_tolist",
|
||||
"prim_tuple_construct",
|
||||
"prim_type",
|
||||
"prim_unchecked_cast",
|
||||
"prim_uninitialized",
|
||||
"rand",
|
||||
"rand_like",
|
||||
"randint",
|
||||
"randint_like",
|
||||
"randn",
|
||||
"randn_like",
|
||||
"reciprocal",
|
||||
"reflection_pad",
|
||||
"relu",
|
||||
"relu6",
|
||||
"remainder",
|
||||
"repeat",
|
||||
"repeat_interleave",
|
||||
"replication_pad",
|
||||
"reshape",
|
||||
"reshape_as",
|
||||
"rnn_relu",
|
||||
"rnn_tanh",
|
||||
"roll",
|
||||
"rrelu",
|
||||
"rsqrt",
|
||||
"rsub",
|
||||
"scalar_tensor",
|
||||
"scatter",
|
||||
"scatter_add",
|
||||
"select",
|
||||
"selu",
|
||||
"sigmoid",
|
||||
"sign",
|
||||
"silu",
|
||||
"sin",
|
||||
"size",
|
||||
"slice",
|
||||
"softmax",
|
||||
"softplus",
|
||||
"softshrink",
|
||||
"sort",
|
||||
"split",
|
||||
"split_with_sizes",
|
||||
"sqrt",
|
||||
"square",
|
||||
"squeeze",
|
||||
"stack",
|
||||
"std",
|
||||
"std_mean",
|
||||
"sub",
|
||||
"t",
|
||||
"take",
|
||||
"tan",
|
||||
"tanh",
|
||||
"tanhshrink",
|
||||
"tensor",
|
||||
"threshold",
|
||||
"to",
|
||||
"topk",
|
||||
"transpose",
|
||||
"true_divide",
|
||||
"type_as",
|
||||
"unbind",
|
||||
"unfold",
|
||||
"unsafe_chunk",
|
||||
"unsafe_split",
|
||||
"unsafe_split_with_sizes",
|
||||
"unsqueeze",
|
||||
"unsupported_complex_operators",
|
||||
"unused",
|
||||
"upsample_bilinear2d",
|
||||
"upsample_linear1d",
|
||||
"upsample_nearest1d",
|
||||
"upsample_nearest2d",
|
||||
"upsample_nearest3d",
|
||||
"upsample_trilinear3d",
|
||||
"var",
|
||||
"var_mean",
|
||||
"view",
|
||||
"view_as",
|
||||
"where",
|
||||
"wrap_logical_op_with_cast_to",
|
||||
"wrap_logical_op_with_negation",
|
||||
"zero",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
# torch.onnx.utils
|
||||
"disable_apex_o2_state_dict_hook",
|
||||
"export",
|
||||
"export_to_pretty_string",
|
||||
"exporter_context",
|
||||
"is_in_onnx_export",
|
||||
"model_signature",
|
||||
"register_custom_op_symbolic",
|
||||
"select_model_mode_for_export",
|
||||
"setup_onnx_logging",
|
||||
"unconvertible_ops",
|
||||
"unpack_quantized_tensor",
|
||||
"warn_on_static_input_change",
|
||||
# torch.onnx.verification
|
||||
"check_export_model_diff",
|
||||
"verify",
|
||||
"verify_aten_graph",
|
||||
@ -1400,6 +1711,32 @@ coverage_ignore_functions = [
|
||||
"noop_context_fn",
|
||||
"set_checkpoint_early_stop",
|
||||
"set_device_states",
|
||||
# torch.utils.collect_env
|
||||
"check_release_file",
|
||||
"get_cachingallocator_config",
|
||||
"get_clang_version",
|
||||
"get_cmake_version",
|
||||
"get_conda_packages",
|
||||
"get_cpu_info",
|
||||
"get_cuda_module_loading_config",
|
||||
"get_cudnn_version",
|
||||
"get_env_info",
|
||||
"get_gcc_version",
|
||||
"get_gpu_info",
|
||||
"get_libc_version",
|
||||
"get_lsb_version",
|
||||
"get_mac_version",
|
||||
"get_nvidia_driver_version",
|
||||
"get_nvidia_smi",
|
||||
"get_os",
|
||||
"get_pip_packages",
|
||||
"get_platform",
|
||||
"get_pretty_env_info",
|
||||
"get_python_platform",
|
||||
"get_running_cuda_version",
|
||||
"get_windows_version",
|
||||
"is_xnnpack_available",
|
||||
"pretty_str",
|
||||
# torch.utils.cpp_backtrace
|
||||
"get_cpp_backtrace",
|
||||
# torch.utils.cpp_extension
|
||||
@ -1463,6 +1800,52 @@ coverage_ignore_functions = [
|
||||
"apply_shuffle_seed",
|
||||
"apply_shuffle_settings",
|
||||
"get_all_graph_pipes",
|
||||
# torch.utils.flop_counter
|
||||
"addmm_flop",
|
||||
"baddbmm_flop",
|
||||
"bmm_flop",
|
||||
"conv_backward_flop",
|
||||
"conv_flop",
|
||||
"conv_flop_count",
|
||||
"convert_num_with_suffix",
|
||||
"get_shape",
|
||||
"get_suffix_str",
|
||||
"mm_flop",
|
||||
"normalize_tuple",
|
||||
"register_flop_formula",
|
||||
"sdpa_backward_flop",
|
||||
"sdpa_backward_flop_count",
|
||||
"sdpa_flop",
|
||||
"sdpa_flop_count",
|
||||
"shape_wrapper",
|
||||
"transpose_shape",
|
||||
# torch.utils.hipify.hipify_python
|
||||
"add_dim3",
|
||||
"compute_stats",
|
||||
"extract_arguments",
|
||||
"file_add_header",
|
||||
"file_specific_replacement",
|
||||
"find_bracket_group",
|
||||
"find_closure_group",
|
||||
"find_parentheses_group",
|
||||
"fix_static_global_kernels",
|
||||
"get_hip_file_path",
|
||||
"hip_header_magic",
|
||||
"hipify",
|
||||
"is_caffe2_gpu_file",
|
||||
"is_cusparse_file",
|
||||
"is_out_of_place",
|
||||
"is_pytorch_file",
|
||||
"is_special_file",
|
||||
"match_extensions",
|
||||
"matched_files_iter",
|
||||
"openf",
|
||||
"preprocess_file_and_save_result",
|
||||
"preprocessor",
|
||||
"processKernelLaunches",
|
||||
"replace_extern_shared",
|
||||
"replace_math_functions",
|
||||
"str2bool",
|
||||
# torch.utils.hooks
|
||||
"unserializable_hook",
|
||||
"warn_if_has_hooks",
|
||||
|
||||
21
docs/source/mtia.mtia_graph.md
Normal file
21
docs/source/mtia.mtia_graph.md
Normal file
@ -0,0 +1,21 @@
|
||||
# torch.mtia.mtia_graph
|
||||
|
||||
The MTIA backend is implemented out of the tree, only interfaces are defined here.
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.mtia.mtia_graph
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.mtia.mtia_graph
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: MTIAGraph
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: graph
|
||||
:members:
|
||||
```
|
||||
@ -14,6 +14,10 @@ Utils
|
||||
|
||||
sdpa_kernel
|
||||
SDPBackend
|
||||
register_flash_attention_impl
|
||||
activate_flash_attention_impl
|
||||
list_flash_attention_impls
|
||||
current_flash_attention_impl
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
@ -29,6 +29,7 @@ mps
|
||||
xpu
|
||||
mtia
|
||||
mtia.memory
|
||||
mtia.mtia_graph
|
||||
meta
|
||||
torch.backends <backends>
|
||||
torch.export <export>
|
||||
|
||||
@ -19,91 +19,6 @@
|
||||
swap_tensors
|
||||
```
|
||||
|
||||
# torch.utils.collect_env
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.collect_env
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.collect_env
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
check_release_file
|
||||
is_xnnpack_available
|
||||
pretty_str
|
||||
```
|
||||
|
||||
# torch.utils.flop_counter
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.flop_counter
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.flop_counter
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
baddbmm_flop
|
||||
bmm_flop
|
||||
conv_backward_flop
|
||||
conv_flop
|
||||
conv_flop_count
|
||||
register_flop_formula
|
||||
sdpa_backward_flop
|
||||
sdpa_backward_flop_count
|
||||
sdpa_flop
|
||||
sdpa_flop_count
|
||||
shape_wrapper
|
||||
```
|
||||
|
||||
# torch.utils.hipify.hipify_python
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.hipify.hipify_python
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.hipify.hipify_python
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
compute_stats
|
||||
extract_arguments
|
||||
file_add_header
|
||||
file_specific_replacement
|
||||
find_bracket_group
|
||||
find_closure_group
|
||||
find_parentheses_group
|
||||
fix_static_global_kernels
|
||||
hip_header_magic
|
||||
hipify
|
||||
is_caffe2_gpu_file
|
||||
is_cusparse_file
|
||||
is_out_of_place
|
||||
is_pytorch_file
|
||||
is_special_file
|
||||
openf
|
||||
preprocess_file_and_save_result
|
||||
preprocessor
|
||||
processKernelLaunches
|
||||
replace_extern_shared
|
||||
replace_math_functions
|
||||
str2bool
|
||||
```
|
||||
|
||||
|
||||
<!-- This module needs to be documented. Adding here in the meantime
|
||||
for tracking purposes -->
|
||||
```{eval-rst}
|
||||
@ -128,6 +43,7 @@ for tracking purposes -->
|
||||
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
|
||||
.. py:module:: torch.utils.bundled_inputs
|
||||
.. py:module:: torch.utils.checkpoint
|
||||
.. py:module:: torch.utils.collect_env
|
||||
.. py:module:: torch.utils.cpp_backtrace
|
||||
.. py:module:: torch.utils.cpp_extension
|
||||
.. py:module:: torch.utils.data.backward_compatibility
|
||||
@ -164,8 +80,10 @@ for tracking purposes -->
|
||||
.. py:module:: torch.utils.data.sampler
|
||||
.. py:module:: torch.utils.dlpack
|
||||
.. py:module:: torch.utils.file_baton
|
||||
.. py:module:: torch.utils.flop_counter
|
||||
.. py:module:: torch.utils.hipify.constants
|
||||
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
|
||||
.. py:module:: torch.utils.hipify.hipify_python
|
||||
.. py:module:: torch.utils.hipify.version
|
||||
.. py:module:: torch.utils.hooks
|
||||
.. py:module:: torch.utils.jit.log_extract
|
||||
|
||||
@ -260,6 +260,7 @@ select = [
|
||||
"TRY401", # verbose-log-message
|
||||
"UP",
|
||||
"YTT",
|
||||
"S101",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.pyupgrade]
|
||||
@ -339,6 +340,39 @@ keep-runtime-typing = true
|
||||
"tools/linter/**" = [
|
||||
"LOG015" # please fix
|
||||
]
|
||||
"benchmarks/**" = [
|
||||
"S101"
|
||||
]
|
||||
"test/**" = [
|
||||
"S101"
|
||||
]
|
||||
"torchgen/**" = [
|
||||
"S101"
|
||||
]
|
||||
"torch/**" = [
|
||||
"S101"
|
||||
]
|
||||
"tools/**" = [
|
||||
"S101"
|
||||
]
|
||||
"setup.py" = [
|
||||
"S101"
|
||||
]
|
||||
"functorch/**" = [
|
||||
"S101"
|
||||
]
|
||||
"docs/**" = [
|
||||
"S101"
|
||||
]
|
||||
"android/**" = [
|
||||
"S101"
|
||||
]
|
||||
".github/**" = [
|
||||
"S101"
|
||||
]
|
||||
".ci/**" = [
|
||||
"S101"
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words = "tools/linter/dictionary.txt"
|
||||
|
||||
@ -10,7 +10,7 @@ tp2_dir="$top_dir/third_party"
|
||||
pip install ninja
|
||||
|
||||
# Install onnx
|
||||
pip install --no-use-pep517 -e "$tp2_dir/onnx"
|
||||
pip install -e "$tp2_dir/onnx"
|
||||
|
||||
# Install caffe2 and pytorch
|
||||
pip install -r "$top_dir/caffe2/requirements.txt"
|
||||
|
||||
47
setup.py
47
setup.py
@ -1358,6 +1358,45 @@ class concat_license_files:
|
||||
|
||||
# Need to create the proper LICENSE.txt for the wheel
|
||||
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
|
||||
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
|
||||
|
||||
Excludes:
|
||||
- torch/include/torch/headeronly/*
|
||||
- torch/include/torch/csrc/stable/*
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/generated/
|
||||
"""
|
||||
header_extensions = (".h", ".hpp", ".cuh")
|
||||
header_files = [
|
||||
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
|
||||
]
|
||||
|
||||
# Paths to exclude from wrapping
|
||||
exclude_dir_patterns = [
|
||||
"torch/include/torch/headeronly/",
|
||||
"torch/include/torch/csrc/stable/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/c/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
|
||||
]
|
||||
|
||||
for header_file in header_files:
|
||||
rel_path = header_file.relative_to(bdist_dir).as_posix()
|
||||
|
||||
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
|
||||
report(f"Skipping header: {rel_path}")
|
||||
continue
|
||||
|
||||
original_content = header_file.read_text(encoding="utf-8")
|
||||
wrapped_content = (
|
||||
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
f"{original_content}"
|
||||
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
)
|
||||
|
||||
header_file.write_text(wrapped_content, encoding="utf-8")
|
||||
report(f"Wrapped header: {rel_path}")
|
||||
|
||||
def run(self) -> None:
|
||||
with concat_license_files(include_files=True):
|
||||
super().run()
|
||||
@ -1380,6 +1419,14 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
# need an __init__.py file otherwise we wouldn't have a package
|
||||
(bdist_dir / "torch" / "__init__.py").touch()
|
||||
|
||||
# Wrap all header files with TORCH_STABLE_ONLY macro
|
||||
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
|
||||
bdist_dir = Path(self.bdist_dir)
|
||||
report(
|
||||
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
|
||||
)
|
||||
self._wrap_headers_with_macro(bdist_dir)
|
||||
|
||||
|
||||
class clean(Command):
|
||||
user_options: ClassVar[list[tuple[str, str | None, str]]] = []
|
||||
|
||||
@ -308,12 +308,16 @@ class StepcurrentPlugin:
|
||||
self.report_status = ""
|
||||
assert config.cache is not None
|
||||
self.cache: pytest.Cache = config.cache
|
||||
self.directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}"
|
||||
self.lastrun: Optional[str] = self.cache.get(self.directory, None)
|
||||
directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}"
|
||||
self.lastrun_location = f"{directory}/lastrun"
|
||||
self.lastrun: Optional[str] = self.cache.get(self.lastrun_location, None)
|
||||
self.initial_val = self.lastrun
|
||||
self.skip: bool = config.getoption("stepcurrent_skip")
|
||||
self.run_single: bool = config.getoption("run_single")
|
||||
|
||||
self.made_failing_xml_location = f"{directory}/made_failing_xml"
|
||||
self.cache.set(self.made_failing_xml_location, False)
|
||||
|
||||
def pytest_collection_modifyitems(self, config: Config, items: list[Any]) -> None:
|
||||
if not self.lastrun:
|
||||
self.report_status = "Cannot find last run test, not skipping"
|
||||
@ -349,8 +353,10 @@ class StepcurrentPlugin:
|
||||
|
||||
def pytest_runtest_protocol(self, item, nextitem) -> None:
|
||||
self.lastrun = item.nodeid
|
||||
self.cache.set(self.directory, self.lastrun)
|
||||
self.cache.set(self.lastrun_location, self.lastrun)
|
||||
|
||||
def pytest_sessionfinish(self, session, exitstatus):
|
||||
if exitstatus == 0:
|
||||
self.cache.set(self.directory, self.initial_val)
|
||||
self.cache.set(self.lastrun_location, self.initial_val)
|
||||
if exitstatus != 0:
|
||||
self.cache.set(self.made_failing_xml_location, True)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user