Compare commits

..

1 Commits

Author SHA1 Message Date
9abdc9505f create varlenmetadata 2025-11-14 11:54:34 -08:00
29 changed files with 508 additions and 652 deletions

View File

@ -100,6 +100,337 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
)
def _compile_and_extract_symbols(
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
) -> list[str]:
"""
Helper to compile a C++ file and extract all symbols.
Args:
cpp_content: C++ source code to compile
compile_flags: Compilation flags
exclude_list: List of symbol names to exclude. Defaults to ["main"].
Returns:
List of all symbols found in the object file (excluding those in exclude_list).
"""
import subprocess
import tempfile
if exclude_list is None:
exclude_list = ["main"]
with tempfile.TemporaryDirectory() as tmpdir:
tmppath = Path(tmpdir)
cpp_file = tmppath / "test.cpp"
obj_file = tmppath / "test.o"
cpp_file.write_text(cpp_content)
result = subprocess.run(
compile_flags + [str(cpp_file), "-o", str(obj_file)],
capture_output=True,
text=True,
timeout=60,
)
if result.returncode != 0:
raise RuntimeError(f"Compilation failed: {result.stderr}")
symbols = get_symbols(str(obj_file))
# Return all symbol names, excluding those in the exclude list
return [name for _addr, _stype, name in symbols if name not in exclude_list]
def check_stable_only_symbols(install_root: Path) -> None:
"""
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
This approach tests:
1. WITHOUT macros -> many torch symbols exposed
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
4. WITH both macros -> zero torch symbols (all hidden)
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
test_cpp_content = """
// Main torch C++ API headers
#include <torch/torch.h>
#include <torch/all.h>
// ATen tensor library
#include <ATen/ATen.h>
// Core c10 headers (commonly used)
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Optional.h>
int main() { return 0; }
"""
base_compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c", # Compile only, don't link
]
# Compile WITHOUT any macros
symbols_without = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=base_compile_flags,
)
# We expect constexpr symbols, inline functions used by other headers etc.
# to produce symbols
num_symbols_without = len(symbols_without)
print(f"Found {num_symbols_without} symbols without any macros defined")
assert num_symbols_without != 0, (
"Expected a non-zero number of symbols without any macros"
)
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
symbols_with_stable_only = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_stable_only,
)
num_symbols_with_stable_only = len(symbols_with_stable_only)
assert num_symbols_with_stable_only == 0, (
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
)
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
compile_flags_with_target_version = base_compile_flags + [
"-DTORCH_TARGET_VERSION=1"
]
symbols_with_target_version = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_target_version,
)
num_symbols_with_target_version = len(symbols_with_target_version)
assert num_symbols_with_target_version == 0, (
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
)
# Compile WITH both macros (expect 0 symbols)
compile_flags_with_both = base_compile_flags + [
"-DTORCH_STABLE_ONLY",
"-DTORCH_TARGET_VERSION=1",
]
symbols_with_both = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_both,
)
num_symbols_with_both = len(symbols_with_both)
assert num_symbols_with_both == 0, (
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
)
def check_stable_api_symbols(install_root: Path) -> None:
"""
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
stable_dir = include_dir / "torch" / "csrc" / "stable"
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
stable_headers = list(stable_dir.rglob("*.h"))
if not stable_headers:
raise RuntimeError("Could not find any stable headers")
includes = []
for header in stable_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_stable_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable = _compile_and_extract_symbols(
cpp_content=test_stable_content,
compile_flags=compile_flags,
)
num_symbols_stable = len(symbols_stable)
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
assert num_symbols_stable > 0, (
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable} symbols"
)
def check_headeronly_symbols(install_root: Path) -> None:
"""
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Find all headers in torch/headeronly
headeronly_dir = include_dir / "torch" / "headeronly"
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
headeronly_headers = list(headeronly_dir.rglob("*.h"))
if not headeronly_headers:
raise RuntimeError("Could not find any headeronly headers")
# Filter out platform-specific headers that may not compile everywhere
platform_specific_keywords = [
"cpu/vec",
]
filtered_headers = []
for header in headeronly_headers:
rel_path = header.relative_to(include_dir).as_posix()
if not any(
keyword in rel_path.lower() for keyword in platform_specific_keywords
):
filtered_headers.append(header)
includes = []
for header in filtered_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_headeronly_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_headeronly = _compile_and_extract_symbols(
cpp_content=test_headeronly_content,
compile_flags=compile_flags,
)
num_symbols_headeronly = len(symbols_headeronly)
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
assert num_symbols_headeronly > 0, (
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_headeronly} symbols"
)
def check_aoti_shim_symbols(install_root: Path) -> None:
"""
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_shim_content = """
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
int main() {
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
int32_t (*fp2)() = &aoti_torch_dtype_float32;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_shim = _compile_and_extract_symbols(
cpp_content=test_shim_content,
compile_flags=compile_flags,
)
num_symbols_shim = len(symbols_shim)
assert num_symbols_shim > 0, (
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_shim} symbols"
)
def check_stable_c_shim_symbols(install_root: Path) -> None:
"""
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Check if the stable C shim exists
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
if not stable_shim.exists():
raise RuntimeError("Could not find stable c shim")
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_stable_shim_content = """
#include <torch/csrc/stable/c/shim.h>
int main() {
// Reference stable C API functions to create undefined symbols
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable_shim = _compile_and_extract_symbols(
cpp_content=test_stable_shim_content,
compile_flags=compile_flags,
)
num_symbols_stable_shim = len(symbols_stable_shim)
assert num_symbols_stable_shim > 0, (
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable_shim} symbols"
)
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
print(f"lib: {lib}")
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
@ -129,6 +460,13 @@ def main() -> None:
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
# Check symbols when TORCH_STABLE_ONLY is defined
check_stable_only_symbols(install_root)
check_stable_api_symbols(install_root)
check_headeronly_symbols(install_root)
check_aoti_shim_symbols(install_root)
check_stable_c_shim_symbols(install_root)
if __name__ == "__main__":
main()

View File

@ -1358,6 +1358,45 @@ class concat_license_files:
# Need to create the proper LICENSE.txt for the wheel
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
Excludes:
- torch/include/torch/headeronly/*
- torch/include/torch/csrc/stable/*
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
- torch/include/torch/csrc/inductor/aoti_torch/generated/
"""
header_extensions = (".h", ".hpp", ".cuh")
header_files = [
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
]
# Paths to exclude from wrapping
exclude_dir_patterns = [
"torch/include/torch/headeronly/",
"torch/include/torch/csrc/stable/",
"torch/include/torch/csrc/inductor/aoti_torch/c/",
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
]
for header_file in header_files:
rel_path = header_file.relative_to(bdist_dir).as_posix()
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
report(f"Skipping header: {rel_path}")
continue
original_content = header_file.read_text(encoding="utf-8")
wrapped_content = (
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
f"{original_content}"
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
)
header_file.write_text(wrapped_content, encoding="utf-8")
report(f"Wrapped header: {rel_path}")
def run(self) -> None:
with concat_license_files(include_files=True):
super().run()
@ -1380,6 +1419,14 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
# need an __init__.py file otherwise we wouldn't have a package
(bdist_dir / "torch" / "__init__.py").touch()
# Wrap all header files with TORCH_STABLE_ONLY macro
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
bdist_dir = Path(self.bdist_dir)
report(
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
)
self._wrap_headers_with_macro(bdist_dir)
class clean(Command):
user_options: ClassVar[list[tuple[str, str | None, str]]] = []

View File

@ -634,38 +634,3 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_parallel_for", &boxed_test_parallel_for);
m.impl("test_get_num_threads", &boxed_test_get_num_threads);
}
Tensor my_empty(
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
std::optional<torch::headeronly::ScalarType> dtype,
std::optional<torch::stable::Device> device,
std::optional<bool> pin_memory) {
return empty(size, dtype, device, pin_memory);
}
Tensor my_flatten(Tensor t, int64_t start_dim, int64_t end_dim) {
return flatten(t, start_dim, end_dim);
}
Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) {
return reshape(t, shape);
}
Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) {
return view(t, size);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def(
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor");
m.def("my_reshape(Tensor t, int[] shape) -> Tensor");
m.def("my_view(Tensor t, int[] size) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_empty", TORCH_BOX(&my_empty));
m.impl("my_flatten", TORCH_BOX(&my_flatten));
m.impl("my_reshape", TORCH_BOX(&my_reshape));
m.impl("my_view", TORCH_BOX(&my_view));
}

View File

@ -487,58 +487,3 @@ def test_get_num_threads() -> int:
Returns: int - the number of threads for the parallel backend
"""
return torch.ops.libtorch_agnostic.test_get_num_threads.default()
def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor:
"""
Creates an empty tensor with the specified size, dtype, device, and pin_memory.
Args:
size: list[int] - size of the tensor to create
dtype: ScalarType or None - data type of the tensor
device: Device or None - device on which to create the tensor
pin_memory: bool or None - whether to use pinned memory
Returns: Tensor - an uninitialized tensor with the specified properties
"""
return torch.ops.libtorch_agnostic.my_empty.default(size, dtype, device, pin_memory)
def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor:
"""
Flattens the input tensor from start_dim to end_dim into a single dimension.
Args:
t: Tensor - tensor to flatten
start_dim: int - first dimension to flatten (default: 0)
end_dim: int - last dimension to flatten (default: -1)
Returns: Tensor - flattened tensor
"""
return torch.ops.libtorch_agnostic.my_flatten.default(t, start_dim, end_dim)
def my_reshape(t, shape) -> Tensor:
"""
Returns a tensor with the same data but different shape.
Args:
t: Tensor - tensor to reshape
shape: list[int] - new shape for the tensor
Returns: Tensor - reshaped tensor
"""
return torch.ops.libtorch_agnostic.my_reshape.default(t, shape)
def my_view(t, size) -> Tensor:
"""
Returns a new tensor with the same data as the input tensor but of a different shape.
Args:
t: Tensor - tensor to view
size: list[int] - new size for the tensor
Returns: Tensor - tensor with new view
"""
return torch.ops.libtorch_agnostic.my_view.default(t, size)

View File

@ -33,7 +33,7 @@ class clean(distutils.command.clean.clean):
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always"],
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
}
extension = CppExtension

View File

@ -525,97 +525,6 @@ if not IS_WINDOWS:
expected_num_threads = torch.get_num_threads()
self.assertEqual(num_threads, expected_num_threads)
def test_my_empty(self, device):
import libtorch_agnostic
deterministic = torch.are_deterministic_algorithms_enabled()
try:
# set use_deterministic_algorithms to fill uninitialized memory
torch.use_deterministic_algorithms(True)
size = [2, 3]
result = libtorch_agnostic.ops.my_empty(size, None, None, None)
expected = torch.empty(size)
self.assertEqual(result, expected, exact_device=True)
result_float = libtorch_agnostic.ops.my_empty(
size, torch.float32, None, None
)
expected_float = torch.empty(size, dtype=torch.float32)
self.assertEqual(result_float, expected_float, exact_device=True)
result_with_device = libtorch_agnostic.ops.my_empty(
size, torch.float64, device, None
)
expected_with_device = torch.empty(
size, dtype=torch.float64, device=device
)
self.assertEqual(
result_with_device, expected_with_device, exact_device=True
)
if device == "cuda":
result_pinned = libtorch_agnostic.ops.my_empty(
size, torch.float32, "cpu", True
)
expected_pinned = torch.empty(
size, dtype=torch.float32, device="cpu", pin_memory=True
)
self.assertEqual(result_pinned, expected_pinned)
self.assertTrue(result_pinned.is_pinned())
finally:
torch.use_deterministic_algorithms(deterministic)
def test_my_flatten(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_flatten(t)
expected = torch.flatten(t)
self.assertEqual(result, expected)
result_start = libtorch_agnostic.ops.my_flatten(t, 1)
expected_start = torch.flatten(t, 1)
self.assertEqual(result_start, expected_start)
result_range = libtorch_agnostic.ops.my_flatten(t, 2, -1)
expected_range = torch.flatten(t, 2, -1)
self.assertEqual(result_range, expected_range)
def test_my_reshape(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_reshape(t, [6, 4])
expected = torch.reshape(t, [6, 4])
self.assertEqual(result, expected)
result_infer = libtorch_agnostic.ops.my_reshape(t, [-1, 4])
expected_infer = torch.reshape(t, [-1, 4])
self.assertEqual(result_infer, expected_infer)
result_flat = libtorch_agnostic.ops.my_reshape(t, [-1])
expected_flat = torch.reshape(t, [-1])
self.assertEqual(result_flat, expected_flat)
def test_my_view(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_view(t, [6, 4])
expected = t.view([6, 4])
self.assertEqual(result, expected)
result_infer = libtorch_agnostic.ops.my_view(t, [-1, 4])
expected_infer = t.view([-1, 4])
self.assertEqual(result_infer, expected_infer)
result_flat = libtorch_agnostic.ops.my_view(t, [-1])
expected_flat = t.view([-1])
self.assertEqual(result_flat, expected_flat)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -1,67 +0,0 @@
import distutils.command.clean
import shutil
from pathlib import Path
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
ROOT_DIR = Path(__file__).parent
CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc"
class clean(distutils.command.clean.clean):
def run(self):
# Run default behavior first
distutils.command.clean.clean.run(self)
# Remove extension
for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"):
path.unlink()
# Remove build and dist and egg-info directories
dirs = [
ROOT_DIR / "build",
ROOT_DIR / "dist",
ROOT_DIR / "torch_stable_test.egg-info",
]
for path in dirs:
if path.exists():
shutil.rmtree(str(path), ignore_errors=True)
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
}
sources = list(CSRC_DIR.glob("**/*.cpp"))
return [
CppExtension(
"torch_stable_test._C",
sources=sorted(str(s) for s in sources),
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=[],
)
]
setup(
name="torch_stable_test",
version="0.0",
author="PyTorch Core Team",
description="Test extension to verify TORCH_STABLE_ONLY flag",
packages=find_packages(exclude=("test",)),
package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]},
install_requires=[
"torch",
],
ext_modules=get_extension(),
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean,
},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)

View File

@ -1 +0,0 @@
#include <ATen/core/TensorBase.h> // This should trigger the TORCH_STABLE_ONLY error

View File

@ -1,22 +0,0 @@
# Owner(s): ["module: cpp"]
from pathlib import Path
from torch.testing._internal.common_utils import (
install_cpp_extension,
IS_WINDOWS,
run_tests,
TestCase,
)
if not IS_WINDOWS:
class TestTorchStable(TestCase):
def test_setup_fails(self):
with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"):
install_cpp_extension(extension_root=Path(__file__).parent.parent)
if __name__ == "__main__":
run_tests()

View File

@ -1037,30 +1037,6 @@ class DistMathOpsTest(DTensorTestBase):
self.assertTrue(out_with_redistribute.placements[0].is_replicate())
self.assertEqual(out_without_redistribute, out_with_redistribute)
@with_comms
def test_std(self):
mesh = DeviceMesh(self.device_type, torch.arange(4).reshape(2, 2))
rank = self.rank
comm_mode = CommDebugMode()
global_tensor = map_local_for_rank(
rank,
lambda rank: torch.tensor(
[[-20.0, -18.0, -12.0, 0.0], [-20.0, -18.0, -8.0, 4.0]]
),
)
dt = distribute_tensor(global_tensor, mesh, [Shard(0), Shard(1)])
with comm_mode:
res = dt.std(dim=1)
expected_answer = torch.tensor([9.0, 11.0])
self.assertEqual(comm_mode.get_total_counts(), 1)
self.assertEqual(comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1)
self.assertEqual(res.placements, [Shard(0), Replicate()])
self.assertEqual(res.full_tensor(), expected_answer)
DistMathOpsTestWithLocalTensor = create_local_tensor_test_class(
DistMathOpsTest,

View File

@ -3,7 +3,6 @@
import copy
import types
import unittest
import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple
@ -19,9 +18,6 @@ from torch.testing import FileCheck
from torch.testing._internal.common_utils import TEST_CUDA
GLOBAL_LIST = []
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
class TestExperiment(TestCase):
def test_joint_basic(self) -> None:
@ -589,9 +585,9 @@ def forward(self, args_0):
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
L_args_0_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
l_args_0_ = L_args_0_
add = l_args_0_ + 1; add = None
add = l_args_0_ + 1
mul = l_args_0_ * 2; l_args_0_ = None
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul), self._out_spec)""",
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul, add), self._out_spec)""",
)
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
@ -615,34 +611,6 @@ def forward(self, args_0):
self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers())))
self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters())))
def test_dynamo_graph_capture_side_effects(self):
GLOBAL_LIST.clear()
def foo(x):
z = x + 1
GLOBAL_LIST.append(z)
return z
def make_inputs():
return (torch.randn(2, 3),)
trace_inputs = make_inputs()
with warnings.catch_warnings(record=True) as w:
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
cnt = 0
for entry in w:
if "While compiling, we found certain side effects happened" in str(
entry.message
):
cnt += 1
self.assertEqual(cnt, 1)
self.assertEqual(len(GLOBAL_LIST), 0)
test_inputs = make_inputs()
gm_results = gm(*test_inputs)
self.assertEqual(len(GLOBAL_LIST), 0)
self.assertEqual(gm_results, foo(*test_inputs))
self.assertEqual(len(GLOBAL_LIST), 1)
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
class DummyOp(torch.autograd.Function):

View File

@ -740,26 +740,18 @@ class TestExport(TestCase):
dynamic_shapes={"x": {0: Dim("b")}, "y": None},
)
# clean up _torchdynamo related meta data as it could vary depending on the caller
# https://github.com/pytorch/pytorch/issues/167432
for node in ep.graph.nodes:
if "custom" in node.meta:
node.meta["custom"] = {
k: v
for k, v in node.meta["custom"].items()
if "_torchdynamo_disable" not in k
}
custom_metadata = torch.fx.traceback._get_custom_metadata(ep.module())
self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 'cat', {'moo': 0})
('call_function', 'item', {'moo': 0})
('call_function', 'ge_1', {'moo': 0})
('call_function', '_assert_scalar_default', {'moo': 0})
('call_function', 'mul', {'moo': 0})""",
('placeholder', 'x', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('placeholder', 'y', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('call_function', 'cat', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'item', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'ge_1', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', '_assert_scalar_default', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'mul', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('output', 'output', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})""",
)
@requires_gpu

View File

@ -1897,73 +1897,6 @@ class TestPatternMatcher(TestCase):
f"to be >= view count with remove_noop enabled ({view_count_default})",
)
def test_bound_method_pattern_matcher(self):
class ReluSumPattern:
def __init__(self, e: float):
self.e = e
def pattern(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return x.pow(self.e) + y.pow(self.e) + z.pow(self.e)
def replacement(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return (x + y + z).pow(self.e)
def inputs(self):
return [
torch.empty(5, 5), # x
torch.empty(5, 5), # y
torch.empty(5, 5), # z
]
def register(self, pm: PatternMatcherPass):
register_replacement(
self.pattern, self.replacement, self.inputs(), fwd_only, pm
)
my_patterns = PatternMatcherPass()
ReluSumPattern(4).register(my_patterns)
count = 0
def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
nonlocal count
count = my_patterns.apply(graph)
graph.eliminate_dead_code()
return graph
def custom_backend(graph: torch.fx.GraphModule, example_inputs):
from torch._inductor import config
current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx
current_config["post_grad_custom_post_pass"] = custom_pass
current_config["enable_auto_functionalized_v2"] = False
return compile_fx(graph, example_inputs, config_patches=current_config)
@torch.compile(fullgraph=True, backend=custom_backend)
def fn(x):
y = x.relu()
z = y.tanh()
z2 = x.pow(2) + y.pow(2) + z.pow(2)
z3 = x.pow(3) + y.pow(3) + z2.pow(3)
z4 = x.pow(4) + y.pow(4) + z3.pow(4)
return z4 + 5
def fn_replaced(x):
y = x.relu()
z = y.tanh()
z2 = x.pow(2) + y.pow(2) + z.pow(2)
z3 = x.pow(3) + y.pow(3) + z2.pow(3)
z4 = (x + y + z3).pow(4)
return z4 + 5
x = [torch.ones((5, 4))]
fn_result = fn(*x)
fn_replaced_result = fn_replaced(*x)
self.assertEqual(count, 1)
self.assertEqual(fn_result, fn_replaced_result)
if __name__ == "__main__":
if IS_LINUX and HAS_GPU:

View File

@ -614,8 +614,6 @@ def dynamo_graph_capture_for_export(
def inner(*args: Any, **kwargs: Any) -> Any:
assert not torch._dynamo.config.install_free_tensors
with (
torch._dynamo.config.patch(replay_side_effects=False),
torch._dynamo.config.patch(side_effect_replay_policy="warn"),
get_metrics_context(),
dynamo_timed("fullgraph_capture"),
):

View File

@ -2538,7 +2538,7 @@ class CppKernel(Kernel):
@property
def assert_function(self) -> str:
if V.graph.aot_mode:
return "AOTI_TORCH_CHECK"
return "STD_TORCH_CHECK"
else:
return "TORCH_CHECK"

View File

@ -1442,13 +1442,6 @@ def register_replacement(
"""
argnames_static = [*inspect.signature(search_fn).parameters.keys()]
if inspect.ismethod(search_fn):
search_fn = _wrap_bound_method(search_fn, argnames_static)
if inspect.ismethod(replace_fn):
replace_argnames = [*inspect.signature(replace_fn).parameters.keys()]
replace_fn = _wrap_bound_method(replace_fn, replace_argnames)
def check_fn(match: Match) -> bool:
"""
Often shapes get burned into the pattern, so our initial match ran with
@ -1940,22 +1933,6 @@ def compute_mutation_region_ids(graph: torch.fx.Graph) -> None:
nd.meta["mutation_region_id"] = mutation_region_id
def _wrap_bound_method(fn: Any, argnames: list[str]) -> Any:
"""
Wrap a bound method to remove 'self' from its signature for FX tracing.
"""
def wrapper(*args: Any, **kwargs: Any) -> Any:
return fn(*args, **kwargs)
params = [
inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD)
for name in argnames
]
wrapper.__signature__ = inspect.Signature(params) # type: ignore[attr-defined]
return wrapper
class PatternMatcherPass:
def __init__(
self,

View File

@ -836,10 +836,9 @@ class AOTInductorModelBase {
}
void update_constants_array_from_map() {
if (!constants_map_) {
throw std::runtime_error{
"constants_map_ was not ready when constants_ is trying to be constructed from it!"};
}
STD_TORCH_CHECK(
constants_map_,
"constants_map_ was not ready when constants_ is trying to be constructed from it!");
if (!constants_) {
constants_ =
std::make_shared<std::vector<ConstantHandle>>(constants_info_.size());
@ -875,9 +874,7 @@ class AOTInductorModelBase {
/// Returns true if the model is complete.
bool is_finished() {
#ifdef USE_CUDA
if (!run_finished_) {
throw std::runtime_error{"Model CUDA event was not initialized"};
}
STD_TORCH_CHECK(run_finished_, "Model CUDA event was not initialized");
auto event_status = cudaEventQuery(*run_finished_);
if (event_status == cudaSuccess) {
@ -886,13 +883,13 @@ class AOTInductorModelBase {
return false;
}
throw std::runtime_error(
std::string("The model did not finish successfully. Error: ") +
STD_TORCH_CHECK(
false,
"The model did not finish successfully. Error: ",
cudaGetErrorString(cudaGetLastError()));
#elif defined(USE_XPU)
if (!run_finished_) {
throw std::runtime_error{"Model XPU event was not initialized"};
}
STD_TORCH_CHECK(run_finished_, "Model XPU event was not initialized");
using namespace sycl::info;
return (*run_finished_)->get_info<event::command_execution_status>() ==
event_command_status::complete;
@ -904,19 +901,14 @@ class AOTInductorModelBase {
/// Synchronizes completion event.
void wait_for_completion() {
STD_TORCH_CHECK(run_finished_, "Model event was not initialized");
#ifdef USE_CUDA
if (!run_finished_) {
throw std::runtime_error{"Model event was not initialized"};
}
AOTI_RUNTIME_CUDA_CHECK(cudaEventSynchronize(*run_finished_));
#endif // USE_CUDA
#ifdef USE_XPU
if (!run_finished_) {
throw std::runtime_error{"Model event was not initialized"};
}
(*run_finished_)->wait_and_throw();
#endif
#endif // USE_XPU
}
protected:

View File

@ -123,8 +123,10 @@ class AOTInductorModelContainer {
constants_folding_lk.unlock();
model_lk.lock();
} else if (const_folded != ConstantState::FOLDED) {
throw std::runtime_error(
"Unknown constant state: " + toStringConstantState(constant_folded_));
STD_TORCH_CHECK(
false,
"Unknown constant state: ",
toStringConstantState(constant_folded_));
}
try {
@ -167,8 +169,10 @@ class AOTInductorModelContainer {
/* validate_full_update = */ false);
const_folded = ConstantState::FOLDED;
} else if (constant_folded_ != ConstantState::FOLDED) {
throw std::runtime_error(
"Unknown constant state: " + toStringConstantState(constant_folded_));
STD_TORCH_CHECK(
false,
"Unknown constant state: ",
toStringConstantState(constant_folded_));
}
model->run_single_threaded(
@ -202,56 +206,56 @@ class AOTInductorModelContainer {
}
size_t num_constants() const {
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
return models_[0]->num_constants();
}
// retrieve the constant name of constants_info_[idx]
const char* constant_name(size_t idx) const {
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
return models_[0]->constant_name(static_cast<int64_t>(idx));
}
// retrieve original FQN of constants_info_[idx]
const char* constant_original_fqn(size_t idx) const {
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
return models_[0]->constant_original_fqn(static_cast<int64_t>(idx));
}
// retrieve whether constant is from folded of constants_info_[idx]
bool constant_from_folded(size_t idx) const {
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
return models_[0]->constant_from_folded(static_cast<int64_t>(idx));
}
size_t constant_data_size(size_t idx) const {
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
return models_[0]->constant_data_size(static_cast<int64_t>(idx));
}
// retrieve type of constants_info_[idx]
int32_t constant_type(size_t idx) const {
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
return models_[0]->constant_type(static_cast<int64_t>(idx));
}
// retrieve dtype of constants_info_[idx]
int32_t constant_dtype(size_t idx) const {
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
return models_[0]->constant_dtype(static_cast<int64_t>(idx));
}
@ -383,9 +387,12 @@ class AOTInductorModelContainer {
<< " in model, but not provided by user!\n";
continue;
}
throw std::runtime_error(
std::string("Cannot find constants ") + constant_name +
std::string(" in constants_map!"));
STD_TORCH_CHECK(
false,
"Cannot find constants ",
constant_name,
" in constants_map!");
}
}
}
@ -395,9 +402,8 @@ class AOTInductorModelContainer {
std::unordered_map<std::string, AtenTensorHandle>&& constants_map,
bool use_inactive,
bool validate_full_update) {
if (this->num_models() == 0) {
throw std::runtime_error("No model available in container!");
}
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (validate_full_update) {
assert_all_constants(constants_map);
}
@ -443,9 +449,9 @@ class AOTInductorModelContainer {
bool use_inactive,
bool validate_full_update,
bool user_managed = false) {
if (this->num_models() == 0) {
throw std::runtime_error("No model available in container!");
}
STD_TORCH_CHECK(
this->num_models() != 0, "No model available in container!");
if (validate_full_update) {
assert_all_constants(constants_map);
}

View File

@ -7,7 +7,7 @@ namespace torch::aot_inductor {
template <typename T>
inline RAIIAtenTensorHandle scalar_to_tensor_handle(T value) {
throw std::runtime_error("Unsupported scalar_to_tensor_handle");
STD_TORCH_CHECK(false, "Unsupported scalar_to_tensor_handle");
}
// Specialize for supported C++ primitive types

View File

@ -11,11 +11,11 @@ template <>
struct ThreadLocalCachedOutputTensor<RAIIAtenTensorHandle> {
explicit ThreadLocalCachedOutputTensor(const RAIIAtenTensorHandle&) {}
void copy_data_from(const RAIIAtenTensorHandle& handle) {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
AtenTensorHandle tensor() const {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
};
@ -23,11 +23,11 @@ template <>
struct ThreadLocalCachedOutputTensor<AtenTensorHandle> {
explicit ThreadLocalCachedOutputTensor(const AtenTensorHandle&) {}
void copy_data_from(const AtenTensorHandle& handle) {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
AtenTensorHandle tensor() const {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
};
@ -35,11 +35,11 @@ template <>
struct ThreadLocalCachedOutputTensor<ConstantHandle> {
explicit ThreadLocalCachedOutputTensor(const ConstantHandle&) {}
void copy_data_from(const ConstantHandle& handle) {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
AtenTensorHandle tensor() const {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
};
@ -92,18 +92,18 @@ struct ThreadLocalCachedOutputArray;
template <>
struct ThreadLocalCachedOutputArray<RAIIAtenTensorHandle> {
explicit ThreadLocalCachedOutputArray(const RAIIAtenTensorHandle&) {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
// Not supported yet! We would need to put contiguous() or
// expect_contiguous() into the ABI.
void copy_data_from(const RAIIAtenTensorHandle&) {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
template <typename U>
ArrayRefTensor<U> arrayref_tensor() const {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
};
@ -111,18 +111,18 @@ struct ThreadLocalCachedOutputArray<RAIIAtenTensorHandle> {
template <>
struct ThreadLocalCachedOutputArray<ConstantHandle> {
explicit ThreadLocalCachedOutputArray(const ConstantHandle&) {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
// Not supported yet! We would need to put contiguous() or
// expect_contiguous() into the ABI.
void copy_data_from(const ConstantHandle&) {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
template <typename U>
ArrayRefTensor<U> arrayref_tensor() const {
throw std::runtime_error("can't happen");
STD_TORCH_CHECK(false, "can't happen");
}
};

View File

@ -38,9 +38,10 @@
// The following files are implemented in a header-only way and are guarded by
// test/cpp/aoti_abi_check
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/complex.h>
#ifdef __cplusplus
extern "C" {
@ -621,34 +622,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function(
int num_tensors,
AtenTensorHandle* flatten_tensor_args);
AOTI_TORCH_EXPORT void aoti_torch_check(
bool cond,
const char* func,
const char* file,
uint32_t line,
const char* msg);
#ifdef STRIP_ERROR_MESSAGES
#define AOTI_TORCH_CHECK(cond, ...) \
if (!(cond)) { \
aoti_torch_check( \
false, \
__func__, \
__FILE__, \
static_cast<uint32_t>(__LINE__), \
TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \
}
#else
#define AOTI_TORCH_CHECK(cond, ...) \
if (!(cond)) { \
aoti_torch_check( \
false, \
__func__, \
__FILE__, \
static_cast<uint32_t>(__LINE__), \
TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \
}
#endif
// Preserve for BC and will delete it later, using the STD_TORCH_CHECK directly
#define AOTI_TORCH_CHECK(cond, ...) STD_TORCH_CHECK(cond, ##__VA_ARGS__)
AOTI_TORCH_EXPORT void aoti_torch_warn(
const char* func,

View File

@ -1339,13 +1339,14 @@ AOTITorchError aoti_torch_proxy_executor_call_function(
int num_tensors,
AtenTensorHandle* flatten_tensor_args) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
if (!proxy_executor) {
throw std::runtime_error(
"Unable to find a proxy executor to run custom ops. Please check if "
"there is a json file generated in the same directory as the so, or use "
"torch._inductor.aoti_compile_and_package to package everything into a "
"PT2 artifact.");
}
TORCH_CHECK(
proxy_executor != nullptr,
"Unable to find a proxy executor to run custom ops.",
"Please check if there is a json file generated",
"in the same directory as the so,",
"or use torch._inductor.aoti_compile_and_package",
"to package everything into a PT2 artifact.");
ProxyExecutor* executor = reinterpret_cast<ProxyExecutor*>(proxy_executor);
executor->call_function(
extern_node_index,
@ -1356,17 +1357,6 @@ AOTITorchError aoti_torch_proxy_executor_call_function(
});
}
void aoti_torch_check(
bool cond,
const char* func,
const char* file,
uint32_t line,
const char* msg) {
if (C10_UNLIKELY_OR_CONST(!cond)) {
::c10::detail::torchCheckFail(func, file, line, msg);
}
}
void aoti_torch_warn(
const char* func,
const char* file,

View File

@ -10,9 +10,7 @@ AOTITorchError aoti_torch_mps_set_arg_tensor(
AtenTensorHandle tensor) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto t = tensor_handle_to_tensor_pointer(tensor);
if (t == nullptr) {
throw std::runtime_error("Tensor is null.");
}
TORCH_CHECK(t != nullptr, "Tensor is null.");
auto func = reinterpret_cast<at::native::mps::MetalKernelFunction*>(handle);
func->setArg(idx, *t);
});

View File

@ -92,13 +92,11 @@ inline void assert_inf_and_nan(
const std::string& tensor_name,
at::Tensor& check_tensor) {
auto isnan_tensor = check_tensor.isnan();
if (isnan_tensor.any().item<bool>()) {
throw std::runtime_error("At least one NaN in " + tensor_name);
}
TORCH_CHECK(
!isnan_tensor.any().item<bool>(), "At least one NaN in ", tensor_name);
auto isinf_tensor = check_tensor.isinf();
if (isinf_tensor.any().item<bool>()) {
throw std::runtime_error("At least one INF in " + tensor_name);
}
TORCH_CHECK(
!isinf_tensor.any().item<bool>(), "At least one INF in ", tensor_name);
}
// utility functions to convert a pointer to an optional value

View File

@ -69,7 +69,7 @@ inline torch::stable::Tensor narrow(
inline torch::stable::Tensor new_empty(
const torch::stable::Tensor& self,
torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<torch::headeronly::ScalarType> dtype = std::nullopt) {
std::optional<c10::ScalarType> dtype = std::nullopt) {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
@ -108,7 +108,7 @@ inline torch::stable::Tensor new_empty(
inline torch::stable::Tensor new_zeros(
const torch::stable::Tensor& self,
torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<torch::headeronly::ScalarType> dtype = std::nullopt) {
std::optional<c10::ScalarType> dtype = std::nullopt) {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
@ -306,66 +306,6 @@ inline uint32_t get_num_threads() {
return num_threads;
}
// We expect this to be the stable version of the empty op that takes in
// device and dtype parameters. The empty op creates a tensor with uninitialized
// values of the specified size, dtype, and device.
inline torch::stable::Tensor empty(
torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<torch::headeronly::ScalarType> dtype = std::nullopt,
std::optional<torch::stable::Device> device = std::nullopt,
std::optional<bool> pin_memory = std::nullopt) {
const auto num_args = 6;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(size),
torch::stable::detail::from(dtype),
torch::stable::detail::from(std::nullopt),
torch::stable::detail::from(device),
torch::stable::detail::from(pin_memory),
torch::stable::detail::from(std::nullopt)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::empty", "memory_format", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
// We expect this to be the stable version of the flatten.using_ints op.
inline torch::stable::Tensor flatten(
const torch::stable::Tensor& self,
int64_t start_dim = 0,
int64_t end_dim = -1) {
const auto num_args = 3;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self),
torch::stable::detail::from(start_dim),
torch::stable::detail::from(end_dim)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::flatten", "using_ints", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
// We expect this to be the stable version of the reshape op.
inline torch::stable::Tensor reshape(
const torch::stable::Tensor& self,
torch::headeronly::IntHeaderOnlyArrayRef shape) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(shape)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::reshape", "", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
// We expect this to be the stable version of the view op.
inline torch::stable::Tensor view(
const torch::stable::Tensor& self,
torch::headeronly::IntHeaderOnlyArrayRef size) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(size)};
TORCH_ERROR_CODE_CHECK(
torch_call_dispatcher("aten::view", "", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
#endif
HIDDEN_NAMESPACE_END(torch, stable)

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import copy
import inspect
import warnings
from collections.abc import Callable, Sequence
@ -97,23 +96,16 @@ class _ToTorchTensor(torch.autograd.Function):
)
tensor_stride = tuple(tensor_stride)
grad_placements = grad_placements or dtensor_spec.placements
if (
tensor_stride == dtensor_meta.stride
and grad_placements == dtensor_spec.placements
):
# Avoid actual sharing of specs in case they're modified during (e.g.)
# sharding propagation.
grad_spec = copy.copy(dtensor_spec)
else:
grad_spec = DTensorSpec(
mesh,
grad_placements,
tensor_meta=TensorMeta(
shape=dtensor_meta.shape,
stride=tensor_stride,
dtype=dtensor_meta.dtype,
),
)
grad_spec = DTensorSpec(
mesh,
grad_placements,
tensor_meta=TensorMeta(
shape=dtensor_meta.shape,
stride=tensor_stride,
dtype=dtensor_meta.dtype,
),
)
return (
# pyrefly: ignore [bad-argument-type]
DTensor(

View File

@ -405,15 +405,10 @@ def cumsum_strategy(op_schema: OpSchema) -> OpStrategy:
@register_op_strategy(
[
aten.std.correction,
aten.std.correction_out,
aten.var.correction,
aten.var.correction_out,
],
[aten.var.correction, aten.var.correction_out],
schema_info=RuntimeSchemaInfo(1, ["keepdim"]),
)
def std_var_reduction_strategy(op_schema: OpSchema) -> OpStrategy:
def var_reduction_strategy(op_schema: OpSchema) -> OpStrategy:
args_schema = op_schema.args_schema
input_strategy = args_schema[0]
if not isinstance(input_strategy, OpStrategy):

View File

@ -14,7 +14,7 @@ import torch
log = logging.getLogger(__name__)
__all__ = ["varlen_attn", "AuxRequest"]
__all__ = ["varlen_attn", "AuxRequest", "VarlenMetadata"]
@lru_cache(maxsize=8)
@ -23,6 +23,18 @@ def _should_use_cudnn(device_index: int) -> bool:
return False
class VarlenMetadata(NamedTuple):
"""
Cumulative sequence positions for queries and keys/values.
"""
cu_seq_q: torch.Tensor
cu_seq_k: torch.Tensor
max_q: int
max_k: int
class AuxRequest(NamedTuple):
"""
Request which auxiliary outputs to compute from varlen_attn.