mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 07:24:54 +08:00
Compare commits
16 Commits
sy_invoke_
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 54f7347a5b | |||
| 82afb7deda | |||
| 7aa210d215 | |||
| 5a368b8010 | |||
| 602102be50 | |||
| 200156e385 | |||
| 3710cad6d7 | |||
| 4d2cc7d490 | |||
| 0ca26bbd2b | |||
| 76c6d99ba9 | |||
| cf1ea48d0a | |||
| 68283bd54c | |||
| b6ab8b28a4 | |||
| 4c9ebe5b2f | |||
| 7104addf1e | |||
| 0deaae4852 |
@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _compile_and_extract_symbols(
|
||||
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Helper to compile a C++ file and extract all symbols.
|
||||
|
||||
Args:
|
||||
cpp_content: C++ source code to compile
|
||||
compile_flags: Compilation flags
|
||||
exclude_list: List of symbol names to exclude. Defaults to ["main"].
|
||||
|
||||
Returns:
|
||||
List of all symbols found in the object file (excluding those in exclude_list).
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
if exclude_list is None:
|
||||
exclude_list = ["main"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmppath = Path(tmpdir)
|
||||
cpp_file = tmppath / "test.cpp"
|
||||
obj_file = tmppath / "test.o"
|
||||
|
||||
cpp_file.write_text(cpp_content)
|
||||
|
||||
result = subprocess.run(
|
||||
compile_flags + [str(cpp_file), "-o", str(obj_file)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Compilation failed: {result.stderr}")
|
||||
|
||||
symbols = get_symbols(str(obj_file))
|
||||
|
||||
# Return all symbol names, excluding those in the exclude list
|
||||
return [name for _addr, _stype, name in symbols if name not in exclude_list]
|
||||
|
||||
|
||||
def check_stable_only_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
|
||||
|
||||
This approach tests:
|
||||
1. WITHOUT macros -> many torch symbols exposed
|
||||
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
|
||||
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
|
||||
4. WITH both macros -> zero torch symbols (all hidden)
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
test_cpp_content = """
|
||||
// Main torch C++ API headers
|
||||
#include <torch/torch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// ATen tensor library
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
// Core c10 headers (commonly used)
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
int main() { return 0; }
|
||||
"""
|
||||
|
||||
base_compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c", # Compile only, don't link
|
||||
]
|
||||
|
||||
# Compile WITHOUT any macros
|
||||
symbols_without = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=base_compile_flags,
|
||||
)
|
||||
|
||||
# We expect constexpr symbols, inline functions used by other headers etc.
|
||||
# to produce symbols
|
||||
num_symbols_without = len(symbols_without)
|
||||
print(f"Found {num_symbols_without} symbols without any macros defined")
|
||||
assert num_symbols_without != 0, (
|
||||
"Expected a non-zero number of symbols without any macros"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
|
||||
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
|
||||
|
||||
symbols_with_stable_only = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_stable_only,
|
||||
)
|
||||
|
||||
num_symbols_with_stable_only = len(symbols_with_stable_only)
|
||||
assert num_symbols_with_stable_only == 0, (
|
||||
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
|
||||
compile_flags_with_target_version = base_compile_flags + [
|
||||
"-DTORCH_TARGET_VERSION=1"
|
||||
]
|
||||
|
||||
symbols_with_target_version = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_target_version,
|
||||
)
|
||||
|
||||
num_symbols_with_target_version = len(symbols_with_target_version)
|
||||
assert num_symbols_with_target_version == 0, (
|
||||
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
|
||||
)
|
||||
|
||||
# Compile WITH both macros (expect 0 symbols)
|
||||
compile_flags_with_both = base_compile_flags + [
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
"-DTORCH_TARGET_VERSION=1",
|
||||
]
|
||||
|
||||
symbols_with_both = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_both,
|
||||
)
|
||||
|
||||
num_symbols_with_both = len(symbols_with_both)
|
||||
assert num_symbols_with_both == 0, (
|
||||
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_api_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
stable_dir = include_dir / "torch" / "csrc" / "stable"
|
||||
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
|
||||
|
||||
stable_headers = list(stable_dir.rglob("*.h"))
|
||||
if not stable_headers:
|
||||
raise RuntimeError("Could not find any stable headers")
|
||||
|
||||
includes = []
|
||||
for header in stable_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_stable_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable = len(symbols_stable)
|
||||
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
|
||||
assert num_symbols_stable > 0, (
|
||||
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_headeronly_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Find all headers in torch/headeronly
|
||||
headeronly_dir = include_dir / "torch" / "headeronly"
|
||||
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
|
||||
headeronly_headers = list(headeronly_dir.rglob("*.h"))
|
||||
if not headeronly_headers:
|
||||
raise RuntimeError("Could not find any headeronly headers")
|
||||
|
||||
# Filter out platform-specific headers that may not compile everywhere
|
||||
platform_specific_keywords = [
|
||||
"cpu/vec",
|
||||
]
|
||||
|
||||
filtered_headers = []
|
||||
for header in headeronly_headers:
|
||||
rel_path = header.relative_to(include_dir).as_posix()
|
||||
if not any(
|
||||
keyword in rel_path.lower() for keyword in platform_specific_keywords
|
||||
):
|
||||
filtered_headers.append(header)
|
||||
|
||||
includes = []
|
||||
for header in filtered_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_headeronly_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_headeronly = _compile_and_extract_symbols(
|
||||
cpp_content=test_headeronly_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_headeronly = len(symbols_headeronly)
|
||||
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
|
||||
assert num_symbols_headeronly > 0, (
|
||||
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_headeronly} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_aoti_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_shim_content = """
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
int main() {
|
||||
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
|
||||
int32_t (*fp2)() = &aoti_torch_dtype_float32;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_shim = len(symbols_shim)
|
||||
assert num_symbols_shim > 0, (
|
||||
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_c_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Check if the stable C shim exists
|
||||
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
|
||||
if not stable_shim.exists():
|
||||
raise RuntimeError("Could not find stable c shim")
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_stable_shim_content = """
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
int main() {
|
||||
// Reference stable C API functions to create undefined symbols
|
||||
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
|
||||
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable_shim = len(symbols_stable_shim)
|
||||
assert num_symbols_stable_shim > 0, (
|
||||
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
|
||||
print(f"lib: {lib}")
|
||||
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
|
||||
@ -460,13 +129,6 @@ def main() -> None:
|
||||
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
|
||||
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
|
||||
|
||||
# Check symbols when TORCH_STABLE_ONLY is defined
|
||||
check_stable_only_symbols(install_root)
|
||||
check_stable_api_symbols(install_root)
|
||||
check_headeronly_symbols(install_root)
|
||||
check_aoti_shim_symbols(install_root)
|
||||
check_stable_c_shim_symbols(install_root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
47
setup.py
47
setup.py
@ -1358,45 +1358,6 @@ class concat_license_files:
|
||||
|
||||
# Need to create the proper LICENSE.txt for the wheel
|
||||
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
|
||||
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
|
||||
|
||||
Excludes:
|
||||
- torch/include/torch/headeronly/*
|
||||
- torch/include/torch/csrc/stable/*
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/generated/
|
||||
"""
|
||||
header_extensions = (".h", ".hpp", ".cuh")
|
||||
header_files = [
|
||||
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
|
||||
]
|
||||
|
||||
# Paths to exclude from wrapping
|
||||
exclude_dir_patterns = [
|
||||
"torch/include/torch/headeronly/",
|
||||
"torch/include/torch/csrc/stable/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/c/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
|
||||
]
|
||||
|
||||
for header_file in header_files:
|
||||
rel_path = header_file.relative_to(bdist_dir).as_posix()
|
||||
|
||||
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
|
||||
report(f"Skipping header: {rel_path}")
|
||||
continue
|
||||
|
||||
original_content = header_file.read_text(encoding="utf-8")
|
||||
wrapped_content = (
|
||||
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
f"{original_content}"
|
||||
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
)
|
||||
|
||||
header_file.write_text(wrapped_content, encoding="utf-8")
|
||||
report(f"Wrapped header: {rel_path}")
|
||||
|
||||
def run(self) -> None:
|
||||
with concat_license_files(include_files=True):
|
||||
super().run()
|
||||
@ -1419,14 +1380,6 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
# need an __init__.py file otherwise we wouldn't have a package
|
||||
(bdist_dir / "torch" / "__init__.py").touch()
|
||||
|
||||
# Wrap all header files with TORCH_STABLE_ONLY macro
|
||||
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
|
||||
bdist_dir = Path(self.bdist_dir)
|
||||
report(
|
||||
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
|
||||
)
|
||||
self._wrap_headers_with_macro(bdist_dir)
|
||||
|
||||
|
||||
class clean(Command):
|
||||
user_options: ClassVar[list[tuple[str, str | None, str]]] = []
|
||||
|
||||
@ -33,7 +33,7 @@ class clean(distutils.command.clean.clean):
|
||||
|
||||
def get_extension():
|
||||
extra_compile_args = {
|
||||
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
|
||||
"cxx": ["-fdiagnostics-color=always"],
|
||||
}
|
||||
|
||||
extension = CppExtension
|
||||
|
||||
67
test/cpp_extensions/torch_stable_test_extension/setup.py
Normal file
67
test/cpp_extensions/torch_stable_test_extension/setup.py
Normal file
@ -0,0 +1,67 @@
|
||||
import distutils.command.clean
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent
|
||||
CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc"
|
||||
|
||||
|
||||
class clean(distutils.command.clean.clean):
|
||||
def run(self):
|
||||
# Run default behavior first
|
||||
distutils.command.clean.clean.run(self)
|
||||
|
||||
# Remove extension
|
||||
for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"):
|
||||
path.unlink()
|
||||
# Remove build and dist and egg-info directories
|
||||
dirs = [
|
||||
ROOT_DIR / "build",
|
||||
ROOT_DIR / "dist",
|
||||
ROOT_DIR / "torch_stable_test.egg-info",
|
||||
]
|
||||
for path in dirs:
|
||||
if path.exists():
|
||||
shutil.rmtree(str(path), ignore_errors=True)
|
||||
|
||||
|
||||
def get_extension():
|
||||
extra_compile_args = {
|
||||
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
|
||||
}
|
||||
|
||||
sources = list(CSRC_DIR.glob("**/*.cpp"))
|
||||
|
||||
return [
|
||||
CppExtension(
|
||||
"torch_stable_test._C",
|
||||
sources=sorted(str(s) for s in sources),
|
||||
py_limited_api=True,
|
||||
extra_compile_args=extra_compile_args,
|
||||
extra_link_args=[],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
setup(
|
||||
name="torch_stable_test",
|
||||
version="0.0",
|
||||
author="PyTorch Core Team",
|
||||
description="Test extension to verify TORCH_STABLE_ONLY flag",
|
||||
packages=find_packages(exclude=("test",)),
|
||||
package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]},
|
||||
install_requires=[
|
||||
"torch",
|
||||
],
|
||||
ext_modules=get_extension(),
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
|
||||
"clean": clean,
|
||||
},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
||||
@ -0,0 +1 @@
|
||||
#include <ATen/core/TensorBase.h> // This should trigger the TORCH_STABLE_ONLY error
|
||||
@ -0,0 +1,22 @@
|
||||
# Owner(s): ["module: cpp"]
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
install_cpp_extension,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
if not IS_WINDOWS:
|
||||
|
||||
class TestTorchStable(TestCase):
|
||||
def test_setup_fails(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"):
|
||||
install_cpp_extension(extension_root=Path(__file__).parent.parent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -90,12 +90,12 @@ class GraphModule(torch.nn.Module):
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", primals_4: "f32[u0, u1, u2]"):
|
||||
ge_1: "Sym(u0 >= 0)" = primals_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge_3: "Sym(u1 >= 0)" = primals_2 >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
||||
ge_5: "Sym(u2 >= 0)" = primals_3 >= 0
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
|
||||
ge: "Sym(u0 >= 0)" = primals_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
ge_1: "Sym(u1 >= 0)" = primals_2 >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
|
||||
ge_2: "Sym(u2 >= 0)" = primals_3 >= 0
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None
|
||||
|
||||
floordiv: "Sym((u0//2))" = primals_1 // 2
|
||||
|
||||
|
||||
@ -727,7 +727,7 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(3)
|
||||
arg_count = ifdynstaticdefault(4, 5)
|
||||
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
|
||||
expected_op_count = ifdynstaticdefault(10, 8)
|
||||
expected_op_count = ifdynstaticdefault(9, 7)
|
||||
out_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x,)),
|
||||
@ -747,7 +747,6 @@ class GraphModule(torch.nn.Module):
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
@ -784,7 +783,6 @@ class GraphModule(torch.nn.Module):
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
@ -883,7 +881,7 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(3)
|
||||
arg_count = ifdynstaticdefault(4, 5)
|
||||
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
|
||||
expected_op_count = ifdynstaticdefault(10, 8)
|
||||
expected_op_count = ifdynstaticdefault(9, 7)
|
||||
out_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x,)),
|
||||
@ -905,7 +903,6 @@ class GraphModule(torch.nn.Module):
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
@ -956,7 +953,7 @@ class GraphModule(torch.nn.Module):
|
||||
y = torch.randn(3)
|
||||
arg_count = ifdynstaticdefault(5, 6)
|
||||
# when compiled with dynamic, we don't have upper bound runtime assertions for u0 and u1
|
||||
expected_op_count = ifdynstaticdefault(17, 13)
|
||||
expected_op_count = ifdynstaticdefault(15, 11)
|
||||
out_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x, y)),
|
||||
@ -977,7 +974,6 @@ class GraphModule(torch.nn.Module):
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int_2: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int_2); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_2 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
@ -987,7 +983,6 @@ class GraphModule(torch.nn.Module):
|
||||
d: "i64[u1, 1]" = l_y_.nonzero(); l_y_ = None
|
||||
|
||||
sym_size_int_3: "Sym(u1)" = torch.ops.aten.sym_size.int(d, 0)
|
||||
_check_is_size_1 = torch._check_is_size(sym_size_int_3); _check_is_size_1 = None
|
||||
|
||||
ge_1: "Sym(u1 >= 0)" = sym_size_int_3 >= 0
|
||||
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None
|
||||
|
||||
@ -10,10 +10,6 @@ import torch.utils.checkpoint
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._functorch._aot_autograd.autograd_cache import BundledCompiledForward
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._higher_order_ops.invoke_subgraph import (
|
||||
NestedCompileBackend,
|
||||
NestedCompileRegionOptions,
|
||||
)
|
||||
from torch._inductor.output_code import RegionalOutputCode
|
||||
from torch._inductor.test_case import run_tests
|
||||
from torch._inductor.utils import run_fw_bw_and_get_code
|
||||
@ -472,86 +468,6 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
# flex in forward and flex_backward in backward
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@parametrize("serialize", [True, False])
|
||||
def test_invoke_subgraph_regional_compile(self, serialize):
|
||||
call_test_partitioner_ct = 0
|
||||
original_default_partitioner = torch._functorch.partitioners.default_partition
|
||||
|
||||
def test_partitioner(
|
||||
*args, **kwargs
|
||||
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
|
||||
nonlocal call_test_partitioner_ct
|
||||
call_test_partitioner_ct += 1
|
||||
return original_default_partitioner(*args, **kwargs)
|
||||
|
||||
# pyrefly: ignore [not-iterable]
|
||||
if serialize:
|
||||
# Callable cannot be serialized
|
||||
torch._functorch.partitioners.default_partition = test_partitioner
|
||||
partitioner = "default_partition"
|
||||
else:
|
||||
partitioner = test_partitioner
|
||||
backend = NestedCompileRegionOptions(
|
||||
backend=NestedCompileBackend.INDUCTOR,
|
||||
inductor_configs={
|
||||
"max_autotune": True,
|
||||
"triton.cudagraphs": False,
|
||||
},
|
||||
partitioner=partitioner,
|
||||
)
|
||||
|
||||
@torch.compiler.nested_compile_region(backend_options=backend)
|
||||
def gn_with_backend(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@torch.compiler.nested_compile_region
|
||||
def gn_without_backend(x):
|
||||
return torch.cos(x)
|
||||
|
||||
def fn(x):
|
||||
return gn_with_backend(x) + gn_without_backend(x)
|
||||
|
||||
backend = aot_eager_regional_inductor(serialize=serialize)
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
|
||||
import torch._inductor.config as inductor_config
|
||||
|
||||
# Hook to verify options
|
||||
original_compile = torch._inductor.standalone_compile
|
||||
captured_options = []
|
||||
|
||||
def verify_options(*args, **kwargs):
|
||||
options = kwargs.get("options", {})
|
||||
captured_options.append(options)
|
||||
|
||||
# Verify config is set as expected from explicit options
|
||||
assert inductor_config.max_autotune, "max_autotune should be True"
|
||||
assert not inductor_config.triton.cudagraphs, (
|
||||
"triton.cudagraphs should be False"
|
||||
)
|
||||
|
||||
return original_compile(*args, **kwargs)
|
||||
|
||||
torch._inductor.standalone_compile = verify_options
|
||||
|
||||
try:
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
# opt_fn(x)
|
||||
res, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
self.assertEqual(len(codes), 2)
|
||||
self.assertTrue("repeated_subgraph0" in codes[0])
|
||||
self.assertTrue("repeated_subgraph1" not in codes[0])
|
||||
self.assertTrue("repeated_subgraph0" in codes[1])
|
||||
self.assertTrue("repeated_subgraph1" not in codes[1])
|
||||
self.assertEqual(call_test_partitioner_ct, 1)
|
||||
true_res = fn(x)
|
||||
self.assertEqual(res, true_res)
|
||||
finally:
|
||||
torch._inductor.standalone_compile = original_compile
|
||||
torch._functorch.partitioners.default_partition = (
|
||||
original_default_partitioner
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
|
||||
class TestRegionalOutputCode(torch._inductor.test_case.TestCase):
|
||||
|
||||
@ -3081,15 +3081,12 @@ def forward(self, x, y):
|
||||
foo = torch.ops.export.foo.default(x, y); x = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(foo, 0)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(foo, 1)
|
||||
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int); sym_constrain_range_for_size_default = None
|
||||
ge = sym_size_int >= 0; sym_size_int = None
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
sym_constrain_range_for_size_default_1 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default_1 = None
|
||||
ge_1 = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_1 = None
|
||||
bar = torch.ops.export.bar.default(y); y = None
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(bar, 0)
|
||||
sym_constrain_range_for_size_default_2 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_2); sym_constrain_range_for_size_default_2 = None
|
||||
ge_2 = sym_size_int_2 >= 0; sym_size_int_2 = None
|
||||
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_default_2 = None
|
||||
return (foo, bar)""",
|
||||
@ -17743,7 +17740,6 @@ class TestExportCustomClass(TorchTestCase):
|
||||
def forward(self, x, mask):
|
||||
masked_select = torch.ops.aten.masked_select.default(x, mask); x = mask = None
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(masked_select, 0)
|
||||
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None
|
||||
ge = sym_size_int_1 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
le = sym_size_int_1 <= 1188864
|
||||
|
||||
@ -21,10 +21,6 @@ from torch._dynamo.testing import (
|
||||
InductorAndRecordGraphs,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch._higher_order_ops.invoke_subgraph import (
|
||||
NestedCompileBackend,
|
||||
NestedCompileRegionOptions,
|
||||
)
|
||||
from torch._higher_order_ops.schema import find_hop_schema
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.pattern_matcher import (
|
||||
@ -1560,101 +1556,6 @@ class GraphModule(torch.nn.Module):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_unbacked_expr(self):
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return x + 1
|
||||
|
||||
def fn(c):
|
||||
d = torch.concat([c, c], dim=0)
|
||||
d = gn(d)
|
||||
return d
|
||||
|
||||
c = torch.randn((64, 32))
|
||||
torch._dynamo.decorators.mark_unbacked(c, 0)
|
||||
|
||||
ref = fn(c)
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
res = opt_fn(c)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_grad_accumulation(self):
|
||||
mod1 = torch.nn.Linear(8, 8)
|
||||
mod2 = torch.nn.Linear(8, 8)
|
||||
mod3 = torch.nn.Linear(8, 8)
|
||||
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
return mod1(x) - mod2(x)
|
||||
|
||||
def fn(c):
|
||||
d = gn(c) - mod3(c)
|
||||
return d * 2
|
||||
|
||||
c = torch.randn((8, 8), requires_grad=True)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
res = opt_fn(c)
|
||||
res.sum().backward()
|
||||
|
||||
# fw_add_nodes = backend.fw_graphs[0].graph.find_nodes(op="call_function", target = torch.ops.aten.add.Tensor)
|
||||
# The gradient addition node for mod3 is not in the subgraph.
|
||||
bw_add_nodes = backend.bw_graphs[0].graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.add.Tensor
|
||||
)
|
||||
self.assertEqual(len(bw_add_nodes), 1)
|
||||
subgraph_node = backend.bw_graphs[0].graph.find_nodes(op="get_attr")[0]
|
||||
subgraph_name = subgraph_node.target
|
||||
# The gradient addition node between mod1 and mode2 will be in the subgraph
|
||||
bw_add_nodes = getattr(backend.bw_graphs[0], subgraph_name).graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.add.Tensor
|
||||
)
|
||||
self.assertEqual(len(bw_add_nodes), 1)
|
||||
|
||||
def test_backend_parameter(self):
|
||||
backend = NestedCompileRegionOptions(NestedCompileBackend.INDUCTOR)
|
||||
|
||||
# Test that backend parameter is properly set in node.meta
|
||||
@nested_compile_region(backend_options=backend)
|
||||
def gn_with_backend(x):
|
||||
return torch.sin(x)
|
||||
|
||||
@nested_compile_region
|
||||
def gn_without_backend(x):
|
||||
return torch.cos(x)
|
||||
|
||||
def fn(x):
|
||||
return gn_with_backend(x) + gn_without_backend(x)
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
|
||||
x = torch.randn(8, 8, requires_grad=False)
|
||||
opt_fn(x)
|
||||
|
||||
# Check that we captured the graph
|
||||
self.assertEqual(len(backend.graphs), 1)
|
||||
graph = backend.graphs[0]
|
||||
|
||||
# Find invoke_subgraph nodes and check their backend metadata
|
||||
invoke_subgraph_nodes = [
|
||||
node
|
||||
for node in graph.graph.nodes
|
||||
if node.op == "call_function"
|
||||
and node.target == torch.ops.higher_order.invoke_subgraph
|
||||
]
|
||||
|
||||
# We should have 2 invoke_subgraph calls
|
||||
self.assertEqual(len(invoke_subgraph_nodes), 2)
|
||||
|
||||
# First invoke_subgraph (gn_with_backend) should have backend
|
||||
self.assertIn("custom", invoke_subgraph_nodes[0].meta)
|
||||
|
||||
# Second invoke_subgraph (gn_without_backend) should have custom=None or no custom
|
||||
backend_value = invoke_subgraph_nodes[1].meta.get("custom", None)
|
||||
self.assertIsNone(backend_value)
|
||||
|
||||
def test_complex(self):
|
||||
# Observed in Wan2.1
|
||||
@nested_compile_region
|
||||
|
||||
@ -1492,8 +1492,8 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
|
||||
clone: "f32[s77][1]cpu" = torch.ops.aten.clone.default(arg1_1)
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
|
||||
getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]
|
||||
@ -1513,8 +1513,8 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1)
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
|
||||
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
|
||||
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
|
||||
@ -1538,8 +1538,8 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
|
||||
alias_default: "f32[s77][1]cpu" = torch.ops.aten.alias.default(arg1_1)
|
||||
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
|
||||
@ -1557,8 +1557,8 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg0_1)
|
||||
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
|
||||
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
|
||||
|
||||
@ -3532,11 +3532,11 @@ class TestUbackedOps(TestCase):
|
||||
aot_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
|
||||
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
|
||||
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
|
||||
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
|
||||
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
|
||||
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
|
||||
@ -3573,11 +3573,11 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)",
|
||||
aot_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
|
||||
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
|
||||
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
|
||||
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
|
||||
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
|
||||
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
|
||||
@ -3632,21 +3632,21 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||
aot_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", arg3_1: "f32[u2, u3][1, u2]cpu"):
|
||||
ge_1: "Sym(u2 >= 0)" = arg1_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge_3: "Sym(u3 >= 0)" = arg2_1 >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
||||
ge: "Sym(u2 >= 0)" = arg1_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
ge_1: "Sym(u3 >= 0)" = arg2_1 >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
|
||||
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
|
||||
ge_4: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_4, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_4 = _assert_scalar_2 = None
|
||||
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None
|
||||
sym_sum: "Sym(u0 + 1)" = torch.sym_sum((1, _local_scalar_dense))
|
||||
gt: "Sym(u0 + 1 > 0)" = sym_sum > 0; sym_sum = None
|
||||
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 0 < u0 + 1 on node 'gt'"); gt = _assert_scalar_3 = None
|
||||
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
|
||||
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
|
||||
ge_5: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
|
||||
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_5 = _assert_scalar_4 = None
|
||||
ge_3: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
|
||||
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_3 = _assert_scalar_4 = None
|
||||
sym_sum_1: "Sym(u1 + 1)" = torch.sym_sum((1, _local_scalar_dense_1))
|
||||
gt_1: "Sym(u1 + 1 > 0)" = sym_sum_1 > 0; sym_sum_1 = None
|
||||
_assert_scalar_5 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 + 1 on node 'gt_1'"); gt_1 = _assert_scalar_5 = None
|
||||
@ -4068,10 +4068,10 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||
self.assertExpectedInline(
|
||||
output,
|
||||
"""\
|
||||
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
||||
ge: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
|
||||
clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None
|
||||
add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
||||
mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None
|
||||
@ -4097,10 +4097,10 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||
self.assertExpectedInline(
|
||||
output,
|
||||
"""\
|
||||
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
||||
ge: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
|
||||
add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None
|
||||
mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None
|
||||
return (mul_5,)""", # noqa: B950
|
||||
@ -4283,11 +4283,11 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||
aot_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
|
||||
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
|
||||
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
|
||||
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
|
||||
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
|
||||
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
|
||||
@ -4319,11 +4319,11 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)",
|
||||
aot_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
|
||||
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
|
||||
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
|
||||
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
|
||||
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
|
||||
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
|
||||
|
||||
@ -121,7 +121,7 @@ class TestOpaqueObject(TestCase):
|
||||
def size_impl_fake(q: OpaqueQueue) -> int:
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
u0 = ctx.new_dynamic_size()
|
||||
torch._check_is_size(u0)
|
||||
torch._check(u0 >= 0)
|
||||
return u0
|
||||
|
||||
torch.library.define(
|
||||
|
||||
@ -33,7 +33,11 @@ from typing import (
|
||||
TypeVar as _TypeVar,
|
||||
Union as _Union,
|
||||
)
|
||||
from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs
|
||||
from typing_extensions import (
|
||||
deprecated as _deprecated,
|
||||
ParamSpec as _ParamSpec,
|
||||
TypeIs as _TypeIs,
|
||||
)
|
||||
|
||||
|
||||
# As a bunch of torch.packages internally still have this check
|
||||
@ -1735,7 +1739,10 @@ def _check(cond, message=None): # noqa: F811
|
||||
_check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
|
||||
|
||||
|
||||
# TODO add deprecation annotation
|
||||
@_deprecated(
|
||||
"_check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. \
|
||||
Use _check(i >= 0) instead."
|
||||
)
|
||||
def _check_is_size(i, message=None, *, max=None):
|
||||
"""Checks that a given integer is a valid size (i.e., is non-negative).
|
||||
You should use this over ``_check(i >= 0)`` because it can prevent
|
||||
|
||||
@ -20,7 +20,6 @@ their semantic behavior.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
@ -43,10 +42,6 @@ from torch._dynamo.variables.functions import UserFunctionVariable
|
||||
from torch._dynamo.variables.nn_module import UnspecializedNNModuleVariable
|
||||
from torch._dynamo.variables.tensor import SymNodeVariable
|
||||
from torch._guards import Source
|
||||
from torch._higher_order_ops.invoke_subgraph import (
|
||||
NestedCompileBackend,
|
||||
NestedCompileRegionOptions,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.utils import _pytree as pytree
|
||||
@ -264,7 +259,6 @@ def _call_function_with_auto_output_flattening(
|
||||
flat_example_value: Any,
|
||||
body_r: Optional[VariableTracker],
|
||||
graph_output_vts: VariableTracker | tuple[VariableTracker, ...],
|
||||
backend_options: Optional[NestedCompileRegionOptions] = None,
|
||||
) -> Optional[VariableTracker]:
|
||||
"""
|
||||
Create HOP call node and reproxify output VTs for HOPs with auto output semantics.
|
||||
@ -291,30 +285,14 @@ def _call_function_with_auto_output_flattening(
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
# Store the invocation as a call
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function",
|
||||
fn,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Set backend metadata if provided
|
||||
if backend_options is not None:
|
||||
if "custom" not in proxy.node.meta:
|
||||
proxy.node.meta["custom"] = {}
|
||||
if backend_options.backend == NestedCompileBackend.INDUCTOR:
|
||||
inductor_configs = {}
|
||||
if backend_options.inductor_configs:
|
||||
inductor_configs = copy.deepcopy(backend_options.inductor_configs)
|
||||
proxy.node.meta["custom"]["compile_with_inductor"] = {
|
||||
"inductor_configs": inductor_configs
|
||||
}
|
||||
if backend_options.partitioner is not None:
|
||||
proxy.node.meta["custom"]["partitioner"] = backend_options.partitioner
|
||||
|
||||
flat_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=proxy,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
fn,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
),
|
||||
example_value=flat_example_value,
|
||||
)
|
||||
|
||||
@ -346,13 +324,7 @@ def _call_function_with_auto_output_flattening(
|
||||
|
||||
|
||||
def _call_function_and_unflatten_output(
|
||||
tx,
|
||||
fn,
|
||||
args,
|
||||
kwargs,
|
||||
flat_example_value,
|
||||
ret_spec,
|
||||
body_r,
|
||||
tx, fn, args, kwargs, flat_example_value, ret_spec, body_r
|
||||
):
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
@ -4263,18 +4235,6 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
||||
],
|
||||
)
|
||||
|
||||
# Extract backend from the function if it was decorated with nested_compile_region(backend=...)
|
||||
backend_options = None
|
||||
fn_var = args[0]
|
||||
if hasattr(fn_var, "get_function"):
|
||||
try:
|
||||
fn = fn_var.get_function()
|
||||
|
||||
if hasattr(fn, "__marked_compile_region_backend__"):
|
||||
backend_options = fn.__marked_compile_region_backend__
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
p_args = (
|
||||
p_args[0],
|
||||
body_name,
|
||||
@ -4288,7 +4248,6 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
||||
example_value,
|
||||
body_r,
|
||||
body_graph_output_vts,
|
||||
backend_options=backend_options,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -675,42 +675,6 @@ def prepare_for_partitioner(mod, num_primals, num_fw_outputs):
|
||||
return out
|
||||
|
||||
|
||||
def _get_partition_fn(fw_hop_node, aot_config):
|
||||
"""
|
||||
Return either the default `partition_fn` in aot_config or a HOP specific partition
|
||||
function.
|
||||
|
||||
If a HOP specific partition function is returned, used_hop_custom_partition is True.
|
||||
|
||||
See Note [InvokeSubgraphHOP Partitioner]
|
||||
"""
|
||||
used_hop_custom_partition = False
|
||||
partition_fn: Callable[..., tuple[torch.fx.GraphModule, torch.fx.GraphModule]] = (
|
||||
aot_config.partition_fn
|
||||
)
|
||||
if (
|
||||
fw_hop_node.target == torch._higher_order_ops.invoke_subgraph
|
||||
and "custom" in fw_hop_node.meta
|
||||
and "partitioner" in fw_hop_node.meta["custom"]
|
||||
):
|
||||
hop_partition_fn = fw_hop_node.meta["custom"]["partitioner"]
|
||||
if callable(hop_partition_fn):
|
||||
partition_fn = hop_partition_fn # pyrefly: ignore [bad-assignment]
|
||||
used_hop_custom_partition = True
|
||||
else:
|
||||
assert isinstance(hop_partition_fn, str)
|
||||
match hop_partition_fn:
|
||||
case "default_partition":
|
||||
partition_fn = torch._functorch.partitioners.default_partition
|
||||
case "min_cut_rematerialization_partition":
|
||||
partition_fn = torch._functorch.partitioners.min_cut_rematerialization_partition
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unknown HOP partitioner config: {hop_partition_fn}"
|
||||
)
|
||||
return used_hop_custom_partition, partition_fn
|
||||
|
||||
|
||||
def run_joint_graph_passes_on_hops(
|
||||
joint_gm: torch.fx.GraphModule,
|
||||
joint_inputs: Any,
|
||||
@ -815,26 +779,13 @@ def run_joint_graph_passes_on_hops(
|
||||
# TODO: invoke_subgraph should track which of its inputs static indices
|
||||
# so it can propagate them to the partitioner (and use in cudagraphs)
|
||||
static_lifetime_input_indices: list[int] = []
|
||||
|
||||
used_hop_custom_partition, partition_fn = _get_partition_fn(
|
||||
fw_hop_node, aot_config
|
||||
)
|
||||
|
||||
# Step 2) and 3) - Run joint graph passes and partitioner
|
||||
try:
|
||||
new_fw_hop_gm, new_bw_hop_gm = partition_fn(
|
||||
joint_hop_gm,
|
||||
[],
|
||||
num_fwd_outputs=num_fw_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
except Exception as e:
|
||||
if used_hop_custom_partition:
|
||||
raise RuntimeError(
|
||||
f"Error in custom partition function for invoke_subgraph node {fw_hop_node.name}: {e}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
new_fw_hop_gm, new_bw_hop_gm = aot_config.partition_fn(
|
||||
joint_hop_gm,
|
||||
[],
|
||||
num_fwd_outputs=num_fw_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
|
||||
# Save the new forward and backward graph modules
|
||||
new_hop_graphs[identifier].new_fw_hop_gm = new_fw_hop_gm
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import contextlib
|
||||
import enum
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -38,6 +36,10 @@ from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
invoke_subgraph_counter = 0
|
||||
|
||||
|
||||
@ -51,32 +53,6 @@ class OutputMetadata:
|
||||
indexes_with_no_grad: set[int] = field(default_factory=set)
|
||||
|
||||
|
||||
class NestedCompileBackend(enum.Enum):
|
||||
INDUCTOR = "inductor"
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NestedCompileRegionOptions:
|
||||
# If default, does nothing, inherient the torch.compile backend
|
||||
# If "inductor", will add {"compile_with_inductor": {"inductor_configs":config}} to HOP node meta "custom"
|
||||
# If "custom" already has "compile_with_inductor", this config will override
|
||||
backend: NestedCompileBackend = NestedCompileBackend.DEFAULT
|
||||
|
||||
# If backend == "inductor", the configs
|
||||
inductor_configs: Optional[dict[str, Any]] = None
|
||||
|
||||
# Note: [InvokeSubgraphHOP Partitioner]
|
||||
# If not None, add "partitioner" to HOP node meta.
|
||||
# If Callable, directly assign the callable, but the callable cannot be pickled
|
||||
# If str, the options are "default_partition" and "min_cut_rematerialization_partition".
|
||||
# The HOP joint graph will be partitioned using the corresponding functions in
|
||||
# torch/_functorch/partitioners.py
|
||||
partitioner: Optional[Callable | str] = None
|
||||
|
||||
# TODO: add decomposition function
|
||||
|
||||
|
||||
class InvokeSubgraphHOP(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
# Invoke subgraph does not have any state, it is just a wrapper over a
|
||||
@ -177,9 +153,7 @@ def invoke_subgraph_placeholder(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def mark_compile_region(
|
||||
fn=None, backend_options: Optional[NestedCompileRegionOptions] = None
|
||||
):
|
||||
def mark_compile_region(fn=None):
|
||||
"""
|
||||
This wrapper instructs torch.compile to compile the wrapped region once and
|
||||
reuse the compiled artifact, instead of the usual way of aggressively
|
||||
@ -187,10 +161,6 @@ def mark_compile_region(
|
||||
|
||||
Under the hood, it tells TorchDynamo to use InvokeSubgraph HOP for the
|
||||
region. For PyTorch eager, this is a no-op.
|
||||
|
||||
Args:
|
||||
fn: The function to wrap
|
||||
backend: Optional backend to use for compiling the subgraph
|
||||
"""
|
||||
|
||||
def wrap(func):
|
||||
@ -202,7 +172,6 @@ def mark_compile_region(
|
||||
return invoke_subgraph_placeholder(inner_func, *args, **kwargs)
|
||||
|
||||
inner.__marked_compile_region_fn__ = func # type: ignore[attr-defined]
|
||||
func.__marked_compile_region_backend__ = backend_options # type: ignore[attr-defined]
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@ -2538,7 +2538,7 @@ class CppKernel(Kernel):
|
||||
@property
|
||||
def assert_function(self) -> str:
|
||||
if V.graph.aot_mode:
|
||||
return "STD_TORCH_CHECK"
|
||||
return "AOTI_TORCH_CHECK"
|
||||
else:
|
||||
return "TORCH_CHECK"
|
||||
|
||||
|
||||
@ -1,21 +1,14 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.invoke_subgraph import NestedCompileRegionOptions
|
||||
|
||||
from . import config
|
||||
|
||||
|
||||
try:
|
||||
from typing import LiteralString
|
||||
except ImportError:
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._cache import CacheInfo
|
||||
|
||||
@ -642,9 +635,7 @@ def skip_all_guards_unsafe(guard_entries):
|
||||
return [False for entry in guard_entries]
|
||||
|
||||
|
||||
def nested_compile_region(
|
||||
fn=None, backend_options: Optional[NestedCompileRegionOptions] = None
|
||||
):
|
||||
def nested_compile_region(fn=None):
|
||||
"""
|
||||
Tells **``torch.compile``** that the marked set of operations forms a nested
|
||||
compile region (which is often repeated in the full model) whose code can be
|
||||
@ -653,8 +644,8 @@ def nested_compile_region(
|
||||
|
||||
During **``torch.compile``** tracing, the compiler applies *hierarchical
|
||||
compilation* with ``nested_compile_region``: it emits optimized code for the
|
||||
marked region the first time it is encountered and re-emits (or "stamps
|
||||
out") the previously compiled code on every subsequent invocation. This can
|
||||
marked region the first time it is encountered and re-emits (or “stamps
|
||||
out”) the previously compiled code on every subsequent invocation. This can
|
||||
substantially reduce overall compile time for deeply-stacked,
|
||||
structurally-identical components such as the transformer layers of a
|
||||
large-language-model (LLM).
|
||||
@ -668,17 +659,13 @@ def nested_compile_region(
|
||||
to reuse, it will transparently re-compile the region. Using it is
|
||||
therefore *safe*: correctness is always preserved, and you pay the extra
|
||||
compilation cost only when required.
|
||||
|
||||
Args:
|
||||
fn: The function to wrap
|
||||
backend: Optional backend to use for compiling the subgraph.
|
||||
"""
|
||||
|
||||
from torch._higher_order_ops.invoke_subgraph import (
|
||||
mark_compile_region as _mark_compile_region,
|
||||
)
|
||||
|
||||
return _mark_compile_region(fn, backend_options=backend_options)
|
||||
return _mark_compile_region(fn)
|
||||
|
||||
|
||||
def load_compiled_function(file: io.IOBase) -> Callable[..., Any]:
|
||||
|
||||
@ -836,9 +836,10 @@ class AOTInductorModelBase {
|
||||
}
|
||||
|
||||
void update_constants_array_from_map() {
|
||||
STD_TORCH_CHECK(
|
||||
constants_map_,
|
||||
"constants_map_ was not ready when constants_ is trying to be constructed from it!");
|
||||
if (!constants_map_) {
|
||||
throw std::runtime_error{
|
||||
"constants_map_ was not ready when constants_ is trying to be constructed from it!"};
|
||||
}
|
||||
if (!constants_) {
|
||||
constants_ =
|
||||
std::make_shared<std::vector<ConstantHandle>>(constants_info_.size());
|
||||
@ -874,7 +875,9 @@ class AOTInductorModelBase {
|
||||
/// Returns true if the model is complete.
|
||||
bool is_finished() {
|
||||
#ifdef USE_CUDA
|
||||
STD_TORCH_CHECK(run_finished_, "Model CUDA event was not initialized");
|
||||
if (!run_finished_) {
|
||||
throw std::runtime_error{"Model CUDA event was not initialized"};
|
||||
}
|
||||
|
||||
auto event_status = cudaEventQuery(*run_finished_);
|
||||
if (event_status == cudaSuccess) {
|
||||
@ -883,13 +886,13 @@ class AOTInductorModelBase {
|
||||
return false;
|
||||
}
|
||||
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"The model did not finish successfully. Error: ",
|
||||
throw std::runtime_error(
|
||||
std::string("The model did not finish successfully. Error: ") +
|
||||
cudaGetErrorString(cudaGetLastError()));
|
||||
#elif defined(USE_XPU)
|
||||
STD_TORCH_CHECK(run_finished_, "Model XPU event was not initialized");
|
||||
|
||||
if (!run_finished_) {
|
||||
throw std::runtime_error{"Model XPU event was not initialized"};
|
||||
}
|
||||
using namespace sycl::info;
|
||||
return (*run_finished_)->get_info<event::command_execution_status>() ==
|
||||
event_command_status::complete;
|
||||
@ -901,14 +904,19 @@ class AOTInductorModelBase {
|
||||
|
||||
/// Synchronizes completion event.
|
||||
void wait_for_completion() {
|
||||
STD_TORCH_CHECK(run_finished_, "Model event was not initialized");
|
||||
#ifdef USE_CUDA
|
||||
if (!run_finished_) {
|
||||
throw std::runtime_error{"Model event was not initialized"};
|
||||
}
|
||||
|
||||
AOTI_RUNTIME_CUDA_CHECK(cudaEventSynchronize(*run_finished_));
|
||||
#endif // USE_CUDA
|
||||
|
||||
#ifdef USE_XPU
|
||||
if (!run_finished_) {
|
||||
throw std::runtime_error{"Model event was not initialized"};
|
||||
}
|
||||
(*run_finished_)->wait_and_throw();
|
||||
#endif // USE_XPU
|
||||
#endif
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
@ -123,10 +123,8 @@ class AOTInductorModelContainer {
|
||||
constants_folding_lk.unlock();
|
||||
model_lk.lock();
|
||||
} else if (const_folded != ConstantState::FOLDED) {
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Unknown constant state: ",
|
||||
toStringConstantState(constant_folded_));
|
||||
throw std::runtime_error(
|
||||
"Unknown constant state: " + toStringConstantState(constant_folded_));
|
||||
}
|
||||
|
||||
try {
|
||||
@ -169,10 +167,8 @@ class AOTInductorModelContainer {
|
||||
/* validate_full_update = */ false);
|
||||
const_folded = ConstantState::FOLDED;
|
||||
} else if (constant_folded_ != ConstantState::FOLDED) {
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Unknown constant state: ",
|
||||
toStringConstantState(constant_folded_));
|
||||
throw std::runtime_error(
|
||||
"Unknown constant state: " + toStringConstantState(constant_folded_));
|
||||
}
|
||||
|
||||
model->run_single_threaded(
|
||||
@ -206,56 +202,56 @@ class AOTInductorModelContainer {
|
||||
}
|
||||
|
||||
size_t num_constants() const {
|
||||
STD_TORCH_CHECK(
|
||||
this->num_models() != 0, "No available models in container!");
|
||||
|
||||
if (this->num_models() == 0) {
|
||||
throw std::runtime_error("No available models in container!");
|
||||
}
|
||||
return models_[0]->num_constants();
|
||||
}
|
||||
|
||||
// retrieve the constant name of constants_info_[idx]
|
||||
const char* constant_name(size_t idx) const {
|
||||
STD_TORCH_CHECK(
|
||||
this->num_models() != 0, "No available models in container!");
|
||||
|
||||
if (this->num_models() == 0) {
|
||||
throw std::runtime_error("No available models in container!");
|
||||
}
|
||||
return models_[0]->constant_name(static_cast<int64_t>(idx));
|
||||
}
|
||||
|
||||
// retrieve original FQN of constants_info_[idx]
|
||||
const char* constant_original_fqn(size_t idx) const {
|
||||
STD_TORCH_CHECK(
|
||||
this->num_models() != 0, "No available models in container!");
|
||||
|
||||
if (this->num_models() == 0) {
|
||||
throw std::runtime_error("No available models in container!");
|
||||
}
|
||||
return models_[0]->constant_original_fqn(static_cast<int64_t>(idx));
|
||||
}
|
||||
|
||||
// retrieve whether constant is from folded of constants_info_[idx]
|
||||
bool constant_from_folded(size_t idx) const {
|
||||
STD_TORCH_CHECK(
|
||||
this->num_models() != 0, "No available models in container!");
|
||||
|
||||
if (this->num_models() == 0) {
|
||||
throw std::runtime_error("No available models in container!");
|
||||
}
|
||||
return models_[0]->constant_from_folded(static_cast<int64_t>(idx));
|
||||
}
|
||||
|
||||
size_t constant_data_size(size_t idx) const {
|
||||
STD_TORCH_CHECK(
|
||||
this->num_models() != 0, "No available models in container!");
|
||||
|
||||
if (this->num_models() == 0) {
|
||||
throw std::runtime_error("No available models in container!");
|
||||
}
|
||||
return models_[0]->constant_data_size(static_cast<int64_t>(idx));
|
||||
}
|
||||
|
||||
// retrieve type of constants_info_[idx]
|
||||
int32_t constant_type(size_t idx) const {
|
||||
STD_TORCH_CHECK(
|
||||
this->num_models() != 0, "No available models in container!");
|
||||
|
||||
if (this->num_models() == 0) {
|
||||
throw std::runtime_error("No available models in container!");
|
||||
}
|
||||
return models_[0]->constant_type(static_cast<int64_t>(idx));
|
||||
}
|
||||
|
||||
// retrieve dtype of constants_info_[idx]
|
||||
int32_t constant_dtype(size_t idx) const {
|
||||
STD_TORCH_CHECK(
|
||||
this->num_models() != 0, "No available models in container!");
|
||||
|
||||
if (this->num_models() == 0) {
|
||||
throw std::runtime_error("No available models in container!");
|
||||
}
|
||||
return models_[0]->constant_dtype(static_cast<int64_t>(idx));
|
||||
}
|
||||
|
||||
@ -387,12 +383,9 @@ class AOTInductorModelContainer {
|
||||
<< " in model, but not provided by user!\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Cannot find constants ",
|
||||
constant_name,
|
||||
" in constants_map!");
|
||||
throw std::runtime_error(
|
||||
std::string("Cannot find constants ") + constant_name +
|
||||
std::string(" in constants_map!"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -402,8 +395,9 @@ class AOTInductorModelContainer {
|
||||
std::unordered_map<std::string, AtenTensorHandle>&& constants_map,
|
||||
bool use_inactive,
|
||||
bool validate_full_update) {
|
||||
STD_TORCH_CHECK(
|
||||
this->num_models() != 0, "No available models in container!");
|
||||
if (this->num_models() == 0) {
|
||||
throw std::runtime_error("No model available in container!");
|
||||
}
|
||||
if (validate_full_update) {
|
||||
assert_all_constants(constants_map);
|
||||
}
|
||||
@ -449,9 +443,9 @@ class AOTInductorModelContainer {
|
||||
bool use_inactive,
|
||||
bool validate_full_update,
|
||||
bool user_managed = false) {
|
||||
STD_TORCH_CHECK(
|
||||
this->num_models() != 0, "No model available in container!");
|
||||
|
||||
if (this->num_models() == 0) {
|
||||
throw std::runtime_error("No model available in container!");
|
||||
}
|
||||
if (validate_full_update) {
|
||||
assert_all_constants(constants_map);
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@ namespace torch::aot_inductor {
|
||||
|
||||
template <typename T>
|
||||
inline RAIIAtenTensorHandle scalar_to_tensor_handle(T value) {
|
||||
STD_TORCH_CHECK(false, "Unsupported scalar_to_tensor_handle");
|
||||
throw std::runtime_error("Unsupported scalar_to_tensor_handle");
|
||||
}
|
||||
|
||||
// Specialize for supported C++ primitive types
|
||||
|
||||
@ -11,11 +11,11 @@ template <>
|
||||
struct ThreadLocalCachedOutputTensor<RAIIAtenTensorHandle> {
|
||||
explicit ThreadLocalCachedOutputTensor(const RAIIAtenTensorHandle&) {}
|
||||
void copy_data_from(const RAIIAtenTensorHandle& handle) {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
|
||||
AtenTensorHandle tensor() const {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
};
|
||||
|
||||
@ -23,11 +23,11 @@ template <>
|
||||
struct ThreadLocalCachedOutputTensor<AtenTensorHandle> {
|
||||
explicit ThreadLocalCachedOutputTensor(const AtenTensorHandle&) {}
|
||||
void copy_data_from(const AtenTensorHandle& handle) {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
|
||||
AtenTensorHandle tensor() const {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
};
|
||||
|
||||
@ -35,11 +35,11 @@ template <>
|
||||
struct ThreadLocalCachedOutputTensor<ConstantHandle> {
|
||||
explicit ThreadLocalCachedOutputTensor(const ConstantHandle&) {}
|
||||
void copy_data_from(const ConstantHandle& handle) {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
|
||||
AtenTensorHandle tensor() const {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
};
|
||||
|
||||
@ -92,18 +92,18 @@ struct ThreadLocalCachedOutputArray;
|
||||
template <>
|
||||
struct ThreadLocalCachedOutputArray<RAIIAtenTensorHandle> {
|
||||
explicit ThreadLocalCachedOutputArray(const RAIIAtenTensorHandle&) {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
|
||||
// Not supported yet! We would need to put contiguous() or
|
||||
// expect_contiguous() into the ABI.
|
||||
void copy_data_from(const RAIIAtenTensorHandle&) {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
ArrayRefTensor<U> arrayref_tensor() const {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
};
|
||||
|
||||
@ -111,18 +111,18 @@ struct ThreadLocalCachedOutputArray<RAIIAtenTensorHandle> {
|
||||
template <>
|
||||
struct ThreadLocalCachedOutputArray<ConstantHandle> {
|
||||
explicit ThreadLocalCachedOutputArray(const ConstantHandle&) {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
|
||||
// Not supported yet! We would need to put contiguous() or
|
||||
// expect_contiguous() into the ABI.
|
||||
void copy_data_from(const ConstantHandle&) {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
ArrayRefTensor<U> arrayref_tensor() const {
|
||||
STD_TORCH_CHECK(false, "can't happen");
|
||||
throw std::runtime_error("can't happen");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -38,10 +38,9 @@
|
||||
|
||||
// The following files are implemented in a header-only way and are guarded by
|
||||
// test/cpp/aoti_abi_check
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
#include <torch/headeronly/util/complex.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -622,8 +621,34 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function(
|
||||
int num_tensors,
|
||||
AtenTensorHandle* flatten_tensor_args);
|
||||
|
||||
// Preserve for BC and will delete it later, using the STD_TORCH_CHECK directly
|
||||
#define AOTI_TORCH_CHECK(cond, ...) STD_TORCH_CHECK(cond, ##__VA_ARGS__)
|
||||
AOTI_TORCH_EXPORT void aoti_torch_check(
|
||||
bool cond,
|
||||
const char* func,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
const char* msg);
|
||||
|
||||
#ifdef STRIP_ERROR_MESSAGES
|
||||
#define AOTI_TORCH_CHECK(cond, ...) \
|
||||
if (!(cond)) { \
|
||||
aoti_torch_check( \
|
||||
false, \
|
||||
__func__, \
|
||||
__FILE__, \
|
||||
static_cast<uint32_t>(__LINE__), \
|
||||
TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \
|
||||
}
|
||||
#else
|
||||
#define AOTI_TORCH_CHECK(cond, ...) \
|
||||
if (!(cond)) { \
|
||||
aoti_torch_check( \
|
||||
false, \
|
||||
__func__, \
|
||||
__FILE__, \
|
||||
static_cast<uint32_t>(__LINE__), \
|
||||
TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \
|
||||
}
|
||||
#endif
|
||||
|
||||
AOTI_TORCH_EXPORT void aoti_torch_warn(
|
||||
const char* func,
|
||||
|
||||
@ -1339,14 +1339,13 @@ AOTITorchError aoti_torch_proxy_executor_call_function(
|
||||
int num_tensors,
|
||||
AtenTensorHandle* flatten_tensor_args) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
TORCH_CHECK(
|
||||
proxy_executor != nullptr,
|
||||
"Unable to find a proxy executor to run custom ops.",
|
||||
"Please check if there is a json file generated",
|
||||
"in the same directory as the so,",
|
||||
"or use torch._inductor.aoti_compile_and_package",
|
||||
"to package everything into a PT2 artifact.");
|
||||
|
||||
if (!proxy_executor) {
|
||||
throw std::runtime_error(
|
||||
"Unable to find a proxy executor to run custom ops. Please check if "
|
||||
"there is a json file generated in the same directory as the so, or use "
|
||||
"torch._inductor.aoti_compile_and_package to package everything into a "
|
||||
"PT2 artifact.");
|
||||
}
|
||||
ProxyExecutor* executor = reinterpret_cast<ProxyExecutor*>(proxy_executor);
|
||||
executor->call_function(
|
||||
extern_node_index,
|
||||
@ -1357,6 +1356,17 @@ AOTITorchError aoti_torch_proxy_executor_call_function(
|
||||
});
|
||||
}
|
||||
|
||||
void aoti_torch_check(
|
||||
bool cond,
|
||||
const char* func,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
const char* msg) {
|
||||
if (C10_UNLIKELY_OR_CONST(!cond)) {
|
||||
::c10::detail::torchCheckFail(func, file, line, msg);
|
||||
}
|
||||
}
|
||||
|
||||
void aoti_torch_warn(
|
||||
const char* func,
|
||||
const char* file,
|
||||
|
||||
@ -10,7 +10,9 @@ AOTITorchError aoti_torch_mps_set_arg_tensor(
|
||||
AtenTensorHandle tensor) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto t = tensor_handle_to_tensor_pointer(tensor);
|
||||
TORCH_CHECK(t != nullptr, "Tensor is null.");
|
||||
if (t == nullptr) {
|
||||
throw std::runtime_error("Tensor is null.");
|
||||
}
|
||||
auto func = reinterpret_cast<at::native::mps::MetalKernelFunction*>(handle);
|
||||
func->setArg(idx, *t);
|
||||
});
|
||||
|
||||
@ -92,11 +92,13 @@ inline void assert_inf_and_nan(
|
||||
const std::string& tensor_name,
|
||||
at::Tensor& check_tensor) {
|
||||
auto isnan_tensor = check_tensor.isnan();
|
||||
TORCH_CHECK(
|
||||
!isnan_tensor.any().item<bool>(), "At least one NaN in ", tensor_name);
|
||||
if (isnan_tensor.any().item<bool>()) {
|
||||
throw std::runtime_error("At least one NaN in " + tensor_name);
|
||||
}
|
||||
auto isinf_tensor = check_tensor.isinf();
|
||||
TORCH_CHECK(
|
||||
!isinf_tensor.any().item<bool>(), "At least one INF in ", tensor_name);
|
||||
if (isinf_tensor.any().item<bool>()) {
|
||||
throw std::runtime_error("At least one INF in " + tensor_name);
|
||||
}
|
||||
}
|
||||
|
||||
// utility functions to convert a pointer to an optional value
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import copy
|
||||
import inspect
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
@ -96,16 +97,23 @@ class _ToTorchTensor(torch.autograd.Function):
|
||||
)
|
||||
tensor_stride = tuple(tensor_stride)
|
||||
grad_placements = grad_placements or dtensor_spec.placements
|
||||
grad_spec = DTensorSpec(
|
||||
mesh,
|
||||
grad_placements,
|
||||
tensor_meta=TensorMeta(
|
||||
shape=dtensor_meta.shape,
|
||||
stride=tensor_stride,
|
||||
dtype=dtensor_meta.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
tensor_stride == dtensor_meta.stride
|
||||
and grad_placements == dtensor_spec.placements
|
||||
):
|
||||
# Avoid actual sharing of specs in case they're modified during (e.g.)
|
||||
# sharding propagation.
|
||||
grad_spec = copy.copy(dtensor_spec)
|
||||
else:
|
||||
grad_spec = DTensorSpec(
|
||||
mesh,
|
||||
grad_placements,
|
||||
tensor_meta=TensorMeta(
|
||||
shape=dtensor_meta.shape,
|
||||
stride=tensor_stride,
|
||||
dtype=dtensor_meta.dtype,
|
||||
),
|
||||
)
|
||||
return (
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
DTensor(
|
||||
|
||||
@ -470,6 +470,10 @@ def has_static_value(a: Union[SymBool, SymFloat, SymInt, bool, float, int]) -> b
|
||||
return a.node.shape_env.bound_sympy(a.node.expr).is_singleton() # type: ignore[union-attr]
|
||||
|
||||
|
||||
@deprecated(
|
||||
"guard_size_oblivious will be removed. Consider using explicit unbacked handling \
|
||||
potentially utilizing guard_or_false, guard_or_true, or statically_known_true"
|
||||
)
|
||||
def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool:
|
||||
"""
|
||||
Perform a guard on a symbolic boolean expression in a size oblivious way.
|
||||
|
||||
@ -576,17 +576,6 @@ def insert_deferred_runtime_asserts(
|
||||
if i0 in constrained_unbacked_symbols:
|
||||
continue # constrain symbol just once
|
||||
|
||||
if i0 in shape_env.size_like:
|
||||
if export:
|
||||
graph.call_function(
|
||||
torch.ops.aten.sym_constrain_range_for_size.default,
|
||||
(expr_to_proxy[i0].node,),
|
||||
)
|
||||
else:
|
||||
graph.call_function(
|
||||
torch._check_is_size, (expr_to_proxy[i0].node,)
|
||||
)
|
||||
|
||||
vr = shape_env.var_to_range[i0]
|
||||
if vr.is_int and vr.upper == sys.maxsize - 1:
|
||||
# treat upper bound == sys.maxsize - 1 for int symbols as +oo
|
||||
|
||||
Reference in New Issue
Block a user