mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 23:04:54 +08:00
Compare commits
53 Commits
dev/joona/
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 6d8946c013 | |||
| 4ed26f7382 | |||
| 4c79305b87 | |||
| f4b8c4f907 | |||
| d629b7a459 | |||
| 0922ba5f42 | |||
| c87295c044 | |||
| 7aa210d215 | |||
| 5a368b8010 | |||
| 602102be50 | |||
| 200156e385 | |||
| a2daf3fc86 | |||
| 52b45c16de | |||
| 2ef85bed5a | |||
| d99c6bcf69 | |||
| 8378abda84 | |||
| 5b42a5d9a6 | |||
| caca3f2eec | |||
| 9e2bf129e1 | |||
| c429b1fc5c | |||
| 1176b2b0b7 | |||
| dd37a1a434 | |||
| a74adcf80e | |||
| 5eac46a011 | |||
| e0fff31ae3 | |||
| 7ede33b8e3 | |||
| 065176cd97 | |||
| 02ee7dd7d3 | |||
| 99fdca8f4d | |||
| 9d1a74cb0c | |||
| 40e6f090d9 | |||
| bfddfde50c | |||
| b6570615f8 | |||
| 226850cc66 | |||
| f8a2ce3b9a | |||
| e2c6834584 | |||
| 0e7235ed73 | |||
| 3522e0ce74 | |||
| 50bf1f0b81 | |||
| c78e64622e | |||
| 5623628894 | |||
| 2aba180114 | |||
| 45b2c3d312 | |||
| 5b1e112cf9 | |||
| 5e6ac5c6e1 | |||
| 79317dc7a7 | |||
| 96a4c4b3d1 | |||
| 05bcfcc5d1 | |||
| 8cf0bdde45 | |||
| 813e5eae9b | |||
| 2ef236e3e3 | |||
| 532389fe9e | |||
| 08de54f1ea |
@ -402,3 +402,6 @@ scikit-build==0.18.1
|
||||
pyre-extensions==0.0.32
|
||||
tabulate==0.9.0
|
||||
#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI
|
||||
|
||||
Flask==3.1.2
|
||||
#Description: required for torch.distributed.debug
|
||||
|
||||
@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _compile_and_extract_symbols(
|
||||
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Helper to compile a C++ file and extract all symbols.
|
||||
|
||||
Args:
|
||||
cpp_content: C++ source code to compile
|
||||
compile_flags: Compilation flags
|
||||
exclude_list: List of symbol names to exclude. Defaults to ["main"].
|
||||
|
||||
Returns:
|
||||
List of all symbols found in the object file (excluding those in exclude_list).
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
if exclude_list is None:
|
||||
exclude_list = ["main"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmppath = Path(tmpdir)
|
||||
cpp_file = tmppath / "test.cpp"
|
||||
obj_file = tmppath / "test.o"
|
||||
|
||||
cpp_file.write_text(cpp_content)
|
||||
|
||||
result = subprocess.run(
|
||||
compile_flags + [str(cpp_file), "-o", str(obj_file)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Compilation failed: {result.stderr}")
|
||||
|
||||
symbols = get_symbols(str(obj_file))
|
||||
|
||||
# Return all symbol names, excluding those in the exclude list
|
||||
return [name for _addr, _stype, name in symbols if name not in exclude_list]
|
||||
|
||||
|
||||
def check_stable_only_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
|
||||
|
||||
This approach tests:
|
||||
1. WITHOUT macros -> many torch symbols exposed
|
||||
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
|
||||
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
|
||||
4. WITH both macros -> zero torch symbols (all hidden)
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
test_cpp_content = """
|
||||
// Main torch C++ API headers
|
||||
#include <torch/torch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// ATen tensor library
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
// Core c10 headers (commonly used)
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
int main() { return 0; }
|
||||
"""
|
||||
|
||||
base_compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c", # Compile only, don't link
|
||||
]
|
||||
|
||||
# Compile WITHOUT any macros
|
||||
symbols_without = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=base_compile_flags,
|
||||
)
|
||||
|
||||
# We expect constexpr symbols, inline functions used by other headers etc.
|
||||
# to produce symbols
|
||||
num_symbols_without = len(symbols_without)
|
||||
print(f"Found {num_symbols_without} symbols without any macros defined")
|
||||
assert num_symbols_without != 0, (
|
||||
"Expected a non-zero number of symbols without any macros"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
|
||||
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
|
||||
|
||||
symbols_with_stable_only = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_stable_only,
|
||||
)
|
||||
|
||||
num_symbols_with_stable_only = len(symbols_with_stable_only)
|
||||
assert num_symbols_with_stable_only == 0, (
|
||||
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
|
||||
compile_flags_with_target_version = base_compile_flags + [
|
||||
"-DTORCH_TARGET_VERSION=1"
|
||||
]
|
||||
|
||||
symbols_with_target_version = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_target_version,
|
||||
)
|
||||
|
||||
num_symbols_with_target_version = len(symbols_with_target_version)
|
||||
assert num_symbols_with_target_version == 0, (
|
||||
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
|
||||
)
|
||||
|
||||
# Compile WITH both macros (expect 0 symbols)
|
||||
compile_flags_with_both = base_compile_flags + [
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
"-DTORCH_TARGET_VERSION=1",
|
||||
]
|
||||
|
||||
symbols_with_both = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_both,
|
||||
)
|
||||
|
||||
num_symbols_with_both = len(symbols_with_both)
|
||||
assert num_symbols_with_both == 0, (
|
||||
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_api_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
stable_dir = include_dir / "torch" / "csrc" / "stable"
|
||||
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
|
||||
|
||||
stable_headers = list(stable_dir.rglob("*.h"))
|
||||
if not stable_headers:
|
||||
raise RuntimeError("Could not find any stable headers")
|
||||
|
||||
includes = []
|
||||
for header in stable_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_stable_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable = len(symbols_stable)
|
||||
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
|
||||
assert num_symbols_stable > 0, (
|
||||
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_headeronly_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Find all headers in torch/headeronly
|
||||
headeronly_dir = include_dir / "torch" / "headeronly"
|
||||
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
|
||||
headeronly_headers = list(headeronly_dir.rglob("*.h"))
|
||||
if not headeronly_headers:
|
||||
raise RuntimeError("Could not find any headeronly headers")
|
||||
|
||||
# Filter out platform-specific headers that may not compile everywhere
|
||||
platform_specific_keywords = [
|
||||
"cpu/vec",
|
||||
]
|
||||
|
||||
filtered_headers = []
|
||||
for header in headeronly_headers:
|
||||
rel_path = header.relative_to(include_dir).as_posix()
|
||||
if not any(
|
||||
keyword in rel_path.lower() for keyword in platform_specific_keywords
|
||||
):
|
||||
filtered_headers.append(header)
|
||||
|
||||
includes = []
|
||||
for header in filtered_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_headeronly_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_headeronly = _compile_and_extract_symbols(
|
||||
cpp_content=test_headeronly_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_headeronly = len(symbols_headeronly)
|
||||
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
|
||||
assert num_symbols_headeronly > 0, (
|
||||
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_headeronly} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_aoti_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_shim_content = """
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
int main() {
|
||||
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
|
||||
int32_t (*fp2)() = &aoti_torch_dtype_float32;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_shim = len(symbols_shim)
|
||||
assert num_symbols_shim > 0, (
|
||||
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_c_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Check if the stable C shim exists
|
||||
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
|
||||
if not stable_shim.exists():
|
||||
raise RuntimeError("Could not find stable c shim")
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_stable_shim_content = """
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
int main() {
|
||||
// Reference stable C API functions to create undefined symbols
|
||||
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
|
||||
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable_shim = len(symbols_stable_shim)
|
||||
assert num_symbols_stable_shim > 0, (
|
||||
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
|
||||
print(f"lib: {lib}")
|
||||
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
|
||||
@ -460,13 +129,6 @@ def main() -> None:
|
||||
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
|
||||
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
|
||||
|
||||
# Check symbols when TORCH_STABLE_ONLY is defined
|
||||
check_stable_only_symbols(install_root)
|
||||
check_stable_api_symbols(install_root)
|
||||
check_headeronly_symbols(install_root)
|
||||
check_aoti_shim_symbols(install_root)
|
||||
check_stable_c_shim_symbols(install_root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
330
.spin/cmds.py
Normal file
330
.spin/cmds.py
Normal file
@ -0,0 +1,330 @@
|
||||
import hashlib
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import spin
|
||||
|
||||
|
||||
def file_digest(file, algorithm: str):
|
||||
try:
|
||||
return hashlib.file_digest(file, algorithm)
|
||||
except AttributeError:
|
||||
pass # Fallback to manual implementation below
|
||||
hash = hashlib.new(algorithm)
|
||||
while chunk := file.read(8192):
|
||||
hash.update(chunk)
|
||||
return hash
|
||||
|
||||
|
||||
def _hash_file(file):
|
||||
with open(file, "rb") as f:
|
||||
hash = file_digest(f, "sha256")
|
||||
return hash.hexdigest()
|
||||
|
||||
|
||||
def _hash_files(files):
|
||||
hashes = {file: _hash_file(file) for file in files}
|
||||
return hashes
|
||||
|
||||
|
||||
def _read_hashes(hash_file: Path):
|
||||
if not hash_file.exists():
|
||||
return {}
|
||||
with hash_file.open("r") as f:
|
||||
lines = f.readlines()
|
||||
hashes = {}
|
||||
for line in lines:
|
||||
hash = line[:64]
|
||||
file = line[66:].strip()
|
||||
hashes[file] = hash
|
||||
return hashes
|
||||
|
||||
|
||||
def _updated_hashes(hash_file, files_to_hash):
|
||||
old_hashes = _read_hashes(hash_file)
|
||||
new_hashes = _hash_files(files_to_hash)
|
||||
if new_hashes != old_hashes:
|
||||
return new_hashes
|
||||
return None
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_version():
|
||||
"""Regenerate version.py."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.generate_torch_version",
|
||||
"--is-debug=false",
|
||||
]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
TYPE_STUBS = [
|
||||
(
|
||||
"Pytorch type stubs",
|
||||
Path(".lintbin/.pytorch-type-stubs.sha256"),
|
||||
[
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"tools/autograd/deprecated.yaml",
|
||||
],
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.pyi.gen_pyi",
|
||||
"--native-functions-path",
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"--tags-path",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"--deprecated-functions-path",
|
||||
"tools/autograd/deprecated.yaml",
|
||||
],
|
||||
),
|
||||
(
|
||||
"Datapipes type stubs",
|
||||
None,
|
||||
[],
|
||||
[
|
||||
sys.executable,
|
||||
"torch/utils/data/datapipes/gen_pyi.py",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_type_stubs():
|
||||
"""Regenerate type stubs."""
|
||||
for name, hash_file, files_to_hash, cmd in TYPE_STUBS:
|
||||
if hash_file:
|
||||
if hashes := _updated_hashes(hash_file, files_to_hash):
|
||||
click.echo(
|
||||
f"Changes detected in type stub files for {name}. Regenerating..."
|
||||
)
|
||||
spin.util.run(cmd)
|
||||
hash_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with hash_file.open("w") as f:
|
||||
for file, hash in hashes.items():
|
||||
f.write(f"{hash} {file}\n")
|
||||
click.echo("Type stubs and hashes updated.")
|
||||
else:
|
||||
click.echo(f"No changes detected in type stub files for {name}.")
|
||||
else:
|
||||
click.echo(f"No hash file for {name}. Regenerating...")
|
||||
spin.util.run(cmd)
|
||||
click.echo("Type stubs regenerated.")
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_clangtidy_files():
|
||||
"""Regenerate clang-tidy files."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.linter.clang_tidy.generate_build_files",
|
||||
]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
#: These linters are expected to need less than 3s cpu time total
|
||||
VERY_FAST_LINTERS = {
|
||||
"ATEN_CPU_GPU_AGNOSTIC",
|
||||
"BAZEL_LINTER",
|
||||
"C10_NODISCARD",
|
||||
"C10_UNUSED",
|
||||
"CALL_ONCE",
|
||||
"CMAKE_MINIMUM_REQUIRED",
|
||||
"CONTEXT_DECORATOR",
|
||||
"COPYRIGHT",
|
||||
"CUBINCLUDE",
|
||||
"DEPLOY_DETECTION",
|
||||
"ERROR_PRONE_ISINSTANCE",
|
||||
"EXEC",
|
||||
"HEADER_ONLY_LINTER",
|
||||
"IMPORT_LINTER",
|
||||
"INCLUDE",
|
||||
"LINTRUNNER_VERSION",
|
||||
"MERGE_CONFLICTLESS_CSV",
|
||||
"META_NO_CREATE_UNBACKED",
|
||||
"NEWLINE",
|
||||
"NOQA",
|
||||
"NO_WORKFLOWS_ON_FORK",
|
||||
"ONCE_FLAG",
|
||||
"PYBIND11_INCLUDE",
|
||||
"PYBIND11_SPECIALIZATION",
|
||||
"PYPIDEP",
|
||||
"PYPROJECT",
|
||||
"RAWCUDA",
|
||||
"RAWCUDADEVICE",
|
||||
"ROOT_LOGGING",
|
||||
"TABS",
|
||||
"TESTOWNERS",
|
||||
"TYPEIGNORE",
|
||||
"TYPENOSKIP",
|
||||
"WORKFLOWSYNC",
|
||||
}
|
||||
|
||||
|
||||
#: These linters are expected to take a few seconds, but less than 10s cpu time total
|
||||
FAST_LINTERS = {
|
||||
"CMAKE",
|
||||
"DOCSTRING_LINTER",
|
||||
"GHA",
|
||||
"NATIVEFUNCTIONS",
|
||||
"RUFF",
|
||||
"SET_LINTER",
|
||||
"SHELLCHECK",
|
||||
"SPACES",
|
||||
}
|
||||
|
||||
|
||||
#: These linters are expected to take more than 10s cpu time total;
|
||||
#: some need more than 1 hour.
|
||||
SLOW_LINTERS = {
|
||||
"ACTIONLINT",
|
||||
"CLANGFORMAT",
|
||||
"CLANGTIDY",
|
||||
"CODESPELL",
|
||||
"FLAKE8",
|
||||
"GB_REGISTRY",
|
||||
"PYFMT",
|
||||
"PYREFLY",
|
||||
"TEST_DEVICE_BIAS",
|
||||
"TEST_HAS_MAIN",
|
||||
}
|
||||
|
||||
|
||||
ALL_LINTERS = VERY_FAST_LINTERS | FAST_LINTERS | SLOW_LINTERS
|
||||
|
||||
|
||||
LINTRUNNER_CACHE_INFO = (
|
||||
Path(".lintbin/.lintrunner.sha256"),
|
||||
[
|
||||
"requirements.txt",
|
||||
"pyproject.toml",
|
||||
".lintrunner.toml",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
LINTRUNNER_BASE_CMD = [
|
||||
"uvx",
|
||||
"--python",
|
||||
"3.10",
|
||||
"lintrunner@0.12.7",
|
||||
]
|
||||
|
||||
|
||||
@click.command()
|
||||
def setup_lint():
|
||||
"""Set up lintrunner with current CI version."""
|
||||
cmd = LINTRUNNER_BASE_CMD + ["init"]
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
|
||||
|
||||
def _check_linters():
|
||||
cmd = LINTRUNNER_BASE_CMD + ["list"]
|
||||
ret = spin.util.run(cmd, output=False, stderr=subprocess.PIPE)
|
||||
linters = {l.strip() for l in ret.stdout.decode().strip().split("\n")[1:]}
|
||||
unknown_linters = linters - ALL_LINTERS
|
||||
missing_linters = ALL_LINTERS - linters
|
||||
if unknown_linters:
|
||||
click.secho(
|
||||
f"Unknown linters found; please add them to the correct category "
|
||||
f"in .spin/cmds.py: {', '.join(unknown_linters)}",
|
||||
fg="yellow",
|
||||
)
|
||||
if missing_linters:
|
||||
click.secho(
|
||||
f"Missing linters found; please update the corresponding category "
|
||||
f"in .spin/cmds.py: {', '.join(missing_linters)}",
|
||||
fg="yellow",
|
||||
)
|
||||
return unknown_linters, missing_linters
|
||||
|
||||
|
||||
@spin.util.extend_command(
|
||||
setup_lint,
|
||||
doc=f"""
|
||||
If configuration has changed, update lintrunner.
|
||||
|
||||
Compares the stored old hashes of configuration files with new ones and
|
||||
performs setup via setup-lint if the hashes have changed.
|
||||
Hashes are stored in {LINTRUNNER_CACHE_INFO[0]}; the following files are
|
||||
considered: {", ".join(LINTRUNNER_CACHE_INFO[1])}.
|
||||
""",
|
||||
)
|
||||
@click.pass_context
|
||||
def lazy_setup_lint(ctx, parent_callback, **kwargs):
|
||||
if hashes := _updated_hashes(*LINTRUNNER_CACHE_INFO):
|
||||
click.echo(
|
||||
"Changes detected in lint configuration files. Setting up linting tools..."
|
||||
)
|
||||
parent_callback(**kwargs)
|
||||
hash_file = LINTRUNNER_CACHE_INFO[0]
|
||||
hash_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with hash_file.open("w") as f:
|
||||
for file, hash in hashes.items():
|
||||
f.write(f"{hash} {file}\n")
|
||||
click.echo("Linting tools set up and hashes updated.")
|
||||
else:
|
||||
click.echo("No changes detected in lint configuration files. Skipping setup.")
|
||||
click.echo("Regenerating version...")
|
||||
ctx.invoke(regenerate_version)
|
||||
click.echo("Regenerating type stubs...")
|
||||
ctx.invoke(regenerate_type_stubs)
|
||||
click.echo("Done.")
|
||||
_check_linters()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-a", "--apply-patches", is_flag=True)
|
||||
@click.pass_context
|
||||
def lint(ctx, apply_patches, **kwargs):
|
||||
"""Lint all files."""
|
||||
ctx.invoke(lazy_setup_lint)
|
||||
all_files_linters = VERY_FAST_LINTERS | FAST_LINTERS
|
||||
changed_files_linters = SLOW_LINTERS
|
||||
cmd = LINTRUNNER_BASE_CMD
|
||||
if apply_patches:
|
||||
cmd += ["--apply-patches"]
|
||||
all_files_cmd = cmd + [
|
||||
"--take",
|
||||
",".join(all_files_linters),
|
||||
"--all-files",
|
||||
]
|
||||
spin.util.run(all_files_cmd)
|
||||
changed_files_cmd = cmd + [
|
||||
"--take",
|
||||
",".join(changed_files_linters),
|
||||
]
|
||||
spin.util.run(changed_files_cmd)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.pass_context
|
||||
def fixlint(ctx, **kwargs):
|
||||
"""Autofix all files."""
|
||||
ctx.invoke(lint, apply_patches=True)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-a", "--apply-patches", is_flag=True)
|
||||
@click.pass_context
|
||||
def quicklint(ctx, apply_patches, **kwargs):
|
||||
"""Lint changed files."""
|
||||
ctx.invoke(lazy_setup_lint)
|
||||
cmd = LINTRUNNER_BASE_CMD
|
||||
if apply_patches:
|
||||
cmd += ["--apply-patches"]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.pass_context
|
||||
def quickfix(ctx, **kwargs):
|
||||
"""Autofix changed files."""
|
||||
ctx.invoke(quicklint, apply_patches=True)
|
||||
@ -144,7 +144,7 @@ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
|
||||
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
|
||||
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ')';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ namespace indexing {
|
||||
const EllipsisIndexType Ellipsis = EllipsisIndexType();
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const Slice& slice) {
|
||||
stream << slice.start() << ":" << slice.stop() << ":" << slice.step();
|
||||
stream << slice.start() << ':' << slice.stop() << ':' << slice.step();
|
||||
return stream;
|
||||
}
|
||||
|
||||
@ -31,12 +31,12 @@ std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index)
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices) {
|
||||
stream << "(";
|
||||
stream << '(';
|
||||
for (const auto i : c10::irange(tensor_indices.size())) {
|
||||
stream << tensor_indices[i];
|
||||
if (i < tensor_indices.size() - 1) stream << ", ";
|
||||
}
|
||||
stream << ")";
|
||||
stream << ')';
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
||||
@ -113,7 +113,7 @@ void TensorNames::checkUnique(const char* op_name) const {
|
||||
std::ostream& operator<<(std::ostream& out, const TensorName& tensorname) {
|
||||
out << tensorname.name_ << " (index ";
|
||||
out << tensorname.origin_idx_ << " of ";
|
||||
out << tensorname.origin_ << ")";
|
||||
out << tensorname.origin_ << ')';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -13,9 +13,9 @@ std::ostream& operator<<(std::ostream & out, const TensorGeometryArg& t) {
|
||||
if (t.pos == 0) {
|
||||
// 0 is distinguished; it usually indicates 'self' or the return
|
||||
// tensor
|
||||
out << "'" << t.name << "'";
|
||||
out << '\'' << t.name << '\'';
|
||||
} else {
|
||||
out << "argument #" << t.pos << " '" << t.name << "'";
|
||||
out << "argument #" << t.pos << " '" << t.name << '\'';
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -154,7 +154,7 @@ void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
|
||||
oss << "Tensor for " << t2 << " is on CPU, ";
|
||||
}
|
||||
oss << "but expected " << ((!t1->is_cpu() && !t2->is_cpu()) ? "them" : "it")
|
||||
<< " to be on GPU (while checking arguments for " << c << ")";
|
||||
<< " to be on GPU (while checking arguments for " << c << ')';
|
||||
TORCH_CHECK(false, oss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
@ -199,7 +199,7 @@ void checkScalarTypes(CheckedFrom c, const TensorArg& t,
|
||||
i++;
|
||||
}
|
||||
oss << "; but got " << t->toString()
|
||||
<< " instead (while checking arguments for " << c << ")";
|
||||
<< " instead (while checking arguments for " << c << ')';
|
||||
TORCH_CHECK(false, oss.str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,8 +43,8 @@ std::string get_mkldnn_version() {
|
||||
// https://github.com/intel/ideep/issues/29
|
||||
{
|
||||
const dnnl_version_t* ver = dnnl_version();
|
||||
ss << "Intel(R) MKL-DNN v" << ver->major << "." << ver->minor << "." << ver->patch
|
||||
<< " (Git Hash " << ver->hash << ")";
|
||||
ss << "Intel(R) MKL-DNN v" << ver->major << '.' << ver->minor << '.' << ver->patch
|
||||
<< " (Git Hash " << ver->hash << ')';
|
||||
}
|
||||
#else
|
||||
ss << "MKLDNN not found";
|
||||
@ -81,7 +81,7 @@ std::string get_openmp_version() {
|
||||
break;
|
||||
}
|
||||
if (ver_str) {
|
||||
ss << " (a.k.a. OpenMP " << ver_str << ")";
|
||||
ss << " (a.k.a. OpenMP " << ver_str << ')';
|
||||
}
|
||||
}
|
||||
#else
|
||||
@ -135,7 +135,7 @@ std::string show_config() {
|
||||
|
||||
#if defined(__GNUC__)
|
||||
{
|
||||
ss << " - GCC " << __GNUC__ << "." << __GNUC_MINOR__ << "\n";
|
||||
ss << " - GCC " << __GNUC__ << '.' << __GNUC_MINOR__ << "\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -147,7 +147,7 @@ std::string show_config() {
|
||||
|
||||
#if defined(__clang_major__)
|
||||
{
|
||||
ss << " - clang " << __clang_major__ << "." << __clang_minor__ << "." << __clang_patchlevel__ << "\n";
|
||||
ss << " - clang " << __clang_major__ << '.' << __clang_minor__ << '.' << __clang_patchlevel__ << "\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -200,7 +200,7 @@ std::string show_config() {
|
||||
ss << " - Build settings: ";
|
||||
for (const auto& pair : caffe2::GetBuildOptions()) {
|
||||
if (!pair.second.empty()) {
|
||||
ss << pair.first << "=" << pair.second << ", ";
|
||||
ss << pair.first << '=' << pair.second << ", ";
|
||||
}
|
||||
}
|
||||
ss << "\n";
|
||||
|
||||
@ -209,7 +209,7 @@ struct CodeTemplate {
|
||||
// to indent correctly in the context.
|
||||
void emitIndent(std::ostream& out, size_t indent) const {
|
||||
for ([[maybe_unused]] const auto i : c10::irange(indent)) {
|
||||
out << " ";
|
||||
out << ' ';
|
||||
}
|
||||
}
|
||||
void emitStringWithIndents(
|
||||
|
||||
@ -10,7 +10,7 @@ std::ostream& operator<<(std::ostream& out, const Dimname& dimname) {
|
||||
if (dimname.type() == NameType::WILDCARD) {
|
||||
out << "None";
|
||||
} else {
|
||||
out << "'" << dimname.symbol().toUnqualString() << "'";
|
||||
out << '\'' << dimname.symbol().toUnqualString() << '\'';
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
namespace at {
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Range& range) {
|
||||
out << "Range[" << range.begin << ", " << range.end << "]";
|
||||
out << "Range[" << range.begin << ", " << range.end << ']';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ void TensorBase::enforce_invariants() {
|
||||
|
||||
void TensorBase::print() const {
|
||||
if (defined()) {
|
||||
std::cerr << "[" << toString() << " " << sizes() << "]" << '\n';
|
||||
std::cerr << '[' << toString() << ' ' << sizes() << ']' << '\n';
|
||||
} else {
|
||||
std::cerr << "[UndefinedTensor]" << '\n';
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ APIVitals VitalsAPI;
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, TorchVital const& tv) {
|
||||
for (const auto& m : tv.attrs) {
|
||||
os << "[TORCH_VITAL] " << tv.name << "." << m.first << "\t\t "
|
||||
os << "[TORCH_VITAL] " << tv.name << '.' << m.first << "\t\t "
|
||||
<< m.second.value << "\n";
|
||||
}
|
||||
return os;
|
||||
|
||||
@ -100,18 +100,18 @@ inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
|
||||
|
||||
// this does match the way things are represented in the schema
|
||||
inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
|
||||
out << "(";
|
||||
out << '(';
|
||||
bool first = true;
|
||||
for (const auto& set : aliasInfo.beforeSets()) {
|
||||
if (first) {
|
||||
first = false;
|
||||
} else {
|
||||
out << "|";
|
||||
out << '|';
|
||||
}
|
||||
out << set.toUnqualString();
|
||||
}
|
||||
if (aliasInfo.isWrite()) {
|
||||
out << "!";
|
||||
out << '!';
|
||||
}
|
||||
if (aliasInfo.beforeSets() != aliasInfo.afterSets()) {
|
||||
out << " -> ";
|
||||
@ -120,12 +120,12 @@ inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
|
||||
if (first) {
|
||||
first = false;
|
||||
} else {
|
||||
out << "|";
|
||||
out << '|';
|
||||
}
|
||||
out << set.toUnqualString();
|
||||
}
|
||||
}
|
||||
out << ")";
|
||||
out << ')';
|
||||
return out;
|
||||
}
|
||||
} // namespace c10
|
||||
|
||||
@ -198,7 +198,7 @@ inline void swap(Blob& lhs, Blob& rhs) noexcept {
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
|
||||
return out << "Blob[" << v.TypeName() << "]";
|
||||
return out << "Blob[" << v.TypeName() << ']';
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
@ -100,7 +100,7 @@ struct TORCH_API ClassType : public NamedType {
|
||||
std::string repr_str() const override {
|
||||
std::stringstream ss;
|
||||
ss << str()
|
||||
<< " (of Python compilation unit at: " << compilation_unit().get() << ")";
|
||||
<< " (of Python compilation unit at: " << compilation_unit().get() << ')';
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
@ -58,12 +58,12 @@ std::string DispatchKeyExtractor::dumpState() const {
|
||||
std::ostringstream oss;
|
||||
for (const auto i : c10::irange(c10::utils::bitset::NUM_BITS())) {
|
||||
if (dispatch_arg_indices_reverse_.get(i)) {
|
||||
oss << "1";
|
||||
oss << '1';
|
||||
} else {
|
||||
oss << "0";
|
||||
oss << '0';
|
||||
}
|
||||
}
|
||||
oss << " " << nonFallthroughKeys_ << "\n";
|
||||
oss << ' ' << nonFallthroughKeys_ << '\n';
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
@ -69,8 +69,8 @@ private:
|
||||
|
||||
void _print_dispatch_trace(const std::string& label, const std::string& op_name, const DispatchKeySet& dispatchKeySet) {
|
||||
auto nesting_value = dispatch_trace_nesting_value();
|
||||
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
|
||||
std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
|
||||
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << ' ';
|
||||
std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << ']' << std::endl;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
|
||||
@ -570,7 +570,7 @@ void OperatorEntry::checkInvariants() const {
|
||||
|
||||
std::string OperatorEntry::listAllDispatchKeys() const {
|
||||
std::ostringstream str;
|
||||
str << "[";
|
||||
str << '[';
|
||||
|
||||
bool has_kernels = false;
|
||||
for (auto k : allDispatchKeysInFullSet()) {
|
||||
@ -584,7 +584,7 @@ std::string OperatorEntry::listAllDispatchKeys() const {
|
||||
str << k;
|
||||
has_kernels = true;
|
||||
}
|
||||
str << "]";
|
||||
str << ']';
|
||||
return str.str();
|
||||
}
|
||||
|
||||
|
||||
@ -210,9 +210,9 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
|
||||
|
||||
out << schema.name();
|
||||
if (!schema.overload_name().empty()) {
|
||||
out << "." << schema.overload_name();
|
||||
out << '.' << schema.overload_name();
|
||||
}
|
||||
out << "(";
|
||||
out << '(';
|
||||
|
||||
bool seen_kwarg_only = false;
|
||||
for (const auto i : c10::irange(schema.arguments().size())) {
|
||||
@ -273,7 +273,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
|
||||
}
|
||||
|
||||
if (need_paren) {
|
||||
out << "(";
|
||||
out << '(';
|
||||
}
|
||||
for (const auto i : c10::irange(returns.size())) {
|
||||
if (i > 0) {
|
||||
@ -288,7 +288,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
|
||||
out << "...";
|
||||
}
|
||||
if (need_paren) {
|
||||
out << ")";
|
||||
out << ')';
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -471,7 +471,7 @@ bool FunctionSchema::isForwardCompatibleWith(
|
||||
if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) {
|
||||
if (why_not) {
|
||||
why_not
|
||||
<< "'" << arguments().at(i).name() << "'"
|
||||
<< '\'' << arguments().at(i).name() << '\''
|
||||
<< " is not forward compatible with the older version of the schema";
|
||||
}
|
||||
return false;
|
||||
@ -511,7 +511,7 @@ bool FunctionSchema::isForwardCompatibleWith(
|
||||
.isForwardCompatibleWith(old.arguments().at(i))) {
|
||||
if (why_not) {
|
||||
why_not << "Out argument '"
|
||||
<< "'" << arguments().at(i).name()
|
||||
<< '\'' << arguments().at(i).name()
|
||||
<< " is not FC with the older version of the schema";
|
||||
}
|
||||
return false;
|
||||
|
||||
@ -571,7 +571,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
|
||||
if (arg.N()) {
|
||||
N = std::to_string(*arg.N());
|
||||
}
|
||||
out << "[" << N << "]";
|
||||
out << '[' << N << ']';
|
||||
} else {
|
||||
out << unopt_type->str();
|
||||
}
|
||||
@ -582,15 +582,15 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
|
||||
}
|
||||
|
||||
if (is_opt) {
|
||||
out << "?";
|
||||
out << '?';
|
||||
}
|
||||
|
||||
if (!arg.name().empty()) {
|
||||
out << " " << arg.name();
|
||||
out << ' ' << arg.name();
|
||||
}
|
||||
|
||||
if (arg.default_value()) {
|
||||
out << "=";
|
||||
out << '=';
|
||||
if ((type->kind() == c10::TypeKind::StringType ||
|
||||
unopt_type->kind() == c10::TypeKind::StringType) &&
|
||||
arg.default_value().value().isString()) {
|
||||
|
||||
@ -66,7 +66,7 @@ bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) {
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) {
|
||||
out << v.qualifiedClassName() << "." << v.name();
|
||||
out << v.qualifiedClassName() << '.' << v.name();
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -526,7 +526,7 @@ std::ostream& printMaybeAnnotatedList(
|
||||
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
|
||||
out << "annotate(" << the_list.type<c10::Type>()->annotation_str() << ", ";
|
||||
printList(out, the_list.toListRef(), "[", "]", formatter);
|
||||
out << ")";
|
||||
out << ')';
|
||||
return out;
|
||||
} else {
|
||||
return printList(out, the_list.toListRef(), "[", "]", formatter);
|
||||
@ -538,7 +538,7 @@ std::ostream& printDict(
|
||||
std::ostream& out,
|
||||
const Dict& v,
|
||||
const IValueFormatter& formatter) {
|
||||
out << "{";
|
||||
out << '{';
|
||||
|
||||
bool first = true;
|
||||
for (const auto& pair : v) {
|
||||
@ -552,7 +552,7 @@ std::ostream& printDict(
|
||||
first = false;
|
||||
}
|
||||
|
||||
out << "}";
|
||||
out << '}';
|
||||
return out;
|
||||
}
|
||||
}
|
||||
@ -565,8 +565,8 @@ static std::ostream& printMaybeAnnotatedDict(
|
||||
auto value_type = the_dict.type()->castRaw<DictType>()->getValueType();
|
||||
if (the_dict.toGenericDict().empty() ||
|
||||
!elementTypeCanBeInferredFromMembers(value_type)) {
|
||||
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ",";
|
||||
printDict(out, the_dict.toGenericDict(), formatter) << ")";
|
||||
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ',';
|
||||
printDict(out, the_dict.toGenericDict(), formatter) << ')';
|
||||
} else {
|
||||
return printDict(out, the_dict.toGenericDict(), formatter);
|
||||
}
|
||||
@ -577,7 +577,7 @@ static std::ostream& printComplex(std::ostream & out, const IValue & v) {
|
||||
c10::complex<double> d = v.toComplexDouble();
|
||||
IValue real(d.real()), imag(std::abs(d.imag()));
|
||||
auto sign = d.imag() >= 0 ? '+' : '-';
|
||||
return out << real << sign << imag << "j";
|
||||
return out << real << sign << imag << 'j';
|
||||
}
|
||||
|
||||
std::ostream& IValue::repr(
|
||||
@ -605,9 +605,9 @@ std::ostream& IValue::repr(
|
||||
if (static_cast<double>(i) == d) {
|
||||
// -0.0 (signed zero) needs to be parsed as -0.
|
||||
if (i == 0 && std::signbit(d)) {
|
||||
return out << "-" << i << ".";
|
||||
return out << '-' << i << '.';
|
||||
}
|
||||
return out << i << ".";
|
||||
return out << i << '.';
|
||||
}
|
||||
}
|
||||
auto orig_prec = out.precision();
|
||||
@ -643,20 +643,20 @@ std::ostream& IValue::repr(
|
||||
device_stream << v.toDevice();
|
||||
out << "torch.device(";
|
||||
c10::printQuotedString(out, device_stream.str());
|
||||
return out << ")";
|
||||
return out << ')';
|
||||
}
|
||||
case IValue::Tag::Generator: {
|
||||
auto generator = v.toGenerator();
|
||||
out << "torch.Generator(device=";
|
||||
c10::printQuotedString(out, generator.device().str());
|
||||
out << ", seed=" << generator.current_seed() << ")";
|
||||
out << ", seed=" << generator.current_seed() << ')';
|
||||
return out;
|
||||
}
|
||||
case IValue::Tag::GenericDict:
|
||||
return printMaybeAnnotatedDict(out, v, formatter);
|
||||
case IValue::Tag::Enum: {
|
||||
auto enum_holder = v.toEnumHolder();
|
||||
return out << enum_holder->qualifiedClassName() << "." <<
|
||||
return out << enum_holder->qualifiedClassName() << '.' <<
|
||||
enum_holder->name();
|
||||
}
|
||||
case IValue::Tag::Object: {
|
||||
@ -801,7 +801,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
||||
if (c == FP_NORMAL || c == FP_ZERO) {
|
||||
int64_t i = static_cast<int64_t>(d);
|
||||
if (static_cast<double>(i) == d) {
|
||||
return out << i << ".";
|
||||
return out << i << '.';
|
||||
}
|
||||
}
|
||||
auto orig_prec = out.precision();
|
||||
@ -852,7 +852,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
||||
return printDict(out, v.toGenericDict(), formatter);
|
||||
case IValue::Tag::PyObject: {
|
||||
auto py_obj = v.toPyObject();
|
||||
return out << "<PyObject at" << py_obj << ">";
|
||||
return out << "<PyObject at" << py_obj << '>';
|
||||
}
|
||||
case IValue::Tag::Generator:
|
||||
return out << "Generator";
|
||||
@ -862,16 +862,16 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
||||
// TODO we should attempt to call __str__ if the object defines it.
|
||||
auto obj = v.toObject();
|
||||
// print this out the way python would do it
|
||||
return out << "<" << obj->name() << " object at " << obj.get() << ">";
|
||||
return out << '<' << obj->name() << " object at " << obj.get() << '>';
|
||||
}
|
||||
case IValue::Tag::Enum: {
|
||||
auto enum_holder = v.toEnumHolder();
|
||||
return out << "Enum<" << enum_holder->unqualifiedClassName() << "." <<
|
||||
enum_holder->name() << ">";
|
||||
return out << "Enum<" << enum_holder->unqualifiedClassName() << '.' <<
|
||||
enum_holder->name() << '>';
|
||||
}
|
||||
|
||||
}
|
||||
return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << ">";
|
||||
return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << '>';
|
||||
}
|
||||
|
||||
#undef TORCH_FORALL_TAGS
|
||||
@ -1050,7 +1050,7 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
|
||||
std::stringstream err;
|
||||
err << "Cannot serialize custom bound C++ class";
|
||||
if (auto qualname = type()->name()) {
|
||||
err << " " << qualname->qualifiedName();
|
||||
err << ' ' << qualname->qualifiedName();
|
||||
}
|
||||
err << ". Please define serialization methods via def_pickle() for "
|
||||
"this class.";
|
||||
|
||||
@ -211,7 +211,7 @@ struct TORCH_API OptionalType : public UnionType {
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << getElementType()->str() << "?";
|
||||
ss << getElementType()->str() << '?';
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -240,7 +240,7 @@ struct TORCH_API OptionalType : public UnionType {
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Optional[" << getElementType()->annotation_str(printer) << "]";
|
||||
ss << "Optional[" << getElementType()->annotation_str(printer) << ']';
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -906,7 +906,7 @@ struct TORCH_API ListType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "List[" << getElementType()->annotation_str(printer) << "]";
|
||||
ss << "List[" << getElementType()->annotation_str(printer) << ']';
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -946,7 +946,7 @@ struct TORCH_API DictType : public SharedType {
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str()
|
||||
<< ")";
|
||||
<< ')';
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -1018,7 +1018,7 @@ struct TORCH_API FutureType
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "Future(" << getElementType()->str() << ")";
|
||||
ss << "Future(" << getElementType()->str() << ')';
|
||||
return ss.str();
|
||||
}
|
||||
TypePtr createWithContained(
|
||||
@ -1041,7 +1041,7 @@ struct TORCH_API FutureType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Future[" << getElementType()->annotation_str(printer) << "]";
|
||||
ss << "Future[" << getElementType()->annotation_str(printer) << ']';
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -1060,7 +1060,7 @@ struct TORCH_API AwaitType
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "Await(" << getElementType()->str() << ")";
|
||||
ss << "Await(" << getElementType()->str() << ')';
|
||||
return ss.str();
|
||||
}
|
||||
TypePtr createWithContained(
|
||||
@ -1083,7 +1083,7 @@ struct TORCH_API AwaitType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Await[" << getElementType()->annotation_str(printer) << "]";
|
||||
ss << "Await[" << getElementType()->annotation_str(printer) << ']';
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -1102,7 +1102,7 @@ struct TORCH_API RRefType
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "RRef(" << getElementType()->str() << ")";
|
||||
ss << "RRef(" << getElementType()->str() << ')';
|
||||
return ss.str();
|
||||
}
|
||||
TypePtr createWithContained(
|
||||
@ -1115,7 +1115,7 @@ struct TORCH_API RRefType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "RRef[" << getElementType()->annotation_str(printer) << "]";
|
||||
ss << "RRef[" << getElementType()->annotation_str(printer) << ']';
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
@ -11,7 +11,7 @@ std::string toString(const OperatorName& opName) {
|
||||
std::ostream& operator<<(std::ostream& os, const OperatorName& opName) {
|
||||
os << opName.name;
|
||||
if (!opName.overload_name.empty()) {
|
||||
os << "." << opName.overload_name;
|
||||
os << '.' << opName.overload_name;
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
@ -65,7 +65,7 @@ VaryingShape<T> VaryingShape<T>::merge(const VaryingShape<T>& other) const {
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
|
||||
out << "(";
|
||||
out << '(';
|
||||
if (!vs.size()) {
|
||||
out << "*)";
|
||||
return out;
|
||||
@ -79,10 +79,10 @@ std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
|
||||
if (v.has_value()) {
|
||||
out << v.value();
|
||||
} else {
|
||||
out << "*";
|
||||
out << '*';
|
||||
}
|
||||
}
|
||||
out << ")";
|
||||
out << ')';
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -105,7 +105,7 @@ std::ostream& operator<<(
|
||||
}
|
||||
auto sizes_opt = ss.sizes();
|
||||
|
||||
os << "(";
|
||||
os << '(';
|
||||
for (size_t i = 0; i < rank_opt.value(); i++) {
|
||||
if (i > 0) {
|
||||
os << ", ";
|
||||
@ -113,10 +113,10 @@ std::ostream& operator<<(
|
||||
if(sizes_opt.has_value() && sizes_opt.value()[i].is_static()) {
|
||||
os << sizes_opt.value()[i];
|
||||
} else {
|
||||
os << "*";
|
||||
os << '*';
|
||||
}
|
||||
}
|
||||
os << ")";
|
||||
os << ')';
|
||||
|
||||
return os;
|
||||
}
|
||||
@ -131,17 +131,17 @@ std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s) {
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Stride& s) {
|
||||
os << "{";
|
||||
os << '{';
|
||||
if (s.stride_index_.has_value()) {
|
||||
os << *s.stride_index_;
|
||||
} else {
|
||||
os << "*";
|
||||
os << '*';
|
||||
}
|
||||
os << ":";
|
||||
os << ':';
|
||||
if (s.stride_.has_value()) {
|
||||
os << *s.stride_;
|
||||
} else {
|
||||
os << "*";
|
||||
os << '*';
|
||||
}
|
||||
os << '}';
|
||||
return os;
|
||||
|
||||
@ -67,7 +67,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
bool has_valid_strides_info = ndim > 0 &&
|
||||
value->strides().isComplete() && value->strides().size() == ndim;
|
||||
|
||||
out << "(";
|
||||
out << '(';
|
||||
size_t i = 0;
|
||||
bool symbolic = type_verbosity() == TypeVerbosity::Symbolic;
|
||||
for (i = 0; i < *ndim; ++i) {
|
||||
@ -79,7 +79,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
} else if (symbolic) {
|
||||
out << value->symbolic_sizes().at(i);
|
||||
} else {
|
||||
out << "*";
|
||||
out << '*';
|
||||
}
|
||||
}
|
||||
if (has_valid_strides_info &&
|
||||
@ -91,7 +91,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
}
|
||||
out << value->strides()[i].value();
|
||||
}
|
||||
out << "]";
|
||||
out << ']';
|
||||
}
|
||||
if (type_verbosity() >= TypeVerbosity::Full) {
|
||||
if (value->requiresGrad()) {
|
||||
@ -107,12 +107,12 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << "device=" << *value->device();
|
||||
}
|
||||
}
|
||||
out << ")";
|
||||
out << ')';
|
||||
} else {
|
||||
if (type_verbosity() >= TypeVerbosity::Full) {
|
||||
size_t i = 0;
|
||||
if (value->requiresGrad()) {
|
||||
out << "("
|
||||
out << '('
|
||||
<< "requires_grad=" << *value->requiresGrad();
|
||||
i++;
|
||||
}
|
||||
@ -120,7 +120,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << ((i++ > 0) ? ", " : "(") << "device=" << *value->device();
|
||||
}
|
||||
if (i > 0) {
|
||||
out << ")";
|
||||
out << ')';
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -133,18 +133,18 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << *prim << "[]";
|
||||
} else if (t.kind() == TypeKind::OptionalType) {
|
||||
auto prim = t.castRaw<OptionalType>()->getElementType();
|
||||
out << *prim << "?";
|
||||
out << *prim << '?';
|
||||
} else if(t.kind() == TypeKind::FutureType) {
|
||||
auto elem = t.castRaw<FutureType>()->getElementType();
|
||||
out << "Future[" << *elem << "]";
|
||||
out << "Future[" << *elem << ']';
|
||||
} else if(t.kind() == TypeKind::RRefType) {
|
||||
auto elem = t.castRaw<RRefType>()->getElementType();
|
||||
out << "RRef[" << *elem << "]";
|
||||
out << "RRef[" << *elem << ']';
|
||||
} else if(auto tup = t.cast<TupleType>()) {
|
||||
if (tup->schema()) {
|
||||
out << "NamedTuple";
|
||||
}
|
||||
out << "(";
|
||||
out << '(';
|
||||
for(size_t i = 0; i < tup->elements().size(); ++i) {
|
||||
if(i > 0)
|
||||
out << ", ";
|
||||
@ -160,7 +160,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << *(tup->elements()[i]);
|
||||
}
|
||||
}
|
||||
out << ")";
|
||||
out << ')';
|
||||
} else if (t.kind() == TypeKind::FunctionType) {
|
||||
out << "Function";
|
||||
} else {
|
||||
@ -475,7 +475,7 @@ std::optional<TypePtr> unifyTypeList(
|
||||
why_not << "Could not unify type list since element " << i << " of type "
|
||||
<< elements.at(i)->repr_str()
|
||||
<< " did not match the types before it ("
|
||||
<< ret_type->repr_str() << ")";
|
||||
<< ret_type->repr_str() << ')';
|
||||
return std::nullopt;
|
||||
}
|
||||
ret_type = *maybe_unified;
|
||||
@ -907,13 +907,13 @@ std::string TupleType::str() const {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
ss << name()->qualifiedName();
|
||||
} else {
|
||||
ss << "(";
|
||||
ss << '(';
|
||||
for(size_t i = 0; i < elements().size(); ++i) {
|
||||
if(i > 0)
|
||||
ss << ", ";
|
||||
ss << elements()[i]->str();
|
||||
}
|
||||
ss << ")";
|
||||
ss << ')';
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -205,9 +205,9 @@ UnionType::UnionType(std::vector<TypePtr> reference, TypeKind kind) : SharedType
|
||||
for (const auto i : c10::irange(reference.size())) {
|
||||
msg << reference[i]->repr_str();
|
||||
if (i > 0) {
|
||||
msg << ",";
|
||||
msg << ',';
|
||||
}
|
||||
msg << " ";
|
||||
msg << ' ';
|
||||
}
|
||||
msg << "} has the single type " << types_[0]->repr_str()
|
||||
<< ". Use the common supertype instead of creating a Union"
|
||||
|
||||
@ -223,6 +223,62 @@ CONVERT_FROM_BF16_TEMPLATE(double)
|
||||
CONVERT_FROM_BF16_TEMPLATE(float16_t)
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
|
||||
// clang-[17, 20] crashes when autovectorizing static cast to bf16
|
||||
// Below is a workaround to have some vectorization
|
||||
// Works decently well for smaller int types
|
||||
template <typename from_type>
|
||||
inline void convertToBf16Impl(
|
||||
const from_type* __restrict src,
|
||||
c10::BFloat16* __restrict dst,
|
||||
uint64_t n) {
|
||||
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
|
||||
uint64_t loopBound = n - (n % 16);
|
||||
uint64_t i = 0;
|
||||
for (; i < loopBound; i += 16) {
|
||||
float32x4_t a, b, c, d;
|
||||
a[0] = static_cast<float>(src[i]);
|
||||
a[1] = static_cast<float>(src[i + 1]);
|
||||
a[2] = static_cast<float>(src[i + 2]);
|
||||
a[3] = static_cast<float>(src[i + 3]);
|
||||
b[0] = static_cast<float>(src[i + 4]);
|
||||
b[1] = static_cast<float>(src[i + 5]);
|
||||
b[2] = static_cast<float>(src[i + 6]);
|
||||
b[3] = static_cast<float>(src[i + 7]);
|
||||
c[0] = static_cast<float>(src[i + 8]);
|
||||
c[1] = static_cast<float>(src[i + 9]);
|
||||
c[2] = static_cast<float>(src[i + 10]);
|
||||
c[3] = static_cast<float>(src[i + 11]);
|
||||
d[0] = static_cast<float>(src[i + 12]);
|
||||
d[1] = static_cast<float>(src[i + 13]);
|
||||
d[2] = static_cast<float>(src[i + 14]);
|
||||
d[3] = static_cast<float>(src[i + 15]);
|
||||
|
||||
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
|
||||
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
|
||||
}
|
||||
|
||||
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
|
||||
for (; i < n; i++) {
|
||||
float a = static_cast<float>(src[i]);
|
||||
dstPtr[i] = vcvth_bf16_f32(a);
|
||||
}
|
||||
}
|
||||
|
||||
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
|
||||
template <> \
|
||||
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
|
||||
return convertToBf16Impl<from_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_TO_BF16_TEMPLATE(uint8_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int8_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int16_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int32_t)
|
||||
|
||||
#endif
|
||||
|
||||
inline void convertBoolToBfloat16Impl(
|
||||
const bool* __restrict src,
|
||||
c10::BFloat16* __restrict dst,
|
||||
|
||||
@ -80,7 +80,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
}
|
||||
stream << buf[i];
|
||||
}
|
||||
stream << "]";
|
||||
stream << ']';
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
}
|
||||
stream << buf[i];
|
||||
}
|
||||
stream << "]";
|
||||
stream << ']';
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
||||
@ -411,16 +411,16 @@ std::string CUDAHooks::showConfig() const {
|
||||
// HIP_VERSION value format was changed after ROCm v4.2 to include the patch number
|
||||
if(v < 500) {
|
||||
// If major=xx, minor=yy then format -> xxyy
|
||||
oss << (v / 100) << "." << (v % 10);
|
||||
oss << (v / 100) << '.' << (v % 10);
|
||||
}
|
||||
else {
|
||||
// If major=xx, minor=yy & patch=zzzzz then format -> xxyyzzzzz
|
||||
oss << (v / 10000000) << "." << (v / 100000 % 100) << "." << (v % 100000);
|
||||
oss << (v / 10000000) << '.' << (v / 100000 % 100) << '.' << (v % 100000);
|
||||
}
|
||||
#else
|
||||
oss << (v / 1000) << "." << (v / 10 % 100);
|
||||
oss << (v / 1000) << '.' << (v / 10 % 100);
|
||||
if (v % 10 != 0) {
|
||||
oss << "." << (v % 10);
|
||||
oss << '.' << (v % 10);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
@ -448,9 +448,9 @@ std::string CUDAHooks::showConfig() const {
|
||||
|
||||
|
||||
auto printCudnnStyleVersion = [&](size_t v) {
|
||||
oss << (v / 1000) << "." << (v / 100 % 10);
|
||||
oss << (v / 1000) << '.' << (v / 100 % 10);
|
||||
if (v % 100 != 0) {
|
||||
oss << "." << (v % 100);
|
||||
oss << '.' << (v % 100);
|
||||
}
|
||||
};
|
||||
|
||||
@ -461,7 +461,7 @@ std::string CUDAHooks::showConfig() const {
|
||||
if (cudnnCudartVersion != CUDART_VERSION) {
|
||||
oss << " (built against CUDA ";
|
||||
printCudaStyleVersion(cudnnCudartVersion);
|
||||
oss << ")";
|
||||
oss << ')';
|
||||
}
|
||||
oss << "\n";
|
||||
if (cudnnVersion != CUDNN_VERSION) {
|
||||
@ -472,11 +472,11 @@ std::string CUDAHooks::showConfig() const {
|
||||
#endif
|
||||
#else
|
||||
// TODO: Check if miopen has the functions above and unify
|
||||
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << "." << MIOPEN_VERSION_MINOR << "." << MIOPEN_VERSION_PATCH << "\n";
|
||||
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << '.' << MIOPEN_VERSION_MINOR << '.' << MIOPEN_VERSION_PATCH << "\n";
|
||||
#endif
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
oss << " - Magma " << MAGMA_VERSION_MAJOR << "." << MAGMA_VERSION_MINOR << "." << MAGMA_VERSION_MICRO << "\n";
|
||||
oss << " - Magma " << MAGMA_VERSION_MAJOR << '.' << MAGMA_VERSION_MINOR << '.' << MAGMA_VERSION_MICRO << "\n";
|
||||
#endif
|
||||
|
||||
return oss.str();
|
||||
|
||||
@ -42,7 +42,7 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
|
||||
|
||||
// The cache key includes all the parameters to generate_code + vec_size + dev_idx
|
||||
std::stringstream ss;
|
||||
ss << nInputs << "_" << nOutputs << f;
|
||||
ss << nInputs << '_' << nOutputs << f;
|
||||
ss << f_inputs_type_str << compute_type_str << result_type_str;
|
||||
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
|
||||
ss << extra_args_types;
|
||||
@ -144,7 +144,7 @@ static inline void launch_jitted_unrolled_kernel_dynamic(
|
||||
|
||||
// The cache key includes all the parameters to generate_code + dev_idx
|
||||
std::stringstream ss;
|
||||
ss << nInputs << "_" << nOutputs << f;
|
||||
ss << nInputs << '_' << nOutputs << f;
|
||||
ss << f_inputs_type_str << compute_type_str << result_type_str;
|
||||
ss << contiguous << dynamic_casting;
|
||||
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
|
||||
|
||||
@ -52,10 +52,10 @@ TuningContext* getTuningContext() {
|
||||
std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry) {
|
||||
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
|
||||
if (!blaslog) {
|
||||
return stream << entry.key_ << "," << entry.time_;
|
||||
return stream << entry.key_ << ',' << entry.time_;
|
||||
}
|
||||
else {
|
||||
return stream << entry.key_ << "," << entry.time_ << ",BLAS_PARAMS: " << entry.blas_sig_;
|
||||
return stream << entry.key_ << ',' << entry.time_ << ",BLAS_PARAMS: " << entry.blas_sig_;
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,10 +156,10 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
|
||||
if (isNew) {
|
||||
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
|
||||
if (!blaslog) {
|
||||
untuned_file << op_signature << "," << params_signature << std::endl;
|
||||
untuned_file << op_signature << ',' << params_signature << std::endl;
|
||||
}
|
||||
else {
|
||||
untuned_file << op_signature << "," << params_signature << ",BLAS_PARAMS: " << blas_signature << std::endl;
|
||||
untuned_file << op_signature << ',' << params_signature << ",BLAS_PARAMS: " << blas_signature << std::endl;
|
||||
}
|
||||
TUNABLE_LOG3("Untuned,", op_signature, ",", params_signature);
|
||||
}
|
||||
@ -201,7 +201,7 @@ void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const
|
||||
|
||||
if(!file_exists || file_empty) {
|
||||
for(const auto& [key, val] : validators) {
|
||||
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
|
||||
(*realtime_out_) << "Validator," << key << ',' << val << std::endl;
|
||||
realtime_out_->flush();
|
||||
}
|
||||
validators_written_ = true;
|
||||
@ -219,7 +219,7 @@ void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std
|
||||
return;
|
||||
}
|
||||
|
||||
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
|
||||
(*realtime_out_) << op_sig << ',' << param_sig << ',' << result << std::endl;
|
||||
realtime_out_->flush(); //ensure immediate write to disk
|
||||
|
||||
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
|
||||
|
||||
@ -93,7 +93,7 @@ std::string cudnnTypeToString(cudnnDataType_t dtype) {
|
||||
return "CUDNN_DATA_UINT8x4";
|
||||
default:
|
||||
std::ostringstream oss;
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ')';
|
||||
return oss.str();
|
||||
}
|
||||
}
|
||||
@ -168,7 +168,7 @@ std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {
|
||||
return "CUDNN_TENSOR_NHWC";
|
||||
default:
|
||||
std::ostringstream oss;
|
||||
oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ")";
|
||||
oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ')';
|
||||
return oss.str();
|
||||
}
|
||||
}
|
||||
|
||||
@ -346,15 +346,15 @@ void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int6
|
||||
}
|
||||
|
||||
std::ostream& operator<< (std::ostream& os, const DynamicLayer& layer) {
|
||||
os << layer.layerId() << ":" << layer.key();
|
||||
os << layer.layerId() << ':' << layer.key();
|
||||
return os;
|
||||
}
|
||||
std::ostream& operator<< (std::ostream& os, const std::vector<DynamicLayer>& dls) {
|
||||
os << "DynamicLayerStack[ ";
|
||||
for (const auto& layer : dls) {
|
||||
os << layer << " ";
|
||||
os << layer << ' ';
|
||||
}
|
||||
os << "]";
|
||||
os << ']';
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ void dumpTensor(std::ostream& ss, const Tensor& tensor) {
|
||||
if (batched) {
|
||||
ss << "Batched[lvl=" << batched->level() << " dim=" << batched->bdim() << ", ";
|
||||
dumpTensor(ss, batched->value());
|
||||
ss << "]";
|
||||
ss << ']';
|
||||
return;
|
||||
}
|
||||
ss << "Tensor" << tensor.sizes();
|
||||
@ -36,7 +36,7 @@ void dumpTensor(std::ostream& ss, const Tensor& tensor) {
|
||||
ss << "dead, ";
|
||||
}
|
||||
dumpTensor(ss, wrapped->value());
|
||||
ss << "]";
|
||||
ss << ']';
|
||||
}
|
||||
|
||||
void TensorWrapper::refreshMetadata() {
|
||||
|
||||
@ -73,7 +73,7 @@ std::string miopenTypeToString(miopenDataType_t dtype) {
|
||||
return "miopenBFloat16";
|
||||
default:
|
||||
std::ostringstream oss;
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ')';
|
||||
return oss.str();
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,7 +91,7 @@ struct OperationInfo : BaseInfo {
|
||||
std::stringstream kernelStr;
|
||||
kernelStr << kernelName;
|
||||
for (const Tensor& tensor : tensors) {
|
||||
kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
|
||||
kernelStr << ':' << BaseInfo::buildTensorString(tensor, includeBufferId);
|
||||
}
|
||||
return kernelStr.str();
|
||||
}
|
||||
|
||||
@ -39,9 +39,9 @@ std::string BaseInfo::buildTensorString(const Tensor& tensor, bool includeBuffer
|
||||
// see comments for INCLUDE_BUFFER_ID
|
||||
if (includeBufferId && deviceType == at::kMPS) {
|
||||
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":" << buffer.retainCount << ")";
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ':' << buffer.retainCount << ')';
|
||||
}
|
||||
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
|
||||
tensorStr << ':' << tensor.scalar_type() << tensor.sizes();
|
||||
return tensorStr.str();
|
||||
} else {
|
||||
return "undefined";
|
||||
|
||||
@ -167,7 +167,7 @@ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, co
|
||||
std::stringstream ss;
|
||||
ss << arg_name << " should be greater than zero but got (";
|
||||
std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
|
||||
ss << args.back() << ")" << " (while checking arguments for " << c << ")";
|
||||
ss << args.back() << ")" << " (while checking arguments for " << c << ')';
|
||||
TORCH_CHECK(false, ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -639,7 +639,7 @@ static std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params)
|
||||
<< " deterministic = " << params.deterministic
|
||||
<< " cudnn_enabled = " << params.cudnn_enabled
|
||||
<< " allow_tf32 = " << params.allow_tf32
|
||||
<< "}";
|
||||
<< '}';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -847,7 +847,7 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t
|
||||
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
|
||||
<< ", window="; \
|
||||
if (window.defined()) { \
|
||||
SS << window.toString() << "{" << window.sizes() << "}"; \
|
||||
SS << window.toString() << '{' << window.sizes() << '}'; \
|
||||
} else { \
|
||||
SS << "None"; \
|
||||
} \
|
||||
@ -1046,7 +1046,7 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const std::optional<int64_
|
||||
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
|
||||
<< ", window="; \
|
||||
if (window.defined()) { \
|
||||
SS << window.toString() << "{" << window.sizes() << "}"; \
|
||||
SS << window.toString() << '{' << window.sizes() << '}'; \
|
||||
} else { \
|
||||
SS << "None"; \
|
||||
} \
|
||||
|
||||
@ -904,19 +904,11 @@ Tensor mvlgamma(const Tensor& self, int64_t p) {
|
||||
return args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER);
|
||||
}
|
||||
|
||||
// since mvlgamma_ has different signature from its
|
||||
// out and functional variant, we explicitly
|
||||
// define it (instead of using structured kernel).
|
||||
Tensor& mvlgamma_(Tensor& self, int64_t p) {
|
||||
mvlgamma_check(self, p);
|
||||
Tensor args = native::arange(
|
||||
-p *HALF + HALF,
|
||||
HALF,
|
||||
HALF,
|
||||
optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().layout_opt(),
|
||||
self.options().device_opt(),
|
||||
self.options().pinned_memory_opt());
|
||||
args = args.add(self.unsqueeze(-1));
|
||||
const auto p2_sub_p = static_cast<double>(p * (p - 1));
|
||||
return self.copy_(args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER));
|
||||
return at::mvlgamma_out(self, self, p);
|
||||
}
|
||||
|
||||
Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {
|
||||
|
||||
@ -78,9 +78,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const SwizzleType swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const SwizzleType swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
|
||||
@ -11,7 +11,7 @@ static inline std::ostream& operator<<(std::ostream& out, dim3 dim) {
|
||||
if (dim.y == 1 && dim.z == 1) {
|
||||
out << dim.x;
|
||||
} else {
|
||||
out << "[" << dim.x << "," << dim.y << "," << dim.z << "]";
|
||||
out << '[' << dim.x << ',' << dim.y << ',' << dim.z << ']';
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -27,7 +27,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
|
||||
out << "input_mult=[";
|
||||
for (int i = 0; i < 3; i++) {
|
||||
if (i != 0) {
|
||||
out << ",";
|
||||
out << ',';
|
||||
}
|
||||
out << config.input_mult[i];
|
||||
}
|
||||
@ -35,7 +35,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
|
||||
out << "output_mult=[";
|
||||
for (int i = 0; i < 2; i++) {
|
||||
if (i != 0) {
|
||||
out << ",";
|
||||
out << ',';
|
||||
}
|
||||
out << config.output_mult[i];
|
||||
}
|
||||
@ -49,7 +49,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
|
||||
out << "block=" << config.block() << ", ";
|
||||
out << "grid=" << config.grid() << ", ";
|
||||
out << "global_memory_size=" << config.global_memory_size();
|
||||
out << ")";
|
||||
out << ')';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -364,9 +364,9 @@ void f8f8bf16_grouped_gemm_impl_sm90(
|
||||
// reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(
|
||||
// stride_output_h + group_count);
|
||||
|
||||
// std::cout << "PTRS " << mat_a.data_ptr() << " " << mat_b.data_ptr() << "
|
||||
// std::cout << "PTRS " << mat_a.data_ptr() << ' ' << mat_b.data_ptr() << "
|
||||
// "
|
||||
// << out.data_ptr() << " " << scale_a.data_ptr() << " "
|
||||
// << out.data_ptr() << ' ' << scale_a.data_ptr() << ' '
|
||||
// << scale_b.data_ptr() << "\n";
|
||||
// for (int i = 0; i < group_count; i++) {
|
||||
// std::cout << "A " << (void*)inputA_ptrs_h[i] << "\n";
|
||||
|
||||
@ -1057,14 +1057,14 @@ std::string generate_code(
|
||||
// TODO these arrays are potentially of the different types, use function
|
||||
// traits to determine the types
|
||||
declare_load_arrays << f_inputs_type << " arg" << std::to_string(i)
|
||||
<< "[" << std::to_string(thread_work_size) << "];\n";
|
||||
<< '[' << std::to_string(thread_work_size) << "];\n";
|
||||
}
|
||||
env.s("declare_load_arrays", declare_load_arrays.str());
|
||||
|
||||
std::stringstream declare_store_arrays;
|
||||
for (int i = 0; i < nOutputs; i++) {
|
||||
declare_store_arrays << result_type << " out" << std::to_string(i)
|
||||
<< "[" << std::to_string(thread_work_size) << "];\n";
|
||||
<< '[' << std::to_string(thread_work_size) << "];\n";
|
||||
}
|
||||
env.s("declare_store_arrays", declare_store_arrays.str());
|
||||
|
||||
@ -1217,7 +1217,7 @@ std::string generate_code(
|
||||
for (const auto i : c10::irange(nInputs)){
|
||||
auto i_string = std::to_string(i);
|
||||
vector_inputs << "auto * input" << i_string <<
|
||||
" = reinterpret_cast<const scalar_t*>(data[" << i_string << "+" << nOutputs << "])" <<
|
||||
" = reinterpret_cast<const scalar_t*>(data[" << i_string << '+' << nOutputs << "])" <<
|
||||
" + block_work_size * idx;\n";
|
||||
}
|
||||
env.s("vector_inputs", vector_inputs.str());
|
||||
@ -1543,17 +1543,17 @@ NvrtcFunction jit_pwise_function(
|
||||
|
||||
// Constructs file path by appending constructed cubin name to cache path
|
||||
std::stringstream ss;
|
||||
ss << *cache_dir << "/";
|
||||
ss << *cache_dir << '/';
|
||||
ss << kernel_name;
|
||||
#ifdef USE_ROCM
|
||||
ss << "_arch" << prop->gcnArchName;
|
||||
#else
|
||||
ss << "_arch" << cuda_major << "." << cuda_minor;
|
||||
ss << "_arch" << cuda_major << '.' << cuda_minor;
|
||||
#endif
|
||||
ss << "_nvrtc" << nvrtc_major << "." << nvrtc_minor;
|
||||
ss << "_nvrtc" << nvrtc_major << '.' << nvrtc_minor;
|
||||
ss << (compile_to_sass ? "_sass" : "_ptx");
|
||||
ss << "_" << code.length();
|
||||
ss << "_" << hash_code;
|
||||
ss << '_' << code.length();
|
||||
ss << '_' << hash_code;
|
||||
file_path = ss.str();
|
||||
|
||||
std::ifstream readin{file_path, std::ios::in | std::ifstream::binary};
|
||||
|
||||
@ -115,7 +115,7 @@ std::ostream& operator<<(
|
||||
std::copy(
|
||||
strides.begin(), strides.end() - 1, std::ostream_iterator<int>(oss, ","));
|
||||
oss << sizes.back();
|
||||
output << oss.str() << "}";
|
||||
output << oss.str() << '}';
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ std::ostream& operator<<(std::ostream& out, const ConvParams& params) {
|
||||
<< " transposed = " << params.transposed
|
||||
<< " output_padding = " << IntArrayRef{params.output_padding}
|
||||
<< " groups = " << params.groups << " benchmark = " << params.benchmark
|
||||
<< " deterministic = " << params.deterministic << "}";
|
||||
<< " deterministic = " << params.deterministic << '}';
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -337,10 +337,6 @@ Tensor _convolution_out(
|
||||
TORCH_CHECK(
|
||||
3 == ndim || 4 == ndim || 5 == ndim,
|
||||
"convolution only supports 3D, 4D, 5D tensor");
|
||||
// get computation format for Conv/TransposedConv
|
||||
bool is_channels_last_suggested =
|
||||
use_channels_last_for_conv(input_r, weight_r);
|
||||
|
||||
Tensor input = input_r, weight = weight_r;
|
||||
// PyTorch does not support ChannelsLast1D case,
|
||||
// thus we need the transformation here
|
||||
@ -348,13 +344,8 @@ Tensor _convolution_out(
|
||||
input = view4d(input_r);
|
||||
weight = view4d(weight_r);
|
||||
}
|
||||
// ensure the input/weight/bias/output are congituous in desired format
|
||||
at::MemoryFormat mfmt = is_channels_last_suggested
|
||||
? get_cl_tag_by_ndim(input.ndimension())
|
||||
: at::MemoryFormat::Contiguous;
|
||||
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
|
||||
input = input.contiguous(mfmt);
|
||||
weight = weight.contiguous(mfmt);
|
||||
// get computation format for Conv/TransposedConv
|
||||
bool is_channels_last_suggested = use_channels_last_for_conv(input, weight);
|
||||
|
||||
auto k = weight.ndimension();
|
||||
if (k == input.ndimension() + 1) {
|
||||
@ -388,6 +379,14 @@ Tensor _convolution_out(
|
||||
expand_param_if_needed(output_padding_, "output_padding", dim);
|
||||
params.groups = groups_;
|
||||
}
|
||||
|
||||
// ensure the input/weight/bias/output are congituous in desired format
|
||||
at::MemoryFormat mfmt = is_channels_last_suggested
|
||||
? get_cl_tag_by_ndim(input.ndimension())
|
||||
: at::MemoryFormat::Contiguous;
|
||||
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
|
||||
input = input.contiguous(mfmt);
|
||||
weight = weight.contiguous(mfmt);
|
||||
check_shape_forward(input, weight, bias, params, true);
|
||||
|
||||
Tensor output;
|
||||
@ -514,18 +513,9 @@ Tensor convolution_overrideable(
|
||||
at::borrow_from_optional_tensor(bias_r_opt);
|
||||
const Tensor& bias_r = *bias_r_maybe_owned;
|
||||
|
||||
auto k = weight_r.ndimension();
|
||||
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
|
||||
if (xpu_conv_use_channels_last(input_r, weight_r)) {
|
||||
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d
|
||||
: at::MemoryFormat::ChannelsLast;
|
||||
}
|
||||
Tensor input_c = input_r.contiguous(backend_memory_format);
|
||||
Tensor weight_c = weight_r.contiguous(backend_memory_format);
|
||||
|
||||
return _convolution(
|
||||
input_c,
|
||||
weight_c,
|
||||
input_r,
|
||||
weight_r,
|
||||
bias_r,
|
||||
stride_,
|
||||
padding_,
|
||||
|
||||
342
aten/src/ATen/native/mkldnn/xpu/ScaledBlas.cpp
Normal file
342
aten/src/ATen/native/mkldnn/xpu/ScaledBlas.cpp
Normal file
@ -0,0 +1,342 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/xpu/Blas.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace {
|
||||
/*
|
||||
* Scaling Type Determination:
|
||||
* ---------------------------
|
||||
* Conditions and corresponding Scaling Types:
|
||||
*
|
||||
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
|
||||
* - Returns BlockWise (with additional size checks).
|
||||
*
|
||||
* - Else if scale.numel() == 1:
|
||||
* - Returns TensorWise.
|
||||
*
|
||||
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) ==
|
||||
* 1:
|
||||
* - Returns RowWise.
|
||||
*
|
||||
* - Otherwise:
|
||||
* - Returns Error.
|
||||
*/
|
||||
|
||||
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return at::isFloat8Type(t.scalar_type()) &&
|
||||
scale.scalar_type() == at::kFloat && scale.numel() == 1;
|
||||
}
|
||||
|
||||
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return (
|
||||
at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat &&
|
||||
scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 &&
|
||||
scale.is_contiguous());
|
||||
}
|
||||
|
||||
bool is_desired_scaling(
|
||||
const at::Tensor& t,
|
||||
const at::Tensor& scale,
|
||||
ScalingType desired_scaling) {
|
||||
auto result = desired_scaling == ScalingType::TensorWise
|
||||
? is_tensorwise_scaling(t, scale)
|
||||
: is_rowwise_scaling(t, scale);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<ScalingType, ScalingType> get_joint_scaling(
|
||||
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
|
||||
const at::Tensor& a,
|
||||
const at::Tensor& b,
|
||||
const at::Tensor& scale_a,
|
||||
const at::Tensor& scale_b) {
|
||||
for (auto [lhs, rhs] : options) {
|
||||
if (is_desired_scaling(a, scale_a, lhs) &&
|
||||
is_desired_scaling(b.t(), scale_b.t(), rhs)) {
|
||||
return {lhs, rhs};
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Invalid scaling configuration.\n"
|
||||
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
|
||||
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
|
||||
a.size(0),
|
||||
", 1) and scale_b should be (1, ",
|
||||
b.size(1),
|
||||
"), and both should be contiguous.\n"
|
||||
"Got a.dtype()=",
|
||||
a.scalar_type(),
|
||||
", scale_a.dtype()=",
|
||||
scale_a.scalar_type(),
|
||||
", scale_a.size()=",
|
||||
scale_a.sizes(),
|
||||
", scale_a.stride()=",
|
||||
scale_a.strides(),
|
||||
", ",
|
||||
"b.dtype()=",
|
||||
b.scalar_type(),
|
||||
", scale_b.dtype()=",
|
||||
scale_b.scalar_type(),
|
||||
", scale_b.size()=",
|
||||
scale_b.sizes(),
|
||||
" and scale_b.stride()=",
|
||||
scale_b.strides());
|
||||
}
|
||||
|
||||
Tensor& _scaled_gemm(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const ScalingType scaling_choice_a,
|
||||
const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out,
|
||||
const std::optional<Tensor>& alpha = std::nullopt) {
|
||||
// TODO: scale_result and alpha is not defined or used!
|
||||
std::optional<Tensor> scaled_result = std::nullopt;
|
||||
at::native::onednn::scaled_matmul(
|
||||
mat1,
|
||||
mat2,
|
||||
out,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
scaled_result,
|
||||
use_fast_accum);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output
|
||||
// matrices Scales are only applicable when matrices are of Float8 type and
|
||||
// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit
|
||||
// type, scale_result is not applied. Known limitations:
|
||||
// - Only works if mat1 is row-major and mat2 is column-major
|
||||
// - Only works if matrices sizes are divisible by 32
|
||||
// - If 1-dimensional tensors are used then scale_a should be size =
|
||||
// mat1.size(0)
|
||||
// and scale_b should have size = to mat2.size(1)
|
||||
// Arguments:
|
||||
// - `mat1`: the first operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `mat2`: the second operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
|
||||
// - `out_dtype`: the output dtype, can either be a float8 or a higher
|
||||
// precision floating point type
|
||||
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_result`: a scalar tensor with the scale of the output, only
|
||||
// utilized if the output is a float8 type
|
||||
// - `use_fast_accum`: Not applicable for XPU. For now, it should always be
|
||||
// false.
|
||||
// - `out`: a reference to the output tensor
|
||||
|
||||
Tensor& _scaled_mm_out_xpu(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// Note: fast_accum is not supported in XPU for now.
|
||||
TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now.");
|
||||
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0],
|
||||
"mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0],
|
||||
"x",
|
||||
mat1.sizes()[1],
|
||||
" and ",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
")");
|
||||
|
||||
// Check what type of scaling we are doing based on inputs. This list is
|
||||
// sorted by decreasing priority.
|
||||
|
||||
// List of supported datatypes for XPU with oneDNN:
|
||||
// https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types
|
||||
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
|
||||
{
|
||||
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
|
||||
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
|
||||
},
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b);
|
||||
TORCH_CHECK(
|
||||
!scale_result ||
|
||||
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
|
||||
"scale_result must be a float scalar");
|
||||
TORCH_CHECK(
|
||||
!bias || bias->numel() == mat2.sizes()[1],
|
||||
"Bias must be size ",
|
||||
mat2.sizes()[1],
|
||||
" but got ",
|
||||
bias->numel());
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat1.sizes()[0],
|
||||
"x",
|
||||
mat1.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK(
|
||||
mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0,
|
||||
"mat2 shape (",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
") must be divisible by 16");
|
||||
// Check types
|
||||
TORCH_CHECK(
|
||||
!out_dtype || *out_dtype == out.scalar_type(),
|
||||
"out_dtype must match output matrix type");
|
||||
TORCH_CHECK(
|
||||
at::isFloat8Type(mat1.scalar_type()),
|
||||
"Expected mat1 to be Float8 matrix got ",
|
||||
mat1.scalar_type());
|
||||
TORCH_CHECK(
|
||||
at::isFloat8Type(mat2.scalar_type()),
|
||||
"Expected mat2 to be Float8 matrix got ",
|
||||
mat2.scalar_type());
|
||||
// TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not
|
||||
// support 2D scales, only 1D. Needs to add more checks there.
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(
|
||||
bias->scalar_type() == kFloat ||
|
||||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
|
||||
bias->scalar_type() == c10::ScalarType::Half,
|
||||
"Bias must be Float32 or BFloat16 or Half, but got ",
|
||||
bias->scalar_type());
|
||||
}
|
||||
|
||||
{
|
||||
auto bias_ = bias.value_or(Tensor());
|
||||
auto scale_result_ = scale_result.value_or(Tensor());
|
||||
|
||||
// NOLINTNEXTLINE(*c-array*)
|
||||
TensorArg targs[]{
|
||||
{out, "out", 0},
|
||||
{mat1, "mat1", 1},
|
||||
{mat2, "mat2", 2},
|
||||
{bias_, "bias", 3},
|
||||
{scale_a, "scale_a", 4},
|
||||
{scale_b, "scale_b", 5},
|
||||
{scale_result_, "scale_result", 6}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
}
|
||||
|
||||
// Validation checks have passed lets resize the output to actual size
|
||||
IntArrayRef mat1_sizes = mat1.sizes();
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
|
||||
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
|
||||
// kernels do not support this case).
|
||||
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
|
||||
// `out` was created with `at::empty`. In the case where we are multiplying
|
||||
// MxK by KxN and K is the zero dim, we need to initialize here to properly
|
||||
// return a tensor of zeros.
|
||||
if (mat1_sizes[1] == 0) {
|
||||
out.zero_();
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// TODO: Scale_result is not supported by now!!
|
||||
return _scaled_gemm(
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
Tensor _scaled_mm_xpu(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
|
||||
return _scaled_mm_out_xpu(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias,
|
||||
scale_result,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
@ -1,3 +1,4 @@
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
@ -8,7 +9,6 @@
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
|
||||
namespace at::native::onednn {
|
||||
|
||||
at::Tensor broadcast_bias2D(
|
||||
at::Tensor& dst,
|
||||
at::Tensor& bias,
|
||||
@ -328,4 +328,236 @@ void quantized_matmul(
|
||||
result.copy_(dst);
|
||||
}
|
||||
|
||||
// Describes how to configure oneDNN scales for a given role/ScalingType
|
||||
struct ScaleSpec {
|
||||
// specifies the way scale values will be applied to an ARG tensor.
|
||||
int mask;
|
||||
// specifies how scales are grouped along dimensions where
|
||||
// multiple scale factors are used.
|
||||
dnnl::memory::dims groups;
|
||||
// specifies data type for scale factors.
|
||||
dnnl::memory::data_type dtype;
|
||||
|
||||
// Helper to compute expected number of elements for scale tensors
|
||||
// arg_type: "src" for SRC (groups pattern {1, X}),
|
||||
// "wei" for WEIGHTS (groups pattern {X, 1})
|
||||
int64_t expected_numel(
|
||||
int64_t outer_dim,
|
||||
int64_t inner_dim,
|
||||
const std::string& arg_type) const {
|
||||
if (groups == dnnl::memory::dims{1, 1})
|
||||
return 1; // tensorwise scaling
|
||||
|
||||
TORCH_CHECK(
|
||||
arg_type == "src" || arg_type == "wei",
|
||||
"Expected arg_type to be 'src' or 'wei', but got '",
|
||||
arg_type,
|
||||
"'");
|
||||
|
||||
// For rowwise: SRC groups={1, K}, WEI groups={K, 1}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(groups == dnnl::memory::dims{1, inner_dim} ||
|
||||
groups == dnnl::memory::dims{inner_dim, 1}),
|
||||
"The groups must be either {1, inner_dim} or {inner_dim, 1}. But got ",
|
||||
groups,
|
||||
".");
|
||||
return outer_dim;
|
||||
}
|
||||
|
||||
// Normalize an incoming scale tensor to contiguous storage and appropriate
|
||||
// dtype/view
|
||||
at::Tensor normalize(const at::Tensor& scale) const {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dtype == dnnl::memory::data_type::f32,
|
||||
"tensor scale currently must be f32, but got scale dtype: ",
|
||||
scale.scalar_type());
|
||||
return scale.to(at::kFloat).contiguous();
|
||||
}
|
||||
};
|
||||
|
||||
// This function defines how to set scales mask and groups according to:
|
||||
// https://github.com/uxlfoundation/oneDNN/blob/main/tests/benchdnn/doc/knobs_attr.md#--attr-scales
|
||||
// The returned value will be used in
|
||||
// `set_scales(arg, mask, groups, data_type)`.
|
||||
inline ScaleSpec make_scale_spec(
|
||||
at::blas::ScalingType scaling_type,
|
||||
int64_t M,
|
||||
int64_t K,
|
||||
int64_t N,
|
||||
const std::string& arg_type) {
|
||||
TORCH_CHECK(
|
||||
arg_type == "src" || arg_type == "wei",
|
||||
"Expected arg_type to be 'src' or 'wei', but got '",
|
||||
arg_type,
|
||||
"'");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(scaling_type == at::blas::ScalingType::TensorWise ||
|
||||
scaling_type == at::blas::ScalingType::RowWise),
|
||||
"Currently only support scaling_type for TensorWise or RowWise");
|
||||
int64_t dim = K; // Currently only K is used for grouping
|
||||
bool is_src = (arg_type == "src");
|
||||
if (scaling_type == at::blas::ScalingType::TensorWise) {
|
||||
// Scale tensorwise. The same as `--attr-scales=common`.
|
||||
// mask=0 : scale whole tensor
|
||||
// groups={1, 1}: indicates that there is only one group for scaling
|
||||
return {0, {1, 1}, dnnl::memory::data_type::f32};
|
||||
} else {
|
||||
// (scaling_type == at::blas::ScalingType::RowWise)
|
||||
// Scale RowWise. The same as `--attr-scales=per_dim_01`.
|
||||
// mask={(1 << 0) | (1 << 1)}: Scale on both dim0 and dim1
|
||||
// SRC: groups={1, K}, WEIGHTS: groups={K, 1}
|
||||
return {
|
||||
(1 << 0) | (1 << 1),
|
||||
is_src ? dnnl::memory::dims{1, dim} : dnnl::memory::dims{dim, 1},
|
||||
dnnl::memory::data_type::f32};
|
||||
}
|
||||
}
|
||||
|
||||
sycl::event scaled_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
Tensor& result,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
at::blas::ScalingType scaling_choice_a,
|
||||
at::blas::ScalingType scaling_choice_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
bool use_fast_accum) {
|
||||
auto& engine = GpuEngineManager::Instance().get_engine();
|
||||
auto& stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
// This function will do steps with following steps
|
||||
// 1. create memory descriptor
|
||||
// 2. call write_to_dnnl_memory() to actually write memory
|
||||
// 3. execute
|
||||
|
||||
const int64_t M = mat1.size(0);
|
||||
const int64_t K = mat1.size(1);
|
||||
const int64_t N = mat2.size(1);
|
||||
|
||||
// 1.1 Create memory descriptor
|
||||
dnnl::memory::desc src_md = get_onednn_md(mat1);
|
||||
dnnl::memory::desc weights_md = get_onednn_md(mat2);
|
||||
dnnl::memory::desc dst_md = get_onednn_md(result);
|
||||
|
||||
// scale_a and scale_b has already be checked in `is_desired_scaling()` call.
|
||||
// So we could directly get their memory desc and set later.
|
||||
dnnl::memory::desc scale_a_md = get_onednn_md(scale_a);
|
||||
dnnl::memory::desc scale_b_md = get_onednn_md(scale_b);
|
||||
|
||||
dnnl::memory::desc bias_md;
|
||||
bool with_bias = bias.has_value();
|
||||
at::Tensor possible_reshaped_bias = bias.value_or(at::Tensor());
|
||||
if (with_bias) {
|
||||
if (possible_reshaped_bias.dim() == 1) {
|
||||
possible_reshaped_bias =
|
||||
possible_reshaped_bias.reshape({1, possible_reshaped_bias.size(0)});
|
||||
bias_md = get_onednn_md(possible_reshaped_bias);
|
||||
} else {
|
||||
bias_md = get_onednn_md(possible_reshaped_bias);
|
||||
}
|
||||
}
|
||||
|
||||
// 1.2 Create primitive descriptor and set scales mask
|
||||
const ScaleSpec src_spec = make_scale_spec(scaling_choice_a, M, K, N, "src");
|
||||
const ScaleSpec wei_spec = make_scale_spec(scaling_choice_b, M, K, N, "wei");
|
||||
|
||||
dnnl::primitive_attr op_attr = dnnl::primitive_attr();
|
||||
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn())
|
||||
op_attr.set_deterministic(true);
|
||||
#endif
|
||||
|
||||
std::vector<int64_t> default_groups;
|
||||
op_attr.set_scales(
|
||||
DNNL_ARG_SRC, src_spec.mask, src_spec.groups, src_spec.dtype);
|
||||
op_attr.set_scales(
|
||||
DNNL_ARG_WEIGHTS, wei_spec.mask, wei_spec.groups, wei_spec.dtype);
|
||||
// scale_result tensor currently only supports scalar(TensorWise Scaling).
|
||||
bool with_dst_scale = scale_result && scale_result->defined();
|
||||
if (with_dst_scale) {
|
||||
op_attr.set_scales(DNNL_ARG_DST, 0, {1}, dnnl::memory::data_type::f32);
|
||||
}
|
||||
|
||||
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
// 1.3 Create the matmul primitive descriptor
|
||||
dnnl::matmul::primitive_desc matmul_pd = with_bias
|
||||
? dnnl::matmul::primitive_desc(
|
||||
engine, src_md, weights_md, bias_md, dst_md, op_attr)
|
||||
: dnnl::matmul::primitive_desc(
|
||||
engine, src_md, weights_md, dst_md, op_attr);
|
||||
|
||||
// 1.4 (Possible) Additional Checks
|
||||
// TODO: In case there are memory desc does not align with the actual tensor,
|
||||
// we might need to reorder weights similar to CPU's reorder_if_differ_in()
|
||||
// call. For example, weights not the same as matmul_pd.weights_desc(),
|
||||
|
||||
// 2. Prepare memory
|
||||
|
||||
// Create memory
|
||||
auto src_usr_m = make_onednn_memory(src_md, engine, mat1.data_ptr());
|
||||
auto weights_usr_m = make_onednn_memory(weights_md, engine, mat2.data_ptr());
|
||||
auto dst_usr_m = make_onednn_memory(dst_md, engine, result.data_ptr());
|
||||
dnnl::memory b_usr_m;
|
||||
if (with_bias) {
|
||||
b_usr_m =
|
||||
make_onednn_memory(bias_md, engine, possible_reshaped_bias.data_ptr());
|
||||
}
|
||||
|
||||
// Prepare runtime scale memories (flat 1-D views) using the specs
|
||||
auto make_scale_mem_from_spec = [&](const ScaleSpec& spec,
|
||||
int64_t expected_numel,
|
||||
const at::Tensor& scale_tensor) {
|
||||
at::Tensor prepared = spec.normalize(scale_tensor);
|
||||
TORCH_CHECK(
|
||||
prepared.numel() == expected_numel,
|
||||
"Scale buffer length mismatch. Expected ",
|
||||
expected_numel,
|
||||
", got ",
|
||||
prepared.numel());
|
||||
dnnl::memory::desc scale_md(
|
||||
{prepared.numel()}, spec.dtype, dnnl::memory::format_tag::x);
|
||||
return make_onednn_memory(scale_md, engine, prepared.data_ptr());
|
||||
};
|
||||
|
||||
auto scratchpad =
|
||||
make_onednn_memory(matmul_pd.scratchpad_desc(), engine, nullptr);
|
||||
|
||||
// 3. Setup Args for exec
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
args.insert({DNNL_ARG_SRC, src_usr_m});
|
||||
args.insert({DNNL_ARG_WEIGHTS, weights_usr_m});
|
||||
args.insert({DNNL_ARG_DST, dst_usr_m});
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
|
||||
if (with_bias) {
|
||||
args.insert({DNNL_ARG_BIAS, b_usr_m});
|
||||
}
|
||||
|
||||
// Attach runtime scales using specs
|
||||
auto src_sc_mem = make_scale_mem_from_spec(
|
||||
src_spec, src_spec.expected_numel(M, K, "src"), scale_a);
|
||||
auto wei_sc_mem = make_scale_mem_from_spec(
|
||||
wei_spec, wei_spec.expected_numel(N, K, "wei"), scale_b);
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_mem});
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_sc_mem});
|
||||
if (with_dst_scale) {
|
||||
// Bind single f32 scalar as DST scale
|
||||
at::Tensor dst_scale_f32 = scale_result->to(at::kFloat).contiguous();
|
||||
dnnl::memory::desc dst_sc_md(
|
||||
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
|
||||
auto dst_sc_mem =
|
||||
make_onednn_memory(dst_sc_md, engine, dst_scale_f32.data_ptr());
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_mem});
|
||||
}
|
||||
|
||||
dnnl::matmul matmul_p = dnnl::matmul(matmul_pd);
|
||||
sycl::event matmul_fwd_event =
|
||||
dnnl::sycl_interop::execute(matmul_p, stream, args);
|
||||
return matmul_fwd_event;
|
||||
}
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
||||
@ -78,6 +78,10 @@ dnnl::memory::data_type get_onednn_dtype(
|
||||
return dnnl::memory::data_type::f32;
|
||||
case at::ScalarType::BFloat16:
|
||||
return dnnl::memory::data_type::bf16;
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
return dnnl::memory::data_type::f8_e4m3;
|
||||
case at::ScalarType::Float8_e5m2:
|
||||
return dnnl::memory::data_type::f8_e5m2;
|
||||
default:
|
||||
if (!allow_undef) {
|
||||
TORCH_CHECK(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
@ -202,4 +203,16 @@ void sdpa_backward(
|
||||
Tensor& grad_query,
|
||||
Tensor& grad_key,
|
||||
Tensor& grad_value);
|
||||
|
||||
sycl::event scaled_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
Tensor& result,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
at::blas::ScalingType scaling_choice_a,
|
||||
at::blas::ScalingType scaling_choice_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
bool use_fast_accum);
|
||||
} // namespace at::native::onednn
|
||||
|
||||
@ -82,7 +82,6 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
|
||||
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
|
||||
std::string getMPSShapeString(MPSShape* shape);
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
|
||||
std::string to_hex_key(float);
|
||||
std::string getArrayRefString(const IntArrayRef s);
|
||||
// use has_storage() on the returned tensor to determine if src actually is a view
|
||||
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
|
||||
|
||||
@ -301,10 +301,6 @@ std::string getArrayRefString(const IntArrayRef s) {
|
||||
return fmt::to_string(fmt::join(s, ","));
|
||||
}
|
||||
|
||||
std::string to_hex_key(float f) {
|
||||
return fmt::format("{:a}", f);
|
||||
}
|
||||
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
|
||||
fmt::basic_memory_buffer<char, 100> buffer;
|
||||
auto buf_iterator = std::back_inserter(buffer);
|
||||
|
||||
@ -96,7 +96,9 @@ kernel void addmm(
|
||||
auto bias =
|
||||
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
|
||||
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
|
||||
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
|
||||
static_cast<T>(
|
||||
c10::metal::mul(alpha_beta[0], sum) +
|
||||
c10::metal::mul(alpha_beta[1], bias));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -121,7 +121,7 @@ Tensor& do_metal_addmm(const Tensor& self,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta,
|
||||
const Tensor& bias) {
|
||||
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
|
||||
if (beta.isFloatingPoint() && alpha.isFloatingPoint() && beta.toDouble() == 0 && alpha.toDouble() == 1) {
|
||||
return do_metal_mm(self, other, output);
|
||||
}
|
||||
auto stream = getCurrentMPSStream();
|
||||
@ -147,13 +147,15 @@ Tensor& do_metal_addmm(const Tensor& self,
|
||||
std::array<int64_t, 2> i64;
|
||||
std::array<int32_t, 2> i32;
|
||||
std::array<float, 2> f32;
|
||||
} alpha_beta;
|
||||
std::array<c10::complex<float>, 2> c64;
|
||||
} alpha_beta{};
|
||||
if (output.scalar_type() == kLong) {
|
||||
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
|
||||
} else if (c10::isIntegralType(output.scalar_type(), true)) {
|
||||
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
|
||||
} else if (c10::isComplexType(output.scalar_type())) {
|
||||
alpha_beta.c64 = {alpha.toComplexFloat(), beta.toComplexFloat()};
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
|
||||
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
|
||||
}
|
||||
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
|
||||
|
||||
@ -244,8 +244,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar
|
||||
|
||||
@ -301,12 +301,12 @@ class AvgPoolMicrokernelTester {
|
||||
ASSERT_NEAR(
|
||||
float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5001f)
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
ASSERT_EQ(
|
||||
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
}
|
||||
}
|
||||
@ -396,12 +396,12 @@ class AvgPoolMicrokernelTester {
|
||||
ASSERT_NEAR(
|
||||
float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5001f)
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
ASSERT_EQ(
|
||||
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
}
|
||||
}
|
||||
|
||||
@ -232,7 +232,7 @@ class MaxPoolMicrokernelTester {
|
||||
ASSERT_EQ(
|
||||
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< "), kc = " << kc();
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@ inline std::vector<T> _expand_param_if_needed(
|
||||
std::ostringstream ss;
|
||||
ss << "expected " << param_name << " to be a single integer value or a "
|
||||
<< "list of " << expected_dim << " values to match the convolution "
|
||||
<< "dimensions, but got " << param_name << "=" << list_param;
|
||||
<< "dimensions, but got " << param_name << '=' << list_param;
|
||||
TORCH_CHECK(false, ss.str());
|
||||
} else {
|
||||
return list_param.vec();
|
||||
|
||||
@ -358,9 +358,9 @@ std::string Adapter::stringize() const {
|
||||
std::string device_type = get_device_type_str(properties.deviceType);
|
||||
VkPhysicalDeviceLimits limits = properties.limits;
|
||||
|
||||
ss << "{" << std::endl;
|
||||
ss << '{' << std::endl;
|
||||
ss << " Physical Device Info {" << std::endl;
|
||||
ss << " apiVersion: " << v_major << "." << v_minor << std::endl;
|
||||
ss << " apiVersion: " << v_major << '.' << v_minor << std::endl;
|
||||
ss << " driverversion: " << properties.driverVersion << std::endl;
|
||||
ss << " deviceType: " << device_type << std::endl;
|
||||
ss << " deviceName: " << properties.deviceName << std::endl;
|
||||
@ -371,7 +371,7 @@ std::string Adapter::stringize() const {
|
||||
|
||||
#define PRINT_LIMIT_PROP_VEC3(name) \
|
||||
ss << " " << std::left << std::setw(36) << #name << limits.name[0] \
|
||||
<< "," << limits.name[1] << "," << limits.name[2] << std::endl;
|
||||
<< ',' << limits.name[1] << ',' << limits.name[2] << std::endl;
|
||||
|
||||
ss << " Physical Device Limits {" << std::endl;
|
||||
PRINT_LIMIT_PROP(maxImageDimension1D);
|
||||
@ -425,7 +425,7 @@ std::string Adapter::stringize() const {
|
||||
;
|
||||
}
|
||||
ss << " ]" << std::endl;
|
||||
ss << "}";
|
||||
ss << '}';
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ std::ostream& operator<<(std::ostream& out, const VkResult result) {
|
||||
VK_RESULT_CASE(VK_ERROR_FORMAT_NOT_SUPPORTED)
|
||||
VK_RESULT_CASE(VK_ERROR_FRAGMENTED_POOL)
|
||||
default:
|
||||
out << "VK_ERROR_UNKNOWN (VkResult " << result << ")";
|
||||
out << "VK_ERROR_UNKNOWN (VkResult " << result << ')';
|
||||
break;
|
||||
}
|
||||
return out;
|
||||
@ -46,7 +46,7 @@ std::ostream& operator<<(std::ostream& out, const VkResult result) {
|
||||
//
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) {
|
||||
out << loc.function << " at " << loc.file << ":" << loc.line;
|
||||
out << loc.function << " at " << loc.file << ':' << loc.line;
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -66,7 +66,7 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg)
|
||||
: msg_(std::move(msg)), source_location_{source_location} {
|
||||
std::ostringstream oss;
|
||||
oss << "Exception raised from " << source_location_ << ": ";
|
||||
oss << "(" << cond << ") is false! ";
|
||||
oss << '(' << cond << ") is false! ";
|
||||
oss << msg_;
|
||||
what_ = oss.str();
|
||||
}
|
||||
|
||||
@ -173,8 +173,8 @@ void QueryPool::extract_results() {
|
||||
|
||||
static std::string stringize(const VkExtent3D& extents) {
|
||||
std::stringstream ss;
|
||||
ss << "{" << extents.width << ", " << extents.height << ", " << extents.depth
|
||||
<< "}";
|
||||
ss << '{' << extents.width << ", " << extents.height << ", " << extents.depth
|
||||
<< '}';
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
@ -149,7 +149,7 @@ VKAPI_ATTR VkBool32 VKAPI_CALL debug_report_callback_fn(
|
||||
(void)flags;
|
||||
|
||||
std::stringstream stream;
|
||||
stream << layer_prefix << " " << message_code << " " << message << std::endl;
|
||||
stream << layer_prefix << ' ' << message_code << ' ' << message << std::endl;
|
||||
const std::string log = stream.str();
|
||||
|
||||
std::cout << log;
|
||||
|
||||
@ -253,7 +253,7 @@ using vec4 = vec<4u>;
|
||||
|
||||
// uvec3 is the type representing tensor extents. Useful for debugging.
|
||||
inline std::ostream& operator<<(std::ostream& os, const uvec3& v) {
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ')';
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
@ -73,8 +73,8 @@ TEST(TestScalar, TestScalar) {
|
||||
Scalar bar = 3.0;
|
||||
Half h = bar.toHalf();
|
||||
Scalar h2 = h;
|
||||
cout << "H2: " << h2.toDouble() << " " << what.toFloat() << " "
|
||||
<< bar.toDouble() << " " << what.isIntegral(false) << "\n";
|
||||
cout << "H2: " << h2.toDouble() << ' ' << what.toFloat() << ' '
|
||||
<< bar.toDouble() << ' ' << what.isIntegral(false) << "\n";
|
||||
auto gen = at::detail::getDefaultCPUGenerator();
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
|
||||
@ -19,7 +19,7 @@ TEST(Vitals, Basic) {
|
||||
c10::utils::set_env("TORCH_VITAL", "1");
|
||||
TORCH_VITAL_DEFINE(Testing);
|
||||
TORCH_VITAL(Testing, Attribute0) << 1;
|
||||
TORCH_VITAL(Testing, Attribute1) << "1";
|
||||
TORCH_VITAL(Testing, Attribute1) << '1';
|
||||
TORCH_VITAL(Testing, Attribute2) << 1.0f;
|
||||
TORCH_VITAL(Testing, Attribute3) << 1.0;
|
||||
auto t = at::ones({1, 1});
|
||||
|
||||
@ -129,14 +129,14 @@ void showRtol(const at::Tensor& a, const at::Tensor& b) {
|
||||
std::cout << "Max Diff allowed: " << maxDiff << std::endl;
|
||||
if (diff.sizes().size() == 2) {
|
||||
for (const auto y : c10::irange(diff.sizes()[0])) {
|
||||
std::cout << y << ":";
|
||||
std::cout << y << ':';
|
||||
for (const auto x : c10::irange(diff.sizes()[1])) {
|
||||
float diff_xy = diff[y][x].item<float>();
|
||||
if (diff_xy > maxDiff) {
|
||||
std::cout << std::setw(5) << x;
|
||||
}
|
||||
else {
|
||||
std::cout << std::setw(5) << " ";
|
||||
std::cout << std::setw(5) << ' ';
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
@ -3276,7 +3276,7 @@ TEST_F(VulkanAPITest, masked_fill_invalidinputs_exceptions) {
|
||||
|
||||
void print_shape(const std::vector<int64_t>& shape) {
|
||||
for (const auto& num : shape) {
|
||||
std::cout << num << " ";
|
||||
std::cout << num << ' ';
|
||||
}
|
||||
}
|
||||
|
||||
@ -3367,7 +3367,7 @@ void test_masked_fill_scalar(
|
||||
print_shape(tmp_curr_input_shape);
|
||||
std::cout << "], and mask of shape [";
|
||||
print_shape(tmp_curr_mask_shape);
|
||||
std::cout << "]" << std::endl;
|
||||
std::cout << ']' << std::endl;
|
||||
}
|
||||
|
||||
ASSERT_TRUE(check);
|
||||
@ -4542,9 +4542,9 @@ void test_softmax(const at::IntArrayRef shape, bool log_softmax = false) {
|
||||
if (!check) {
|
||||
std::cout << "Softmax test failed on axis " << dim << "for tensor dims {";
|
||||
for (uint32_t place = 0; place < shape.size() - 1; place++) {
|
||||
std::cout << shape[place] << " ";
|
||||
std::cout << shape[place] << ' ';
|
||||
}
|
||||
std::cout << shape.back() << "}" << std::endl;
|
||||
std::cout << shape.back() << '}' << std::endl;
|
||||
showRtol(out_cpu, out_vulkan.cpu());
|
||||
}
|
||||
ASSERT_TRUE(check);
|
||||
|
||||
@ -95,7 +95,7 @@ void showRtol(
|
||||
std::cout << "Max Diff found is: " << diff.max().item<double>() << std::endl;
|
||||
if (diff.sizes().size() == 2) {
|
||||
for (const auto y : c10::irange(diff.sizes()[0])) {
|
||||
std::cout << y << ":";
|
||||
std::cout << y << ':';
|
||||
for (const auto x : c10::irange(diff.sizes()[1])) {
|
||||
double diff_xy = diff[y][x].item<double>();
|
||||
if (diff_xy > maxDiff) {
|
||||
@ -109,7 +109,7 @@ void showRtol(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::cout << std::setw(5) << " ";
|
||||
std::cout << std::setw(5) << ' ';
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
@ -148,19 +148,19 @@ using at::native::vulkan::api::utils::ivec4;
|
||||
using at::native::vulkan::api::utils::vec4;
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const vec4& v) {
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ")";
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ')';
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const ivec3& v) {
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ')';
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const ivec4& v) {
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ")";
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ')';
|
||||
return os;
|
||||
}
|
||||
|
||||
@ -3379,7 +3379,7 @@ bool _test_quantized_linear(
|
||||
showRtol(out_cpu_dequant, out_vk_to_cpu_dequant);
|
||||
}
|
||||
if (xpos != -1 && ypos != -1) {
|
||||
std::cout << "\nFailure caused on row/col: " << ypos << "/" << xpos
|
||||
std::cout << "\nFailure caused on row/col: " << ypos << '/' << xpos
|
||||
<< "\n";
|
||||
std::cout << "Input tensor scale: " << scale << " zerop: " << zero_point
|
||||
<< "\n";
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# These load paths point to different files in internal and OSS environment
|
||||
|
||||
load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
load("//tools/build_defs:cell_defs.bzl", "get_fbsource_cell")
|
||||
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
|
||||
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
|
||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||
@ -590,6 +591,9 @@ def pt_operator_query_codegen(
|
||||
pt_allow_forced_schema_registration = True,
|
||||
compatible_with = [],
|
||||
apple_sdks = None):
|
||||
if get_fbsource_cell() == "fbcode":
|
||||
return
|
||||
|
||||
oplist_dir_name = name + "_pt_oplist"
|
||||
|
||||
# @lint-ignore BUCKLINT
|
||||
@ -865,6 +869,9 @@ def define_buck_targets(
|
||||
pt_xplat_cxx_library = fb_xplat_cxx_library,
|
||||
c2_fbandroid_xplat_compiler_flags = [],
|
||||
labels = []):
|
||||
if get_fbsource_cell() == "fbcode":
|
||||
return
|
||||
|
||||
# @lint-ignore BUCKLINT
|
||||
fb_native.filegroup(
|
||||
name = "metal_build_srcs",
|
||||
|
||||
@ -176,7 +176,7 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
|
||||
os << k;
|
||||
first = false;
|
||||
}
|
||||
os << ")";
|
||||
os << ')';
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
@ -34,20 +34,6 @@ namespace c10 {
|
||||
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
|
||||
// regarding macros.
|
||||
|
||||
template <typename T>
|
||||
struct CppTypeToScalarType;
|
||||
|
||||
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct CppTypeToScalarType<cpp_type> \
|
||||
: std:: \
|
||||
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
|
||||
#undef SPECIALIZE_CppTypeToScalarType
|
||||
|
||||
#define DEFINE_CONSTANT(_, name) \
|
||||
constexpr ScalarType k##name = ScalarType::name;
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ std::ostream& operator<<(std::ostream& stream, const TensorOptions& options) {
|
||||
} else {
|
||||
stream << "(nullopt)";
|
||||
}
|
||||
stream << ")";
|
||||
stream << ')';
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
@ -136,7 +136,7 @@ std::string c10_retrieve_device_side_assertion_info() {
|
||||
// Something failed, let's talk about that
|
||||
oss << failures_found
|
||||
<< " CUDA device-side assertion failures were found on GPU #"
|
||||
<< device_num << "!" << std::endl;
|
||||
<< device_num << '!' << std::endl;
|
||||
if (assertion_data_for_device.assertion_count >
|
||||
C10_CUDA_DSA_ASSERTION_COUNT) {
|
||||
oss << "But at least " << assertion_data_for_device.assertion_count
|
||||
@ -151,17 +151,17 @@ std::string c10_retrieve_device_side_assertion_info() {
|
||||
oss << "Assertion failure " << i << std::endl;
|
||||
oss << " GPU assertion failure message = " << self.assertion_msg
|
||||
<< std::endl;
|
||||
oss << " File containing assertion = " << self.filename << ":"
|
||||
oss << " File containing assertion = " << self.filename << ':'
|
||||
<< self.line_number << std::endl;
|
||||
oss << " Device function containing assertion = " << self.function_name
|
||||
<< std::endl;
|
||||
oss << " Thread ID that failed assertion = [" << self.thread_id[0] << ","
|
||||
<< self.thread_id[1] << "," << self.thread_id[2] << "]" << std::endl;
|
||||
oss << " Block ID that failed assertion = [" << self.block_id[0] << ","
|
||||
<< self.block_id[1] << "," << self.block_id[2] << "]" << std::endl;
|
||||
oss << " Thread ID that failed assertion = [" << self.thread_id[0] << ','
|
||||
<< self.thread_id[1] << ',' << self.thread_id[2] << ']' << std::endl;
|
||||
oss << " Block ID that failed assertion = [" << self.block_id[0] << ','
|
||||
<< self.block_id[1] << ',' << self.block_id[2] << ']' << std::endl;
|
||||
if (launch_info.generation_number == self.caller) {
|
||||
oss << " File containing kernel launch = "
|
||||
<< launch_info.launch_filename << ":" << launch_info.launch_linenum
|
||||
<< launch_info.launch_filename << ':' << launch_info.launch_linenum
|
||||
<< std::endl;
|
||||
oss << " Function containing kernel launch = "
|
||||
<< launch_info.launch_function << std::endl;
|
||||
|
||||
@ -435,7 +435,7 @@ TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
|
||||
if (i > 0) {
|
||||
ASSERT_TRUE(res.find("Unknown") == std::string::npos)
|
||||
<< i << " (before is " << toString(static_cast<DispatchKey>(i - 1))
|
||||
<< ")";
|
||||
<< ')';
|
||||
} else {
|
||||
ASSERT_TRUE(res.find("Unknown") == std::string::npos) << i;
|
||||
}
|
||||
|
||||
@ -98,7 +98,7 @@ struct Noncopyable {
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Noncopyable& nc) {
|
||||
out << "Noncopyable(" << nc.x << ")";
|
||||
out << "Noncopyable(" << nc.x << ')';
|
||||
return out;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -198,13 +198,13 @@ ArrayRef(const std::initializer_list<T>&) -> ArrayRef<T>;
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
|
||||
int i = 0;
|
||||
out << "[";
|
||||
out << '[';
|
||||
for (const auto& e : list) {
|
||||
if (i++ > 0)
|
||||
out << ", ";
|
||||
out << e;
|
||||
}
|
||||
out << "]";
|
||||
out << ']';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -107,7 +107,7 @@ class GetBacktraceImpl {
|
||||
/*status*/ &status);
|
||||
|
||||
os << " frame #" << idx++ << "\t"
|
||||
<< ((demangled != NULL && status == 0) ? demangled : symbol) << "["
|
||||
<< ((demangled != NULL && status == 0) ? demangled : symbol) << '['
|
||||
<< addr << "]\t" << std::endl;
|
||||
}
|
||||
free(demangled);
|
||||
@ -413,8 +413,8 @@ class GetBacktraceImpl {
|
||||
<< back_trace_[i_frame] << std::dec;
|
||||
if (with_symbol) {
|
||||
stream << std::setfill('0') << std::setw(16) << std::uppercase
|
||||
<< std::hex << p_symbol->Address << std::dec << " " << module
|
||||
<< "!" << p_symbol->Name;
|
||||
<< std::hex << p_symbol->Address << std::dec << ' ' << module
|
||||
<< '!' << p_symbol->Name;
|
||||
} else {
|
||||
stream << " <unknown symbol address> " << module << "!<unknown symbol>";
|
||||
}
|
||||
@ -424,7 +424,7 @@ class GetBacktraceImpl {
|
||||
} else {
|
||||
stream << "<unknown file> @ <unknown line number>";
|
||||
}
|
||||
stream << "]" << std::endl;
|
||||
stream << ']' << std::endl;
|
||||
}
|
||||
|
||||
return stream.str();
|
||||
|
||||
@ -44,7 +44,7 @@ std::string Error::compute_what(bool include_backtrace) const {
|
||||
|
||||
if (context_.size() == 1) {
|
||||
// Fold error and context in one line
|
||||
oss << " (" << context_[0] << ")";
|
||||
oss << " (" << context_[0] << ')';
|
||||
} else {
|
||||
for (const auto& c : context_) {
|
||||
oss << "\n " << c;
|
||||
@ -247,7 +247,7 @@ void WarningHandler::process(const Warning& warning) {
|
||||
LOG_AT_FILE_LINE(
|
||||
WARNING, warning.source_location().file, warning.source_location().line)
|
||||
<< "Warning: " << warning.msg() << " (function "
|
||||
<< warning.source_location().function << ")";
|
||||
<< warning.source_location().function << ')';
|
||||
}
|
||||
|
||||
std::string GetExceptionString(const std::exception& e) {
|
||||
|
||||
@ -473,12 +473,12 @@ MessageLogger::MessageLogger(
|
||||
if (GLOBAL_RANK != -1) {
|
||||
stream_ << "[rank" << GLOBAL_RANK << "]:";
|
||||
}
|
||||
stream_ << "[" << CAFFE2_SEVERITY_PREFIX[std::min(4, GLOG_FATAL - severity_)]
|
||||
stream_ << '[' << CAFFE2_SEVERITY_PREFIX[std::min(4, GLOG_FATAL - severity_)]
|
||||
<< (timeinfo->tm_mon + 1) * 100 + timeinfo->tm_mday
|
||||
<< std::setfill('0') << " " << std::setw(2) << timeinfo->tm_hour
|
||||
<< ":" << std::setw(2) << timeinfo->tm_min << ":" << std::setw(2)
|
||||
<< timeinfo->tm_sec << "." << std::setw(9) << ns << " "
|
||||
<< c10::detail::StripBasename(std::string(file)) << ":" << line
|
||||
<< std::setfill('0') << ' ' << std::setw(2) << timeinfo->tm_hour
|
||||
<< ':' << std::setw(2) << timeinfo->tm_min << ':' << std::setw(2)
|
||||
<< timeinfo->tm_sec << '.' << std::setw(9) << ns << ' '
|
||||
<< c10::detail::StripBasename(std::string(file)) << ':' << line
|
||||
<< "] ";
|
||||
}
|
||||
|
||||
|
||||
@ -1412,13 +1412,13 @@ inline size_t capacity_in_bytes(const SmallVector<T, N>& X) {
|
||||
template <typename T, unsigned N>
|
||||
std::ostream& operator<<(std::ostream& out, const SmallVector<T, N>& list) {
|
||||
int i = 0;
|
||||
out << "[";
|
||||
out << '[';
|
||||
for (auto e : list) {
|
||||
if (i++ > 0)
|
||||
out << ", ";
|
||||
out << e;
|
||||
}
|
||||
out << "]";
|
||||
out << ']';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ std::ostream& _str(std::ostream& ss, const std::wstring& wString) {
|
||||
} // namespace detail
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) {
|
||||
out << loc.function << " at " << loc.file << ":" << loc.line;
|
||||
out << loc.function << " at " << loc.file << ':' << loc.line;
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -170,7 +170,7 @@ inline bool isPrint(char s) {
|
||||
}
|
||||
|
||||
inline void printQuotedString(std::ostream& stmt, const std::string_view str) {
|
||||
stmt << "\"";
|
||||
stmt << '"';
|
||||
for (auto s : str) {
|
||||
switch (s) {
|
||||
case '\\':
|
||||
@ -224,7 +224,7 @@ inline void printQuotedString(std::ostream& stmt, const std::string_view str) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
stmt << "\"";
|
||||
stmt << '"';
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -223,7 +223,7 @@ void FatalSignalHandler::fatalSignalHandler(int signum) {
|
||||
// a single thread that wouldn't receive the SIGUSR2
|
||||
if (std::cv_status::timeout == writingCond.wait_for(ul, 2s)) {
|
||||
if (!signalReceived) {
|
||||
std::cerr << "signal lost waiting for stacktrace " << pid << ":"
|
||||
std::cerr << "signal lost waiting for stacktrace " << pid << ':'
|
||||
<< tid << '\n';
|
||||
break;
|
||||
}
|
||||
|
||||
@ -877,7 +877,7 @@ std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
const SparseBitVector<ElementSize>& vec) {
|
||||
bool first = true;
|
||||
stream << "{";
|
||||
stream << '{';
|
||||
for (auto el : vec) {
|
||||
if (first) {
|
||||
first = false;
|
||||
@ -886,7 +886,7 @@ std::ostream& operator<<(
|
||||
}
|
||||
stream << el;
|
||||
}
|
||||
stream << "}";
|
||||
stream << '}';
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
||||
@ -734,7 +734,7 @@ void PyTorchStreamWriter::setup(const string& file_name) {
|
||||
file_name,
|
||||
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary
|
||||
);
|
||||
} catch (const std::ios_base::failure& e) {
|
||||
} catch (const std::ios_base::failure&) {
|
||||
#ifdef _WIN32
|
||||
// Windows have verbose error code, we prefer to use it than std errno.
|
||||
uint32_t error_code = GetLastError();
|
||||
|
||||
@ -118,6 +118,11 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
|
||||
endif()
|
||||
endif()
|
||||
if("${_arch}" STREQUAL "121a")
|
||||
if(_existing_arch_flags MATCHES ".*compute_120.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
list(JOIN _file_compile_flags " " _file_compile_flags)
|
||||
|
||||
@ -126,7 +131,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
|
||||
"89;90a;100a;103a;120a")
|
||||
"89;90a;100a;103a;120a;121a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
|
||||
"90a")
|
||||
|
||||
113
docs/source/accelerator/device.md
Normal file
113
docs/source/accelerator/device.md
Normal file
@ -0,0 +1,113 @@
|
||||
# Device Management
|
||||
|
||||
## Background
|
||||
|
||||
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
|
||||
|
||||
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
|
||||
|
||||
## Design
|
||||
|
||||
Accelerator vendors need to implement these core functions:
|
||||
|
||||
| Function Name | Description | Application Scenarios |
|
||||
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
|
||||
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
|
||||
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
|
||||
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
|
||||
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
|
||||
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
|
||||
|
||||
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
|
||||
|
||||
## Implementation
|
||||
|
||||
This section shows how to implement device management using `set_device` as an example. The implementation requires:
|
||||
1. C++ wrappers around the device runtime
|
||||
2. Python bindings to expose the C++ functions
|
||||
3. User-friendly Python APIs
|
||||
|
||||
### C++ Side
|
||||
|
||||
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
|
||||
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
|
||||
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
|
||||
### Binding
|
||||
|
||||
Expose the C++ functions to Python using pybind11:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
|
||||
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
|
||||
:linenos:
|
||||
```
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
|
||||
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
|
||||
:linenos:
|
||||
:emphasize-lines: 5
|
||||
```
|
||||
|
||||
### Python Side
|
||||
|
||||
Wrap the C++ bindings with user-friendly Python functions:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
|
||||
:language: python
|
||||
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
|
||||
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
|
||||
Here's the complete mapping from C++ to Python:
|
||||
|
||||
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
|
||||
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
|
||||
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
|
||||
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
|
||||
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
|
||||
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
|
||||
|
||||
## Guard
|
||||
|
||||
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
|
||||
|
||||
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
:linenos:
|
||||
```
|
||||
|
||||
**What needs to be implemented:**
|
||||
|
||||
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
|
||||
2. **getDevice()**: Get the current device
|
||||
3. **setDevice()**: Set the active device
|
||||
4. **Type checking**: Validate that device type matches the backend
|
||||
|
||||
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
|
||||
|
||||
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"
|
||||
@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
|
||||
:glob:
|
||||
:maxdepth: 1
|
||||
|
||||
device
|
||||
hooks
|
||||
autoload
|
||||
operators
|
||||
|
||||
@ -30,5 +30,6 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
|
||||
skip_guard_on_all_nn_modules_unsafe
|
||||
keep_tensor_guards_unsafe
|
||||
skip_guard_on_globals_unsafe
|
||||
skip_all_guards_unsafe
|
||||
nested_compile_region
|
||||
```
|
||||
|
||||
@ -376,3 +376,19 @@ keep-runtime-typing = true
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words = "tools/linter/dictionary.txt"
|
||||
|
||||
[tool.spin]
|
||||
package = 'torch'
|
||||
|
||||
[tool.spin.commands]
|
||||
"Build" = [
|
||||
".spin/cmds.py:lint",
|
||||
".spin/cmds.py:fixlint",
|
||||
".spin/cmds.py:quicklint",
|
||||
".spin/cmds.py:quickfix",
|
||||
]
|
||||
"Regenerate" = [
|
||||
".spin/cmds.py:regenerate_version",
|
||||
".spin/cmds.py:regenerate_type_stubs",
|
||||
".spin/cmds.py:regenerate_clangtidy_files",
|
||||
]
|
||||
|
||||
@ -14,6 +14,7 @@ lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
|
||||
networkx>=2.5.1
|
||||
optree>=0.13.0
|
||||
psutil
|
||||
spin
|
||||
sympy>=1.13.3
|
||||
typing-extensions>=4.13.2
|
||||
wheel
|
||||
|
||||
47
setup.py
47
setup.py
@ -1358,45 +1358,6 @@ class concat_license_files:
|
||||
|
||||
# Need to create the proper LICENSE.txt for the wheel
|
||||
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
|
||||
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
|
||||
|
||||
Excludes:
|
||||
- torch/include/torch/headeronly/*
|
||||
- torch/include/torch/csrc/stable/*
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/generated/
|
||||
"""
|
||||
header_extensions = (".h", ".hpp", ".cuh")
|
||||
header_files = [
|
||||
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
|
||||
]
|
||||
|
||||
# Paths to exclude from wrapping
|
||||
exclude_dir_patterns = [
|
||||
"torch/include/torch/headeronly/",
|
||||
"torch/include/torch/csrc/stable/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/c/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
|
||||
]
|
||||
|
||||
for header_file in header_files:
|
||||
rel_path = header_file.relative_to(bdist_dir).as_posix()
|
||||
|
||||
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
|
||||
report(f"Skipping header: {rel_path}")
|
||||
continue
|
||||
|
||||
original_content = header_file.read_text(encoding="utf-8")
|
||||
wrapped_content = (
|
||||
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
f"{original_content}"
|
||||
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
)
|
||||
|
||||
header_file.write_text(wrapped_content, encoding="utf-8")
|
||||
report(f"Wrapped header: {rel_path}")
|
||||
|
||||
def run(self) -> None:
|
||||
with concat_license_files(include_files=True):
|
||||
super().run()
|
||||
@ -1419,14 +1380,6 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
# need an __init__.py file otherwise we wouldn't have a package
|
||||
(bdist_dir / "torch" / "__init__.py").touch()
|
||||
|
||||
# Wrap all header files with TORCH_STABLE_ONLY macro
|
||||
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
|
||||
bdist_dir = Path(self.bdist_dir)
|
||||
report(
|
||||
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
|
||||
)
|
||||
self._wrap_headers_with_macro(bdist_dir)
|
||||
|
||||
|
||||
class clean(Command):
|
||||
user_options: ClassVar[list[tuple[str, str | None, str]]] = []
|
||||
|
||||
@ -13,6 +13,17 @@ TEST(TestScalarType, ScalarTypeToCPPTypeT) {
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
TEST(TestScalarType, CppTypeToScalarType) {
|
||||
using torch::headeronly::CppTypeToScalarType;
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
EXPECT_EQ(CppTypeToScalarType<TYPE>::value, ScalarType::SCALARTYPE);
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
{ \
|
||||
EXPECT_EQ( \
|
||||
|
||||
@ -634,3 +634,38 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_parallel_for", &boxed_test_parallel_for);
|
||||
m.impl("test_get_num_threads", &boxed_test_get_num_threads);
|
||||
}
|
||||
|
||||
Tensor my_empty(
|
||||
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
|
||||
std::optional<torch::headeronly::ScalarType> dtype,
|
||||
std::optional<torch::stable::Device> device,
|
||||
std::optional<bool> pin_memory) {
|
||||
return empty(size, dtype, device, pin_memory);
|
||||
}
|
||||
|
||||
Tensor my_flatten(Tensor t, int64_t start_dim, int64_t end_dim) {
|
||||
return flatten(t, start_dim, end_dim);
|
||||
}
|
||||
|
||||
Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) {
|
||||
return reshape(t, shape);
|
||||
}
|
||||
|
||||
Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) {
|
||||
return view(t, size);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def(
|
||||
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
|
||||
m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor");
|
||||
m.def("my_reshape(Tensor t, int[] shape) -> Tensor");
|
||||
m.def("my_view(Tensor t, int[] size) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_empty", TORCH_BOX(&my_empty));
|
||||
m.impl("my_flatten", TORCH_BOX(&my_flatten));
|
||||
m.impl("my_reshape", TORCH_BOX(&my_reshape));
|
||||
m.impl("my_view", TORCH_BOX(&my_view));
|
||||
}
|
||||
|
||||
@ -487,3 +487,58 @@ def test_get_num_threads() -> int:
|
||||
Returns: int - the number of threads for the parallel backend
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_get_num_threads.default()
|
||||
|
||||
|
||||
def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor:
|
||||
"""
|
||||
Creates an empty tensor with the specified size, dtype, device, and pin_memory.
|
||||
|
||||
Args:
|
||||
size: list[int] - size of the tensor to create
|
||||
dtype: ScalarType or None - data type of the tensor
|
||||
device: Device or None - device on which to create the tensor
|
||||
pin_memory: bool or None - whether to use pinned memory
|
||||
|
||||
Returns: Tensor - an uninitialized tensor with the specified properties
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_empty.default(size, dtype, device, pin_memory)
|
||||
|
||||
|
||||
def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor:
|
||||
"""
|
||||
Flattens the input tensor from start_dim to end_dim into a single dimension.
|
||||
|
||||
Args:
|
||||
t: Tensor - tensor to flatten
|
||||
start_dim: int - first dimension to flatten (default: 0)
|
||||
end_dim: int - last dimension to flatten (default: -1)
|
||||
|
||||
Returns: Tensor - flattened tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_flatten.default(t, start_dim, end_dim)
|
||||
|
||||
|
||||
def my_reshape(t, shape) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with the same data but different shape.
|
||||
|
||||
Args:
|
||||
t: Tensor - tensor to reshape
|
||||
shape: list[int] - new shape for the tensor
|
||||
|
||||
Returns: Tensor - reshaped tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_reshape.default(t, shape)
|
||||
|
||||
|
||||
def my_view(t, size) -> Tensor:
|
||||
"""
|
||||
Returns a new tensor with the same data as the input tensor but of a different shape.
|
||||
|
||||
Args:
|
||||
t: Tensor - tensor to view
|
||||
size: list[int] - new size for the tensor
|
||||
|
||||
Returns: Tensor - tensor with new view
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_view.default(t, size)
|
||||
|
||||
@ -33,7 +33,7 @@ class clean(distutils.command.clean.clean):
|
||||
|
||||
def get_extension():
|
||||
extra_compile_args = {
|
||||
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
|
||||
"cxx": ["-fdiagnostics-color=always"],
|
||||
}
|
||||
|
||||
extension = CppExtension
|
||||
|
||||
@ -525,6 +525,97 @@ if not IS_WINDOWS:
|
||||
expected_num_threads = torch.get_num_threads()
|
||||
self.assertEqual(num_threads, expected_num_threads)
|
||||
|
||||
def test_my_empty(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
deterministic = torch.are_deterministic_algorithms_enabled()
|
||||
try:
|
||||
# set use_deterministic_algorithms to fill uninitialized memory
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
size = [2, 3]
|
||||
result = libtorch_agnostic.ops.my_empty(size, None, None, None)
|
||||
expected = torch.empty(size)
|
||||
self.assertEqual(result, expected, exact_device=True)
|
||||
|
||||
result_float = libtorch_agnostic.ops.my_empty(
|
||||
size, torch.float32, None, None
|
||||
)
|
||||
expected_float = torch.empty(size, dtype=torch.float32)
|
||||
self.assertEqual(result_float, expected_float, exact_device=True)
|
||||
|
||||
result_with_device = libtorch_agnostic.ops.my_empty(
|
||||
size, torch.float64, device, None
|
||||
)
|
||||
expected_with_device = torch.empty(
|
||||
size, dtype=torch.float64, device=device
|
||||
)
|
||||
self.assertEqual(
|
||||
result_with_device, expected_with_device, exact_device=True
|
||||
)
|
||||
|
||||
if device == "cuda":
|
||||
result_pinned = libtorch_agnostic.ops.my_empty(
|
||||
size, torch.float32, "cpu", True
|
||||
)
|
||||
expected_pinned = torch.empty(
|
||||
size, dtype=torch.float32, device="cpu", pin_memory=True
|
||||
)
|
||||
self.assertEqual(result_pinned, expected_pinned)
|
||||
self.assertTrue(result_pinned.is_pinned())
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(deterministic)
|
||||
|
||||
def test_my_flatten(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(2, 3, 4, device=device)
|
||||
result = libtorch_agnostic.ops.my_flatten(t)
|
||||
expected = torch.flatten(t)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
result_start = libtorch_agnostic.ops.my_flatten(t, 1)
|
||||
expected_start = torch.flatten(t, 1)
|
||||
self.assertEqual(result_start, expected_start)
|
||||
|
||||
result_range = libtorch_agnostic.ops.my_flatten(t, 2, -1)
|
||||
expected_range = torch.flatten(t, 2, -1)
|
||||
self.assertEqual(result_range, expected_range)
|
||||
|
||||
def test_my_reshape(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(2, 3, 4, device=device)
|
||||
|
||||
result = libtorch_agnostic.ops.my_reshape(t, [6, 4])
|
||||
expected = torch.reshape(t, [6, 4])
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
result_infer = libtorch_agnostic.ops.my_reshape(t, [-1, 4])
|
||||
expected_infer = torch.reshape(t, [-1, 4])
|
||||
self.assertEqual(result_infer, expected_infer)
|
||||
|
||||
result_flat = libtorch_agnostic.ops.my_reshape(t, [-1])
|
||||
expected_flat = torch.reshape(t, [-1])
|
||||
self.assertEqual(result_flat, expected_flat)
|
||||
|
||||
def test_my_view(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(2, 3, 4, device=device)
|
||||
|
||||
result = libtorch_agnostic.ops.my_view(t, [6, 4])
|
||||
expected = t.view([6, 4])
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
result_infer = libtorch_agnostic.ops.my_view(t, [-1, 4])
|
||||
expected_infer = t.view([-1, 4])
|
||||
self.assertEqual(result_infer, expected_infer)
|
||||
|
||||
result_flat = libtorch_agnostic.ops.my_view(t, [-1])
|
||||
expected_flat = t.view([-1])
|
||||
self.assertEqual(result_flat, expected_flat)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -4,17 +4,12 @@
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
void orCheckFail(
|
||||
const char* func,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
const char* msg = "");
|
||||
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
|
||||
|
||||
#define OPENREG_CHECK(EXPR, ...) \
|
||||
do { \
|
||||
const orError_t __err = EXPR; \
|
||||
if (__err != orSuccess) { \
|
||||
orCheckFail( \
|
||||
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
||||
} \
|
||||
#define OPENREG_CHECK(EXPR, ...) \
|
||||
do { \
|
||||
const orError_t __err = EXPR; \
|
||||
if (C10_UNLIKELY(__err != orSuccess)) { \
|
||||
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user