Compare commits

..

16 Commits

Author SHA1 Message Date
54f7347a5b Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela

[ghstack-poisoned]
2025-11-14 19:25:04 -08:00
82afb7deda Update base for Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela

[ghstack-poisoned]
2025-11-14 19:25:04 -08:00
7aa210d215 Revert "[CodeClean] Remove the Unused MACRO for AOT Inductor Runtime (#165139)"
This reverts commit fcd5f8c352b5b75bd32e57fa044ec5df095032da.

Reverted https://github.com/pytorch/pytorch/pull/165139 on behalf of https://github.com/jeanschmidt due to trying to hevert in the hopes it fixes internal errors, will land it back ([comment](https://github.com/pytorch/pytorch/pull/165139#issuecomment-3534662138))
2025-11-14 21:35:37 +00:00
5a368b8010 Revert "[CodeClean] Replace std::runtime_error with TORCH_CHECK (#165119)"
This reverts commit 398775a43e9808205f75c81d36f5087117d3f3f4.

Reverted https://github.com/pytorch/pytorch/pull/165119 on behalf of https://github.com/jeanschmidt due to trying to hevert in the hopes it fixes internal errors, will land it back ([comment](https://github.com/pytorch/pytorch/pull/165139#issuecomment-3534662138))
2025-11-14 21:35:37 +00:00
602102be50 Revert "Hide all symbols (except stable/headeronly/shim) if TORCH_STABLE_ONLY is defined (#167496)"
This reverts commit bc09a84150eaadaadab8a8ecd76cd9afc60d8a19.

Reverted https://github.com/pytorch/pytorch/pull/167496 on behalf of https://github.com/jeanschmidt due to trying to revert 165139, my intention is to land it again, so, will land this once both are reverted ([comment](https://github.com/pytorch/pytorch/pull/167496#issuecomment-3534641209))
2025-11-14 21:33:02 +00:00
200156e385 DTensor: avoid unnecessary DTensorSpec creation in _ToTorchTensor.backward (#167588)
Looks like the check here is cheap and has a potentially large payoff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167588
Approved by: https://github.com/ezyang
2025-11-14 21:08:12 +00:00
3710cad6d7 Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela

[ghstack-poisoned]
2025-11-14 09:19:05 -08:00
4d2cc7d490 Update base for Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela

[ghstack-poisoned]
2025-11-14 09:19:05 -08:00
0ca26bbd2b Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-13 17:24:11 -08:00
76c6d99ba9 Update base for Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-13 17:24:11 -08:00
cf1ea48d0a Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-13 07:37:03 -08:00
68283bd54c Update base for Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-13 07:37:03 -08:00
b6ab8b28a4 Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-12 23:18:15 -08:00
4c9ebe5b2f Update base for Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-12 23:18:15 -08:00
7104addf1e Update on "deprecate check_is_size and guard_size_oblivious"
cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
2025-11-06 07:52:23 -08:00
0deaae4852 deprecate sizelike and guard_size_oblivious
[ghstack-poisoned]
2025-11-05 22:12:04 -08:00
32 changed files with 321 additions and 893 deletions

View File

@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
)
def _compile_and_extract_symbols(
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
) -> list[str]:
"""
Helper to compile a C++ file and extract all symbols.
Args:
cpp_content: C++ source code to compile
compile_flags: Compilation flags
exclude_list: List of symbol names to exclude. Defaults to ["main"].
Returns:
List of all symbols found in the object file (excluding those in exclude_list).
"""
import subprocess
import tempfile
if exclude_list is None:
exclude_list = ["main"]
with tempfile.TemporaryDirectory() as tmpdir:
tmppath = Path(tmpdir)
cpp_file = tmppath / "test.cpp"
obj_file = tmppath / "test.o"
cpp_file.write_text(cpp_content)
result = subprocess.run(
compile_flags + [str(cpp_file), "-o", str(obj_file)],
capture_output=True,
text=True,
timeout=60,
)
if result.returncode != 0:
raise RuntimeError(f"Compilation failed: {result.stderr}")
symbols = get_symbols(str(obj_file))
# Return all symbol names, excluding those in the exclude list
return [name for _addr, _stype, name in symbols if name not in exclude_list]
def check_stable_only_symbols(install_root: Path) -> None:
"""
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
This approach tests:
1. WITHOUT macros -> many torch symbols exposed
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
4. WITH both macros -> zero torch symbols (all hidden)
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
test_cpp_content = """
// Main torch C++ API headers
#include <torch/torch.h>
#include <torch/all.h>
// ATen tensor library
#include <ATen/ATen.h>
// Core c10 headers (commonly used)
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Optional.h>
int main() { return 0; }
"""
base_compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c", # Compile only, don't link
]
# Compile WITHOUT any macros
symbols_without = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=base_compile_flags,
)
# We expect constexpr symbols, inline functions used by other headers etc.
# to produce symbols
num_symbols_without = len(symbols_without)
print(f"Found {num_symbols_without} symbols without any macros defined")
assert num_symbols_without != 0, (
"Expected a non-zero number of symbols without any macros"
)
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
symbols_with_stable_only = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_stable_only,
)
num_symbols_with_stable_only = len(symbols_with_stable_only)
assert num_symbols_with_stable_only == 0, (
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
)
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
compile_flags_with_target_version = base_compile_flags + [
"-DTORCH_TARGET_VERSION=1"
]
symbols_with_target_version = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_target_version,
)
num_symbols_with_target_version = len(symbols_with_target_version)
assert num_symbols_with_target_version == 0, (
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
)
# Compile WITH both macros (expect 0 symbols)
compile_flags_with_both = base_compile_flags + [
"-DTORCH_STABLE_ONLY",
"-DTORCH_TARGET_VERSION=1",
]
symbols_with_both = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_both,
)
num_symbols_with_both = len(symbols_with_both)
assert num_symbols_with_both == 0, (
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
)
def check_stable_api_symbols(install_root: Path) -> None:
"""
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
stable_dir = include_dir / "torch" / "csrc" / "stable"
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
stable_headers = list(stable_dir.rglob("*.h"))
if not stable_headers:
raise RuntimeError("Could not find any stable headers")
includes = []
for header in stable_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_stable_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable = _compile_and_extract_symbols(
cpp_content=test_stable_content,
compile_flags=compile_flags,
)
num_symbols_stable = len(symbols_stable)
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
assert num_symbols_stable > 0, (
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable} symbols"
)
def check_headeronly_symbols(install_root: Path) -> None:
"""
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Find all headers in torch/headeronly
headeronly_dir = include_dir / "torch" / "headeronly"
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
headeronly_headers = list(headeronly_dir.rglob("*.h"))
if not headeronly_headers:
raise RuntimeError("Could not find any headeronly headers")
# Filter out platform-specific headers that may not compile everywhere
platform_specific_keywords = [
"cpu/vec",
]
filtered_headers = []
for header in headeronly_headers:
rel_path = header.relative_to(include_dir).as_posix()
if not any(
keyword in rel_path.lower() for keyword in platform_specific_keywords
):
filtered_headers.append(header)
includes = []
for header in filtered_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_headeronly_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_headeronly = _compile_and_extract_symbols(
cpp_content=test_headeronly_content,
compile_flags=compile_flags,
)
num_symbols_headeronly = len(symbols_headeronly)
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
assert num_symbols_headeronly > 0, (
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_headeronly} symbols"
)
def check_aoti_shim_symbols(install_root: Path) -> None:
"""
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_shim_content = """
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
int main() {
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
int32_t (*fp2)() = &aoti_torch_dtype_float32;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_shim = _compile_and_extract_symbols(
cpp_content=test_shim_content,
compile_flags=compile_flags,
)
num_symbols_shim = len(symbols_shim)
assert num_symbols_shim > 0, (
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_shim} symbols"
)
def check_stable_c_shim_symbols(install_root: Path) -> None:
"""
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Check if the stable C shim exists
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
if not stable_shim.exists():
raise RuntimeError("Could not find stable c shim")
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_stable_shim_content = """
#include <torch/csrc/stable/c/shim.h>
int main() {
// Reference stable C API functions to create undefined symbols
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable_shim = _compile_and_extract_symbols(
cpp_content=test_stable_shim_content,
compile_flags=compile_flags,
)
num_symbols_stable_shim = len(symbols_stable_shim)
assert num_symbols_stable_shim > 0, (
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable_shim} symbols"
)
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
print(f"lib: {lib}")
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
@ -460,13 +129,6 @@ def main() -> None:
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
# Check symbols when TORCH_STABLE_ONLY is defined
check_stable_only_symbols(install_root)
check_stable_api_symbols(install_root)
check_headeronly_symbols(install_root)
check_aoti_shim_symbols(install_root)
check_stable_c_shim_symbols(install_root)
if __name__ == "__main__":
main()

View File

@ -1358,45 +1358,6 @@ class concat_license_files:
# Need to create the proper LICENSE.txt for the wheel
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
Excludes:
- torch/include/torch/headeronly/*
- torch/include/torch/csrc/stable/*
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
- torch/include/torch/csrc/inductor/aoti_torch/generated/
"""
header_extensions = (".h", ".hpp", ".cuh")
header_files = [
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
]
# Paths to exclude from wrapping
exclude_dir_patterns = [
"torch/include/torch/headeronly/",
"torch/include/torch/csrc/stable/",
"torch/include/torch/csrc/inductor/aoti_torch/c/",
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
]
for header_file in header_files:
rel_path = header_file.relative_to(bdist_dir).as_posix()
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
report(f"Skipping header: {rel_path}")
continue
original_content = header_file.read_text(encoding="utf-8")
wrapped_content = (
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
f"{original_content}"
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
)
header_file.write_text(wrapped_content, encoding="utf-8")
report(f"Wrapped header: {rel_path}")
def run(self) -> None:
with concat_license_files(include_files=True):
super().run()
@ -1419,14 +1380,6 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
# need an __init__.py file otherwise we wouldn't have a package
(bdist_dir / "torch" / "__init__.py").touch()
# Wrap all header files with TORCH_STABLE_ONLY macro
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
bdist_dir = Path(self.bdist_dir)
report(
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
)
self._wrap_headers_with_macro(bdist_dir)
class clean(Command):
user_options: ClassVar[list[tuple[str, str | None, str]]] = []

View File

@ -33,7 +33,7 @@ class clean(distutils.command.clean.clean):
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
"cxx": ["-fdiagnostics-color=always"],
}
extension = CppExtension

View File

@ -0,0 +1,67 @@
import distutils.command.clean
import shutil
from pathlib import Path
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
ROOT_DIR = Path(__file__).parent
CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc"
class clean(distutils.command.clean.clean):
def run(self):
# Run default behavior first
distutils.command.clean.clean.run(self)
# Remove extension
for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"):
path.unlink()
# Remove build and dist and egg-info directories
dirs = [
ROOT_DIR / "build",
ROOT_DIR / "dist",
ROOT_DIR / "torch_stable_test.egg-info",
]
for path in dirs:
if path.exists():
shutil.rmtree(str(path), ignore_errors=True)
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
}
sources = list(CSRC_DIR.glob("**/*.cpp"))
return [
CppExtension(
"torch_stable_test._C",
sources=sorted(str(s) for s in sources),
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=[],
)
]
setup(
name="torch_stable_test",
version="0.0",
author="PyTorch Core Team",
description="Test extension to verify TORCH_STABLE_ONLY flag",
packages=find_packages(exclude=("test",)),
package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]},
install_requires=[
"torch",
],
ext_modules=get_extension(),
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean,
},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)

View File

@ -0,0 +1 @@
#include <ATen/core/TensorBase.h> // This should trigger the TORCH_STABLE_ONLY error

View File

@ -0,0 +1,22 @@
# Owner(s): ["module: cpp"]
from pathlib import Path
from torch.testing._internal.common_utils import (
install_cpp_extension,
IS_WINDOWS,
run_tests,
TestCase,
)
if not IS_WINDOWS:
class TestTorchStable(TestCase):
def test_setup_fails(self):
with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"):
install_cpp_extension(extension_root=Path(__file__).parent.parent)
if __name__ == "__main__":
run_tests()

View File

@ -90,12 +90,12 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", primals_4: "f32[u0, u1, u2]"):
ge_1: "Sym(u0 >= 0)" = primals_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = primals_2 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
ge_5: "Sym(u2 >= 0)" = primals_3 >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
ge: "Sym(u0 >= 0)" = primals_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
ge_1: "Sym(u1 >= 0)" = primals_2 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
ge_2: "Sym(u2 >= 0)" = primals_3 >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None
floordiv: "Sym((u0//2))" = primals_1 // 2

View File

@ -727,7 +727,7 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
arg_count = ifdynstaticdefault(4, 5)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
expected_op_count = ifdynstaticdefault(10, 8)
expected_op_count = ifdynstaticdefault(9, 7)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x,)),
@ -747,7 +747,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -784,7 +783,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -883,7 +881,7 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
arg_count = ifdynstaticdefault(4, 5)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
expected_op_count = ifdynstaticdefault(10, 8)
expected_op_count = ifdynstaticdefault(9, 7)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x,)),
@ -905,7 +903,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -956,7 +953,7 @@ class GraphModule(torch.nn.Module):
y = torch.randn(3)
arg_count = ifdynstaticdefault(5, 6)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0 and u1
expected_op_count = ifdynstaticdefault(17, 13)
expected_op_count = ifdynstaticdefault(15, 11)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x, y)),
@ -977,7 +974,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_2: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_2); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_2 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -987,7 +983,6 @@ class GraphModule(torch.nn.Module):
d: "i64[u1, 1]" = l_y_.nonzero(); l_y_ = None
sym_size_int_3: "Sym(u1)" = torch.ops.aten.sym_size.int(d, 0)
_check_is_size_1 = torch._check_is_size(sym_size_int_3); _check_is_size_1 = None
ge_1: "Sym(u1 >= 0)" = sym_size_int_3 >= 0
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None

View File

@ -10,10 +10,6 @@ import torch.utils.checkpoint
from torch._dynamo.backends.common import aot_autograd
from torch._functorch._aot_autograd.autograd_cache import BundledCompiledForward
from torch._guards import detect_fake_mode
from torch._higher_order_ops.invoke_subgraph import (
NestedCompileBackend,
NestedCompileRegionOptions,
)
from torch._inductor.output_code import RegionalOutputCode
from torch._inductor.test_case import run_tests
from torch._inductor.utils import run_fw_bw_and_get_code
@ -472,86 +468,6 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
# flex in forward and flex_backward in backward
self.assertEqual(len(codes), 2)
@parametrize("serialize", [True, False])
def test_invoke_subgraph_regional_compile(self, serialize):
call_test_partitioner_ct = 0
original_default_partitioner = torch._functorch.partitioners.default_partition
def test_partitioner(
*args, **kwargs
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
nonlocal call_test_partitioner_ct
call_test_partitioner_ct += 1
return original_default_partitioner(*args, **kwargs)
# pyrefly: ignore [not-iterable]
if serialize:
# Callable cannot be serialized
torch._functorch.partitioners.default_partition = test_partitioner
partitioner = "default_partition"
else:
partitioner = test_partitioner
backend = NestedCompileRegionOptions(
backend=NestedCompileBackend.INDUCTOR,
inductor_configs={
"max_autotune": True,
"triton.cudagraphs": False,
},
partitioner=partitioner,
)
@torch.compiler.nested_compile_region(backend_options=backend)
def gn_with_backend(x):
return torch.sin(x)
@torch.compiler.nested_compile_region
def gn_without_backend(x):
return torch.cos(x)
def fn(x):
return gn_with_backend(x) + gn_without_backend(x)
backend = aot_eager_regional_inductor(serialize=serialize)
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
import torch._inductor.config as inductor_config
# Hook to verify options
original_compile = torch._inductor.standalone_compile
captured_options = []
def verify_options(*args, **kwargs):
options = kwargs.get("options", {})
captured_options.append(options)
# Verify config is set as expected from explicit options
assert inductor_config.max_autotune, "max_autotune should be True"
assert not inductor_config.triton.cudagraphs, (
"triton.cudagraphs should be False"
)
return original_compile(*args, **kwargs)
torch._inductor.standalone_compile = verify_options
try:
x = torch.randn(8, 8, requires_grad=True)
# opt_fn(x)
res, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
self.assertEqual(len(codes), 2)
self.assertTrue("repeated_subgraph0" in codes[0])
self.assertTrue("repeated_subgraph1" not in codes[0])
self.assertTrue("repeated_subgraph0" in codes[1])
self.assertTrue("repeated_subgraph1" not in codes[1])
self.assertEqual(call_test_partitioner_ct, 1)
true_res = fn(x)
self.assertEqual(res, true_res)
finally:
torch._inductor.standalone_compile = original_compile
torch._functorch.partitioners.default_partition = (
original_default_partitioner
)
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
class TestRegionalOutputCode(torch._inductor.test_case.TestCase):

View File

@ -3081,15 +3081,12 @@ def forward(self, x, y):
foo = torch.ops.export.foo.default(x, y); x = None
sym_size_int = torch.ops.aten.sym_size.int(foo, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(foo, 1)
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int); sym_constrain_range_for_size_default = None
ge = sym_size_int >= 0; sym_size_int = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
sym_constrain_range_for_size_default_1 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default_1 = None
ge_1 = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_1 = None
bar = torch.ops.export.bar.default(y); y = None
sym_size_int_2 = torch.ops.aten.sym_size.int(bar, 0)
sym_constrain_range_for_size_default_2 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_2); sym_constrain_range_for_size_default_2 = None
ge_2 = sym_size_int_2 >= 0; sym_size_int_2 = None
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_default_2 = None
return (foo, bar)""",
@ -17743,7 +17740,6 @@ class TestExportCustomClass(TorchTestCase):
def forward(self, x, mask):
masked_select = torch.ops.aten.masked_select.default(x, mask); x = mask = None
sym_size_int_1 = torch.ops.aten.sym_size.int(masked_select, 0)
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None
ge = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
le = sym_size_int_1 <= 1188864

View File

@ -21,10 +21,6 @@ from torch._dynamo.testing import (
InductorAndRecordGraphs,
normalize_gm,
)
from torch._higher_order_ops.invoke_subgraph import (
NestedCompileBackend,
NestedCompileRegionOptions,
)
from torch._higher_order_ops.schema import find_hop_schema
from torch._inductor import config as inductor_config
from torch._inductor.pattern_matcher import (
@ -1560,101 +1556,6 @@ class GraphModule(torch.nn.Module):
res = opt_fn(x)
self.assertEqual(ref, res)
def test_unbacked_expr(self):
@nested_compile_region
def gn(x):
return x + 1
def fn(c):
d = torch.concat([c, c], dim=0)
d = gn(d)
return d
c = torch.randn((64, 32))
torch._dynamo.decorators.mark_unbacked(c, 0)
ref = fn(c)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
res = opt_fn(c)
self.assertEqual(ref, res)
def test_grad_accumulation(self):
mod1 = torch.nn.Linear(8, 8)
mod2 = torch.nn.Linear(8, 8)
mod3 = torch.nn.Linear(8, 8)
@nested_compile_region
def gn(x):
return mod1(x) - mod2(x)
def fn(c):
d = gn(c) - mod3(c)
return d * 2
c = torch.randn((8, 8), requires_grad=True)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
res = opt_fn(c)
res.sum().backward()
# fw_add_nodes = backend.fw_graphs[0].graph.find_nodes(op="call_function", target = torch.ops.aten.add.Tensor)
# The gradient addition node for mod3 is not in the subgraph.
bw_add_nodes = backend.bw_graphs[0].graph.find_nodes(
op="call_function", target=torch.ops.aten.add.Tensor
)
self.assertEqual(len(bw_add_nodes), 1)
subgraph_node = backend.bw_graphs[0].graph.find_nodes(op="get_attr")[0]
subgraph_name = subgraph_node.target
# The gradient addition node between mod1 and mode2 will be in the subgraph
bw_add_nodes = getattr(backend.bw_graphs[0], subgraph_name).graph.find_nodes(
op="call_function", target=torch.ops.aten.add.Tensor
)
self.assertEqual(len(bw_add_nodes), 1)
def test_backend_parameter(self):
backend = NestedCompileRegionOptions(NestedCompileBackend.INDUCTOR)
# Test that backend parameter is properly set in node.meta
@nested_compile_region(backend_options=backend)
def gn_with_backend(x):
return torch.sin(x)
@nested_compile_region
def gn_without_backend(x):
return torch.cos(x)
def fn(x):
return gn_with_backend(x) + gn_without_backend(x)
backend = EagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(8, 8, requires_grad=False)
opt_fn(x)
# Check that we captured the graph
self.assertEqual(len(backend.graphs), 1)
graph = backend.graphs[0]
# Find invoke_subgraph nodes and check their backend metadata
invoke_subgraph_nodes = [
node
for node in graph.graph.nodes
if node.op == "call_function"
and node.target == torch.ops.higher_order.invoke_subgraph
]
# We should have 2 invoke_subgraph calls
self.assertEqual(len(invoke_subgraph_nodes), 2)
# First invoke_subgraph (gn_with_backend) should have backend
self.assertIn("custom", invoke_subgraph_nodes[0].meta)
# Second invoke_subgraph (gn_without_backend) should have custom=None or no custom
backend_value = invoke_subgraph_nodes[1].meta.get("custom", None)
self.assertIsNone(backend_value)
def test_complex(self):
# Observed in Wan2.1
@nested_compile_region

View File

@ -1492,8 +1492,8 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
clone: "f32[s77][1]cpu" = torch.ops.aten.clone.default(arg1_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]
@ -1513,8 +1513,8 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
@ -1538,8 +1538,8 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
alias_default: "f32[s77][1]cpu" = torch.ops.aten.alias.default(arg1_1)
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
@ -1557,8 +1557,8 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
def forward(self, arg0_1: "f32[2][1]cpu"):
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg0_1)
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None

View File

@ -3532,11 +3532,11 @@ class TestUbackedOps(TestCase):
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
@ -3573,11 +3573,11 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)",
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
@ -3632,21 +3632,21 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
aot_graphs,
"""\
def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", arg3_1: "f32[u2, u3][1, u2]cpu"):
ge_1: "Sym(u2 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u3 >= 0)" = arg2_1 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
ge: "Sym(u2 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge = _assert_scalar = None
ge_1: "Sym(u3 >= 0)" = arg2_1 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
ge_4: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_4, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_4 = _assert_scalar_2 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None
sym_sum: "Sym(u0 + 1)" = torch.sym_sum((1, _local_scalar_dense))
gt: "Sym(u0 + 1 > 0)" = sym_sum > 0; sym_sum = None
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 0 < u0 + 1 on node 'gt'"); gt = _assert_scalar_3 = None
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
ge_5: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_5 = _assert_scalar_4 = None
ge_3: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_3 = _assert_scalar_4 = None
sym_sum_1: "Sym(u1 + 1)" = torch.sym_sum((1, _local_scalar_dense_1))
gt_1: "Sym(u1 + 1 > 0)" = sym_sum_1 > 0; sym_sum_1 = None
_assert_scalar_5 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 + 1 on node 'gt_1'"); gt_1 = _assert_scalar_5 = None
@ -4068,10 +4068,10 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
ge: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None
add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None
mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None
@ -4097,10 +4097,10 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
ge: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None
mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None
return (mul_5,)""", # noqa: B950
@ -4283,11 +4283,11 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
@ -4319,11 +4319,11 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)",
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None

View File

@ -121,7 +121,7 @@ class TestOpaqueObject(TestCase):
def size_impl_fake(q: OpaqueQueue) -> int:
ctx = torch._custom_op.impl.get_ctx()
u0 = ctx.new_dynamic_size()
torch._check_is_size(u0)
torch._check(u0 >= 0)
return u0
torch.library.define(

View File

@ -33,7 +33,11 @@ from typing import (
TypeVar as _TypeVar,
Union as _Union,
)
from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs
from typing_extensions import (
deprecated as _deprecated,
ParamSpec as _ParamSpec,
TypeIs as _TypeIs,
)
# As a bunch of torch.packages internally still have this check
@ -1735,7 +1739,10 @@ def _check(cond, message=None): # noqa: F811
_check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
# TODO add deprecation annotation
@_deprecated(
"_check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. \
Use _check(i >= 0) instead."
)
def _check_is_size(i, message=None, *, max=None):
"""Checks that a given integer is a valid size (i.e., is non-negative).
You should use this over ``_check(i >= 0)`` because it can prevent

View File

@ -20,7 +20,6 @@ their semantic behavior.
"""
import contextlib
import copy
import functools
import inspect
import itertools
@ -43,10 +42,6 @@ from torch._dynamo.variables.functions import UserFunctionVariable
from torch._dynamo.variables.nn_module import UnspecializedNNModuleVariable
from torch._dynamo.variables.tensor import SymNodeVariable
from torch._guards import Source
from torch._higher_order_ops.invoke_subgraph import (
NestedCompileBackend,
NestedCompileRegionOptions,
)
from torch._ops import HigherOrderOperator
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils import _pytree as pytree
@ -264,7 +259,6 @@ def _call_function_with_auto_output_flattening(
flat_example_value: Any,
body_r: Optional[VariableTracker],
graph_output_vts: VariableTracker | tuple[VariableTracker, ...],
backend_options: Optional[NestedCompileRegionOptions] = None,
) -> Optional[VariableTracker]:
"""
Create HOP call node and reproxify output VTs for HOPs with auto output semantics.
@ -291,30 +285,14 @@ def _call_function_with_auto_output_flattening(
from .builder import wrap_fx_proxy
# Store the invocation as a call
proxy = tx.output.create_proxy(
"call_function",
fn,
args=args,
kwargs=kwargs,
)
# Set backend metadata if provided
if backend_options is not None:
if "custom" not in proxy.node.meta:
proxy.node.meta["custom"] = {}
if backend_options.backend == NestedCompileBackend.INDUCTOR:
inductor_configs = {}
if backend_options.inductor_configs:
inductor_configs = copy.deepcopy(backend_options.inductor_configs)
proxy.node.meta["custom"]["compile_with_inductor"] = {
"inductor_configs": inductor_configs
}
if backend_options.partitioner is not None:
proxy.node.meta["custom"]["partitioner"] = backend_options.partitioner
flat_variable = wrap_fx_proxy(
tx=tx,
proxy=proxy,
proxy=tx.output.create_proxy(
"call_function",
fn,
args=args,
kwargs=kwargs,
),
example_value=flat_example_value,
)
@ -346,13 +324,7 @@ def _call_function_with_auto_output_flattening(
def _call_function_and_unflatten_output(
tx,
fn,
args,
kwargs,
flat_example_value,
ret_spec,
body_r,
tx, fn, args, kwargs, flat_example_value, ret_spec, body_r
):
from .builder import wrap_fx_proxy
@ -4263,18 +4235,6 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
],
)
# Extract backend from the function if it was decorated with nested_compile_region(backend=...)
backend_options = None
fn_var = args[0]
if hasattr(fn_var, "get_function"):
try:
fn = fn_var.get_function()
if hasattr(fn, "__marked_compile_region_backend__"):
backend_options = fn.__marked_compile_region_backend__
except Exception:
pass
p_args = (
p_args[0],
body_name,
@ -4288,7 +4248,6 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
example_value,
body_r,
body_graph_output_vts,
backend_options=backend_options,
)

View File

@ -675,42 +675,6 @@ def prepare_for_partitioner(mod, num_primals, num_fw_outputs):
return out
def _get_partition_fn(fw_hop_node, aot_config):
"""
Return either the default `partition_fn` in aot_config or a HOP specific partition
function.
If a HOP specific partition function is returned, used_hop_custom_partition is True.
See Note [InvokeSubgraphHOP Partitioner]
"""
used_hop_custom_partition = False
partition_fn: Callable[..., tuple[torch.fx.GraphModule, torch.fx.GraphModule]] = (
aot_config.partition_fn
)
if (
fw_hop_node.target == torch._higher_order_ops.invoke_subgraph
and "custom" in fw_hop_node.meta
and "partitioner" in fw_hop_node.meta["custom"]
):
hop_partition_fn = fw_hop_node.meta["custom"]["partitioner"]
if callable(hop_partition_fn):
partition_fn = hop_partition_fn # pyrefly: ignore [bad-assignment]
used_hop_custom_partition = True
else:
assert isinstance(hop_partition_fn, str)
match hop_partition_fn:
case "default_partition":
partition_fn = torch._functorch.partitioners.default_partition
case "min_cut_rematerialization_partition":
partition_fn = torch._functorch.partitioners.min_cut_rematerialization_partition
case _:
raise ValueError(
f"Unknown HOP partitioner config: {hop_partition_fn}"
)
return used_hop_custom_partition, partition_fn
def run_joint_graph_passes_on_hops(
joint_gm: torch.fx.GraphModule,
joint_inputs: Any,
@ -815,26 +779,13 @@ def run_joint_graph_passes_on_hops(
# TODO: invoke_subgraph should track which of its inputs static indices
# so it can propagate them to the partitioner (and use in cudagraphs)
static_lifetime_input_indices: list[int] = []
used_hop_custom_partition, partition_fn = _get_partition_fn(
fw_hop_node, aot_config
)
# Step 2) and 3) - Run joint graph passes and partitioner
try:
new_fw_hop_gm, new_bw_hop_gm = partition_fn(
joint_hop_gm,
[],
num_fwd_outputs=num_fw_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
except Exception as e:
if used_hop_custom_partition:
raise RuntimeError(
f"Error in custom partition function for invoke_subgraph node {fw_hop_node.name}: {e}"
) from e
else:
raise
new_fw_hop_gm, new_bw_hop_gm = aot_config.partition_fn(
joint_hop_gm,
[],
num_fwd_outputs=num_fw_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
# Save the new forward and backward graph modules
new_hop_graphs[identifier].new_fw_hop_gm = new_fw_hop_gm

View File

@ -1,11 +1,9 @@
# mypy: allow-untyped-defs
import contextlib
import enum
from collections.abc import Callable
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
import torch.utils._pytree as pytree
@ -38,6 +36,10 @@ from torch.fx.graph_module import GraphModule
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
if TYPE_CHECKING:
from collections.abc import Callable
invoke_subgraph_counter = 0
@ -51,32 +53,6 @@ class OutputMetadata:
indexes_with_no_grad: set[int] = field(default_factory=set)
class NestedCompileBackend(enum.Enum):
INDUCTOR = "inductor"
DEFAULT = "default"
@dataclass
class NestedCompileRegionOptions:
# If default, does nothing, inherient the torch.compile backend
# If "inductor", will add {"compile_with_inductor": {"inductor_configs":config}} to HOP node meta "custom"
# If "custom" already has "compile_with_inductor", this config will override
backend: NestedCompileBackend = NestedCompileBackend.DEFAULT
# If backend == "inductor", the configs
inductor_configs: Optional[dict[str, Any]] = None
# Note: [InvokeSubgraphHOP Partitioner]
# If not None, add "partitioner" to HOP node meta.
# If Callable, directly assign the callable, but the callable cannot be pickled
# If str, the options are "default_partition" and "min_cut_rematerialization_partition".
# The HOP joint graph will be partitioned using the corresponding functions in
# torch/_functorch/partitioners.py
partitioner: Optional[Callable | str] = None
# TODO: add decomposition function
class InvokeSubgraphHOP(HigherOrderOperator):
def __init__(self) -> None:
# Invoke subgraph does not have any state, it is just a wrapper over a
@ -177,9 +153,7 @@ def invoke_subgraph_placeholder(func, *args, **kwargs):
return func(*args, **kwargs)
def mark_compile_region(
fn=None, backend_options: Optional[NestedCompileRegionOptions] = None
):
def mark_compile_region(fn=None):
"""
This wrapper instructs torch.compile to compile the wrapped region once and
reuse the compiled artifact, instead of the usual way of aggressively
@ -187,10 +161,6 @@ def mark_compile_region(
Under the hood, it tells TorchDynamo to use InvokeSubgraph HOP for the
region. For PyTorch eager, this is a no-op.
Args:
fn: The function to wrap
backend: Optional backend to use for compiling the subgraph
"""
def wrap(func):
@ -202,7 +172,6 @@ def mark_compile_region(
return invoke_subgraph_placeholder(inner_func, *args, **kwargs)
inner.__marked_compile_region_fn__ = func # type: ignore[attr-defined]
func.__marked_compile_region_backend__ = backend_options # type: ignore[attr-defined]
return inner

View File

@ -2538,7 +2538,7 @@ class CppKernel(Kernel):
@property
def assert_function(self) -> str:
if V.graph.aot_mode:
return "STD_TORCH_CHECK"
return "AOTI_TORCH_CHECK"
else:
return "TORCH_CHECK"

View File

@ -1,21 +1,14 @@
# mypy: allow-untyped-defs
import io
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec
import torch
from torch._higher_order_ops.invoke_subgraph import NestedCompileRegionOptions
from . import config
try:
from typing import LiteralString
except ImportError:
from typing_extensions import LiteralString
if TYPE_CHECKING:
from ._cache import CacheInfo
@ -642,9 +635,7 @@ def skip_all_guards_unsafe(guard_entries):
return [False for entry in guard_entries]
def nested_compile_region(
fn=None, backend_options: Optional[NestedCompileRegionOptions] = None
):
def nested_compile_region(fn=None):
"""
Tells **``torch.compile``** that the marked set of operations forms a nested
compile region (which is often repeated in the full model) whose code can be
@ -653,8 +644,8 @@ def nested_compile_region(
During **``torch.compile``** tracing, the compiler applies *hierarchical
compilation* with ``nested_compile_region``: it emits optimized code for the
marked region the first time it is encountered and re-emits (or "stamps
out") the previously compiled code on every subsequent invocation. This can
marked region the first time it is encountered and re-emits (or stamps
out) the previously compiled code on every subsequent invocation. This can
substantially reduce overall compile time for deeply-stacked,
structurally-identical components such as the transformer layers of a
large-language-model (LLM).
@ -668,17 +659,13 @@ def nested_compile_region(
to reuse, it will transparently re-compile the region. Using it is
therefore *safe*: correctness is always preserved, and you pay the extra
compilation cost only when required.
Args:
fn: The function to wrap
backend: Optional backend to use for compiling the subgraph.
"""
from torch._higher_order_ops.invoke_subgraph import (
mark_compile_region as _mark_compile_region,
)
return _mark_compile_region(fn, backend_options=backend_options)
return _mark_compile_region(fn)
def load_compiled_function(file: io.IOBase) -> Callable[..., Any]:

View File

@ -836,9 +836,10 @@ class AOTInductorModelBase {
}
void update_constants_array_from_map() {
STD_TORCH_CHECK(
constants_map_,
"constants_map_ was not ready when constants_ is trying to be constructed from it!");
if (!constants_map_) {
throw std::runtime_error{
"constants_map_ was not ready when constants_ is trying to be constructed from it!"};
}
if (!constants_) {
constants_ =
std::make_shared<std::vector<ConstantHandle>>(constants_info_.size());
@ -874,7 +875,9 @@ class AOTInductorModelBase {
/// Returns true if the model is complete.
bool is_finished() {
#ifdef USE_CUDA
STD_TORCH_CHECK(run_finished_, "Model CUDA event was not initialized");
if (!run_finished_) {
throw std::runtime_error{"Model CUDA event was not initialized"};
}
auto event_status = cudaEventQuery(*run_finished_);
if (event_status == cudaSuccess) {
@ -883,13 +886,13 @@ class AOTInductorModelBase {
return false;
}
STD_TORCH_CHECK(
false,
"The model did not finish successfully. Error: ",
throw std::runtime_error(
std::string("The model did not finish successfully. Error: ") +
cudaGetErrorString(cudaGetLastError()));
#elif defined(USE_XPU)
STD_TORCH_CHECK(run_finished_, "Model XPU event was not initialized");
if (!run_finished_) {
throw std::runtime_error{"Model XPU event was not initialized"};
}
using namespace sycl::info;
return (*run_finished_)->get_info<event::command_execution_status>() ==
event_command_status::complete;
@ -901,14 +904,19 @@ class AOTInductorModelBase {
/// Synchronizes completion event.
void wait_for_completion() {
STD_TORCH_CHECK(run_finished_, "Model event was not initialized");
#ifdef USE_CUDA
if (!run_finished_) {
throw std::runtime_error{"Model event was not initialized"};
}
AOTI_RUNTIME_CUDA_CHECK(cudaEventSynchronize(*run_finished_));
#endif // USE_CUDA
#ifdef USE_XPU
if (!run_finished_) {
throw std::runtime_error{"Model event was not initialized"};
}
(*run_finished_)->wait_and_throw();
#endif // USE_XPU
#endif
}
protected:

View File

@ -123,10 +123,8 @@ class AOTInductorModelContainer {
constants_folding_lk.unlock();
model_lk.lock();
} else if (const_folded != ConstantState::FOLDED) {
STD_TORCH_CHECK(
false,
"Unknown constant state: ",
toStringConstantState(constant_folded_));
throw std::runtime_error(
"Unknown constant state: " + toStringConstantState(constant_folded_));
}
try {
@ -169,10 +167,8 @@ class AOTInductorModelContainer {
/* validate_full_update = */ false);
const_folded = ConstantState::FOLDED;
} else if (constant_folded_ != ConstantState::FOLDED) {
STD_TORCH_CHECK(
false,
"Unknown constant state: ",
toStringConstantState(constant_folded_));
throw std::runtime_error(
"Unknown constant state: " + toStringConstantState(constant_folded_));
}
model->run_single_threaded(
@ -206,56 +202,56 @@ class AOTInductorModelContainer {
}
size_t num_constants() const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->num_constants();
}
// retrieve the constant name of constants_info_[idx]
const char* constant_name(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_name(static_cast<int64_t>(idx));
}
// retrieve original FQN of constants_info_[idx]
const char* constant_original_fqn(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_original_fqn(static_cast<int64_t>(idx));
}
// retrieve whether constant is from folded of constants_info_[idx]
bool constant_from_folded(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_from_folded(static_cast<int64_t>(idx));
}
size_t constant_data_size(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_data_size(static_cast<int64_t>(idx));
}
// retrieve type of constants_info_[idx]
int32_t constant_type(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_type(static_cast<int64_t>(idx));
}
// retrieve dtype of constants_info_[idx]
int32_t constant_dtype(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_dtype(static_cast<int64_t>(idx));
}
@ -387,12 +383,9 @@ class AOTInductorModelContainer {
<< " in model, but not provided by user!\n";
continue;
}
STD_TORCH_CHECK(
false,
"Cannot find constants ",
constant_name,
" in constants_map!");
throw std::runtime_error(
std::string("Cannot find constants ") + constant_name +
std::string(" in constants_map!"));
}
}
}
@ -402,8 +395,9 @@ class AOTInductorModelContainer {
std::unordered_map<std::string, AtenTensorHandle>&& constants_map,
bool use_inactive,
bool validate_full_update) {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No model available in container!");
}
if (validate_full_update) {
assert_all_constants(constants_map);
}
@ -449,9 +443,9 @@ class AOTInductorModelContainer {
bool use_inactive,
bool validate_full_update,
bool user_managed = false) {
STD_TORCH_CHECK(
this->num_models() != 0, "No model available in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No model available in container!");
}
if (validate_full_update) {
assert_all_constants(constants_map);
}

View File

@ -7,7 +7,7 @@ namespace torch::aot_inductor {
template <typename T>
inline RAIIAtenTensorHandle scalar_to_tensor_handle(T value) {
STD_TORCH_CHECK(false, "Unsupported scalar_to_tensor_handle");
throw std::runtime_error("Unsupported scalar_to_tensor_handle");
}
// Specialize for supported C++ primitive types

View File

@ -11,11 +11,11 @@ template <>
struct ThreadLocalCachedOutputTensor<RAIIAtenTensorHandle> {
explicit ThreadLocalCachedOutputTensor(const RAIIAtenTensorHandle&) {}
void copy_data_from(const RAIIAtenTensorHandle& handle) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
AtenTensorHandle tensor() const {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
};
@ -23,11 +23,11 @@ template <>
struct ThreadLocalCachedOutputTensor<AtenTensorHandle> {
explicit ThreadLocalCachedOutputTensor(const AtenTensorHandle&) {}
void copy_data_from(const AtenTensorHandle& handle) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
AtenTensorHandle tensor() const {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
};
@ -35,11 +35,11 @@ template <>
struct ThreadLocalCachedOutputTensor<ConstantHandle> {
explicit ThreadLocalCachedOutputTensor(const ConstantHandle&) {}
void copy_data_from(const ConstantHandle& handle) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
AtenTensorHandle tensor() const {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
};
@ -92,18 +92,18 @@ struct ThreadLocalCachedOutputArray;
template <>
struct ThreadLocalCachedOutputArray<RAIIAtenTensorHandle> {
explicit ThreadLocalCachedOutputArray(const RAIIAtenTensorHandle&) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
// Not supported yet! We would need to put contiguous() or
// expect_contiguous() into the ABI.
void copy_data_from(const RAIIAtenTensorHandle&) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
template <typename U>
ArrayRefTensor<U> arrayref_tensor() const {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
};
@ -111,18 +111,18 @@ struct ThreadLocalCachedOutputArray<RAIIAtenTensorHandle> {
template <>
struct ThreadLocalCachedOutputArray<ConstantHandle> {
explicit ThreadLocalCachedOutputArray(const ConstantHandle&) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
// Not supported yet! We would need to put contiguous() or
// expect_contiguous() into the ABI.
void copy_data_from(const ConstantHandle&) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
template <typename U>
ArrayRefTensor<U> arrayref_tensor() const {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
};

View File

@ -38,10 +38,9 @@
// The following files are implemented in a header-only way and are guarded by
// test/cpp/aoti_abi_check
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/complex.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>
#ifdef __cplusplus
extern "C" {
@ -622,8 +621,34 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function(
int num_tensors,
AtenTensorHandle* flatten_tensor_args);
// Preserve for BC and will delete it later, using the STD_TORCH_CHECK directly
#define AOTI_TORCH_CHECK(cond, ...) STD_TORCH_CHECK(cond, ##__VA_ARGS__)
AOTI_TORCH_EXPORT void aoti_torch_check(
bool cond,
const char* func,
const char* file,
uint32_t line,
const char* msg);
#ifdef STRIP_ERROR_MESSAGES
#define AOTI_TORCH_CHECK(cond, ...) \
if (!(cond)) { \
aoti_torch_check( \
false, \
__func__, \
__FILE__, \
static_cast<uint32_t>(__LINE__), \
TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \
}
#else
#define AOTI_TORCH_CHECK(cond, ...) \
if (!(cond)) { \
aoti_torch_check( \
false, \
__func__, \
__FILE__, \
static_cast<uint32_t>(__LINE__), \
TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \
}
#endif
AOTI_TORCH_EXPORT void aoti_torch_warn(
const char* func,

View File

@ -1339,14 +1339,13 @@ AOTITorchError aoti_torch_proxy_executor_call_function(
int num_tensors,
AtenTensorHandle* flatten_tensor_args) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
TORCH_CHECK(
proxy_executor != nullptr,
"Unable to find a proxy executor to run custom ops.",
"Please check if there is a json file generated",
"in the same directory as the so,",
"or use torch._inductor.aoti_compile_and_package",
"to package everything into a PT2 artifact.");
if (!proxy_executor) {
throw std::runtime_error(
"Unable to find a proxy executor to run custom ops. Please check if "
"there is a json file generated in the same directory as the so, or use "
"torch._inductor.aoti_compile_and_package to package everything into a "
"PT2 artifact.");
}
ProxyExecutor* executor = reinterpret_cast<ProxyExecutor*>(proxy_executor);
executor->call_function(
extern_node_index,
@ -1357,6 +1356,17 @@ AOTITorchError aoti_torch_proxy_executor_call_function(
});
}
void aoti_torch_check(
bool cond,
const char* func,
const char* file,
uint32_t line,
const char* msg) {
if (C10_UNLIKELY_OR_CONST(!cond)) {
::c10::detail::torchCheckFail(func, file, line, msg);
}
}
void aoti_torch_warn(
const char* func,
const char* file,

View File

@ -10,7 +10,9 @@ AOTITorchError aoti_torch_mps_set_arg_tensor(
AtenTensorHandle tensor) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto t = tensor_handle_to_tensor_pointer(tensor);
TORCH_CHECK(t != nullptr, "Tensor is null.");
if (t == nullptr) {
throw std::runtime_error("Tensor is null.");
}
auto func = reinterpret_cast<at::native::mps::MetalKernelFunction*>(handle);
func->setArg(idx, *t);
});

View File

@ -92,11 +92,13 @@ inline void assert_inf_and_nan(
const std::string& tensor_name,
at::Tensor& check_tensor) {
auto isnan_tensor = check_tensor.isnan();
TORCH_CHECK(
!isnan_tensor.any().item<bool>(), "At least one NaN in ", tensor_name);
if (isnan_tensor.any().item<bool>()) {
throw std::runtime_error("At least one NaN in " + tensor_name);
}
auto isinf_tensor = check_tensor.isinf();
TORCH_CHECK(
!isinf_tensor.any().item<bool>(), "At least one INF in ", tensor_name);
if (isinf_tensor.any().item<bool>()) {
throw std::runtime_error("At least one INF in " + tensor_name);
}
}
// utility functions to convert a pointer to an optional value

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import copy
import inspect
import warnings
from collections.abc import Callable, Sequence
@ -96,16 +97,23 @@ class _ToTorchTensor(torch.autograd.Function):
)
tensor_stride = tuple(tensor_stride)
grad_placements = grad_placements or dtensor_spec.placements
grad_spec = DTensorSpec(
mesh,
grad_placements,
tensor_meta=TensorMeta(
shape=dtensor_meta.shape,
stride=tensor_stride,
dtype=dtensor_meta.dtype,
),
)
if (
tensor_stride == dtensor_meta.stride
and grad_placements == dtensor_spec.placements
):
# Avoid actual sharing of specs in case they're modified during (e.g.)
# sharding propagation.
grad_spec = copy.copy(dtensor_spec)
else:
grad_spec = DTensorSpec(
mesh,
grad_placements,
tensor_meta=TensorMeta(
shape=dtensor_meta.shape,
stride=tensor_stride,
dtype=dtensor_meta.dtype,
),
)
return (
# pyrefly: ignore [bad-argument-type]
DTensor(

View File

@ -470,6 +470,10 @@ def has_static_value(a: Union[SymBool, SymFloat, SymInt, bool, float, int]) -> b
return a.node.shape_env.bound_sympy(a.node.expr).is_singleton() # type: ignore[union-attr]
@deprecated(
"guard_size_oblivious will be removed. Consider using explicit unbacked handling \
potentially utilizing guard_or_false, guard_or_true, or statically_known_true"
)
def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool:
"""
Perform a guard on a symbolic boolean expression in a size oblivious way.

View File

@ -576,17 +576,6 @@ def insert_deferred_runtime_asserts(
if i0 in constrained_unbacked_symbols:
continue # constrain symbol just once
if i0 in shape_env.size_like:
if export:
graph.call_function(
torch.ops.aten.sym_constrain_range_for_size.default,
(expr_to_proxy[i0].node,),
)
else:
graph.call_function(
torch._check_is_size, (expr_to_proxy[i0].node,)
)
vr = shape_env.var_to_range[i0]
if vr.is_int and vr.upper == sys.maxsize - 1:
# treat upper bound == sys.maxsize - 1 for int symbols as +oo