Compare commits

...

15 Commits

Author SHA1 Message Date
09a0f58faf Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. We accomplish this by adding a function that checks if we support Partial placements when attempting to add them to the output_spec. Regarding the specific checks for adding a scalar to Partial, I check if aten.add.Tensor is the op, and if so, is there a scalar argument in arg_schemas. If this is the case, I return False to force replication. Currently, we don't need to do this for aten.mul.Tensor as it works for all the reduction ops for each Partial placement. However, in the future, new Partial placements could be added where the reduce op requires redistribution. In this case, I replicate, but I warn the user I'm replicating.

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar
2. pytest test/distributed/tensor/test_pointwise_ops.py -k test_unverified_custom_partial




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-14 13:49:08 -08:00
aa9b4f836e Update base for Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. We accomplish this by adding a function that checks if we support Partial placements when attempting to add them to the output_spec. Regarding the specific checks for adding a scalar to Partial, I check if aten.add.Tensor is the op, and if so, is there a scalar argument in arg_schemas. If this is the case, I return False to force replication. Currently, we don't need to do this for aten.mul.Tensor as it works for all the reduction ops for each Partial placement. However, in the future, new Partial placements could be added where the reduce op requires redistribution. In this case, I replicate, but I warn the user I'm replicating.

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar
2. pytest test/distributed/tensor/test_pointwise_ops.py -k test_unverified_custom_partial




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-14 13:49:08 -08:00
7aa210d215 Revert "[CodeClean] Remove the Unused MACRO for AOT Inductor Runtime (#165139)"
This reverts commit fcd5f8c352b5b75bd32e57fa044ec5df095032da.

Reverted https://github.com/pytorch/pytorch/pull/165139 on behalf of https://github.com/jeanschmidt due to trying to hevert in the hopes it fixes internal errors, will land it back ([comment](https://github.com/pytorch/pytorch/pull/165139#issuecomment-3534662138))
2025-11-14 21:35:37 +00:00
5a368b8010 Revert "[CodeClean] Replace std::runtime_error with TORCH_CHECK (#165119)"
This reverts commit 398775a43e9808205f75c81d36f5087117d3f3f4.

Reverted https://github.com/pytorch/pytorch/pull/165119 on behalf of https://github.com/jeanschmidt due to trying to hevert in the hopes it fixes internal errors, will land it back ([comment](https://github.com/pytorch/pytorch/pull/165139#issuecomment-3534662138))
2025-11-14 21:35:37 +00:00
602102be50 Revert "Hide all symbols (except stable/headeronly/shim) if TORCH_STABLE_ONLY is defined (#167496)"
This reverts commit bc09a84150eaadaadab8a8ecd76cd9afc60d8a19.

Reverted https://github.com/pytorch/pytorch/pull/167496 on behalf of https://github.com/jeanschmidt due to trying to revert 165139, my intention is to land it again, so, will land this once both are reverted ([comment](https://github.com/pytorch/pytorch/pull/167496#issuecomment-3534641209))
2025-11-14 21:33:02 +00:00
2387cfb3bf Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. We accomplish this by adding a function that checks if we support Partial placements when attempting to add them to the output_spec. Regarding the specific checks for adding a scalar to Partial, I check if aten.add.Tensor is the op, and if so, is there a scalar argument in arg_schemas. If this is the case, I return False to force replication. Currently, we don't need to do this for aten.mul.Tensor as it works for all the reduction ops for each Partial placement. However, in the future, new Partial placements could be added where the reduce op requires redistribution. In this case, I replicate, but I warn the user I'm replicating.

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar
2. pytest test/distributed/tensor/test_pointwise_ops.py -k test_unverified_custom_partial




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-14 13:31:24 -08:00
b92b703b6c Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. We accomplish this by adding a function that checks if we support Partial placements when attempting to add them to the output_spec. Regarding the specific checks for adding a scalar to Partial, I check if aten.add.Tensor is the op, and if so, is there a scalar argument in arg_schemas. If this is the case, I return False to force replication. Currently, we don't need to do this for aten.mul.Tensor as it works for all the reduction ops for each Partial placement. However, in the future, new Partial placements could be added where the reduce op requires redistribution. In this case, I replicate, but I warn the user I'm replicating.

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar
2. pytest test/distributed/tensor/test_pointwise_ops.py -k test_unverified_custom_partial




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-14 13:25:44 -08:00
200156e385 DTensor: avoid unnecessary DTensorSpec creation in _ToTorchTensor.backward (#167588)
Looks like the check here is cheap and has a potentially large payoff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167588
Approved by: https://github.com/ezyang
2025-11-14 21:08:12 +00:00
a2daf3fc86 [Inductor] Add support bound methods in pattern matcher (#167795)
Fixes: #167776

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167795
Approved by: https://github.com/mlazos
2025-11-14 20:55:51 +00:00
52b45c16de Add reshape, view, flatten to torch/csrc/stable (#167600)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167600
Approved by: https://github.com/janeyx99
ghstack dependencies: #167592
2025-11-14 20:35:53 +00:00
2ef85bed5a Add empty to stable ops (#167592)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167592
Approved by: https://github.com/janeyx99
2025-11-14 20:35:53 +00:00
d99c6bcf69 [export] Disable side effects on dynamo_graph_capture_for_export and warn user. (#167763)
Summary:
as title.

Test Plan:
test_dynamo_graph_capture_side_effects

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167763
Approved by: https://github.com/tugsbayasgalan
2025-11-14 20:35:22 +00:00
8378abda84 [torch.export] Fix for flaky test_annotate_on_assert (#167805)
Summary: test_annotate_on_assert become flaky with PR 166341 (Details in https://github.com/pytorch/pytorch/issues/167432). Torchdynamo related metadata can vary depending on the caller. Removing the those metadata before comparison.

Test Plan:
```
buck test mode/opt caffe2/test:test_export -- 'test_annotate_on_assert'
```
https://www.internalfb.com/intern/testinfra/testrun/7036874728749661

Differential Revision: D87036890

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167805
Approved by: https://github.com/yushangdi
2025-11-14 19:56:51 +00:00
48b8b9539d Update on "[DTensor][Partial] fixing adding scalar to Partial"
**Summary:** This is a fix for adding a scalar to a Partial dTensor reported in https://github.com/pytorch/pytorch/issues/149768, https://github.com/pytorch/pytorch/issues/163193. We accomplish this by adding a function that checks if we support Partial placements when attempting to add them to the output_spec. Regarding the specific checks for adding a scalar to Partial, I check if aten.add.Tensor is the op, and if so, is there a scalar argument in arg_schemas. If this is the case, I return False to force replication. Currently, we don't need to do this for aten.mul.Tensor as it works for all the reduction ops for each Partial placement. However, in the future, new Partial placements could be added where the reduce op requires redistribution. In this case, I replicate, but I warn the user I'm replicating.

**Test Cases**
1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_partial_scalar
2. pytest test/distributed/tensor/test_pointwise_ops.py -k test_unverified_custom_partial




cc wanchaol tianyu-l wz337 XilunWu d4l3k pragupta SherlockNoMad H-Huang awgu fegin fduwjj wconstab msaroufim dcci

[ghstack-poisoned]
2025-11-14 10:55:55 -08:00
47e80120ac [DTensor][Partial] fixing adding scalar to Partial
[ghstack-poisoned]
2025-11-13 22:47:15 -08:00
28 changed files with 654 additions and 496 deletions

View File

@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
)
def _compile_and_extract_symbols(
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
) -> list[str]:
"""
Helper to compile a C++ file and extract all symbols.
Args:
cpp_content: C++ source code to compile
compile_flags: Compilation flags
exclude_list: List of symbol names to exclude. Defaults to ["main"].
Returns:
List of all symbols found in the object file (excluding those in exclude_list).
"""
import subprocess
import tempfile
if exclude_list is None:
exclude_list = ["main"]
with tempfile.TemporaryDirectory() as tmpdir:
tmppath = Path(tmpdir)
cpp_file = tmppath / "test.cpp"
obj_file = tmppath / "test.o"
cpp_file.write_text(cpp_content)
result = subprocess.run(
compile_flags + [str(cpp_file), "-o", str(obj_file)],
capture_output=True,
text=True,
timeout=60,
)
if result.returncode != 0:
raise RuntimeError(f"Compilation failed: {result.stderr}")
symbols = get_symbols(str(obj_file))
# Return all symbol names, excluding those in the exclude list
return [name for _addr, _stype, name in symbols if name not in exclude_list]
def check_stable_only_symbols(install_root: Path) -> None:
"""
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
This approach tests:
1. WITHOUT macros -> many torch symbols exposed
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
4. WITH both macros -> zero torch symbols (all hidden)
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
test_cpp_content = """
// Main torch C++ API headers
#include <torch/torch.h>
#include <torch/all.h>
// ATen tensor library
#include <ATen/ATen.h>
// Core c10 headers (commonly used)
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Optional.h>
int main() { return 0; }
"""
base_compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c", # Compile only, don't link
]
# Compile WITHOUT any macros
symbols_without = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=base_compile_flags,
)
# We expect constexpr symbols, inline functions used by other headers etc.
# to produce symbols
num_symbols_without = len(symbols_without)
print(f"Found {num_symbols_without} symbols without any macros defined")
assert num_symbols_without != 0, (
"Expected a non-zero number of symbols without any macros"
)
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
symbols_with_stable_only = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_stable_only,
)
num_symbols_with_stable_only = len(symbols_with_stable_only)
assert num_symbols_with_stable_only == 0, (
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
)
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
compile_flags_with_target_version = base_compile_flags + [
"-DTORCH_TARGET_VERSION=1"
]
symbols_with_target_version = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_target_version,
)
num_symbols_with_target_version = len(symbols_with_target_version)
assert num_symbols_with_target_version == 0, (
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
)
# Compile WITH both macros (expect 0 symbols)
compile_flags_with_both = base_compile_flags + [
"-DTORCH_STABLE_ONLY",
"-DTORCH_TARGET_VERSION=1",
]
symbols_with_both = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_both,
)
num_symbols_with_both = len(symbols_with_both)
assert num_symbols_with_both == 0, (
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
)
def check_stable_api_symbols(install_root: Path) -> None:
"""
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
stable_dir = include_dir / "torch" / "csrc" / "stable"
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
stable_headers = list(stable_dir.rglob("*.h"))
if not stable_headers:
raise RuntimeError("Could not find any stable headers")
includes = []
for header in stable_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_stable_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable = _compile_and_extract_symbols(
cpp_content=test_stable_content,
compile_flags=compile_flags,
)
num_symbols_stable = len(symbols_stable)
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
assert num_symbols_stable > 0, (
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable} symbols"
)
def check_headeronly_symbols(install_root: Path) -> None:
"""
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Find all headers in torch/headeronly
headeronly_dir = include_dir / "torch" / "headeronly"
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
headeronly_headers = list(headeronly_dir.rglob("*.h"))
if not headeronly_headers:
raise RuntimeError("Could not find any headeronly headers")
# Filter out platform-specific headers that may not compile everywhere
platform_specific_keywords = [
"cpu/vec",
]
filtered_headers = []
for header in headeronly_headers:
rel_path = header.relative_to(include_dir).as_posix()
if not any(
keyword in rel_path.lower() for keyword in platform_specific_keywords
):
filtered_headers.append(header)
includes = []
for header in filtered_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_headeronly_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_headeronly = _compile_and_extract_symbols(
cpp_content=test_headeronly_content,
compile_flags=compile_flags,
)
num_symbols_headeronly = len(symbols_headeronly)
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
assert num_symbols_headeronly > 0, (
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_headeronly} symbols"
)
def check_aoti_shim_symbols(install_root: Path) -> None:
"""
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_shim_content = """
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
int main() {
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
int32_t (*fp2)() = &aoti_torch_dtype_float32;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_shim = _compile_and_extract_symbols(
cpp_content=test_shim_content,
compile_flags=compile_flags,
)
num_symbols_shim = len(symbols_shim)
assert num_symbols_shim > 0, (
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_shim} symbols"
)
def check_stable_c_shim_symbols(install_root: Path) -> None:
"""
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Check if the stable C shim exists
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
if not stable_shim.exists():
raise RuntimeError("Could not find stable c shim")
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_stable_shim_content = """
#include <torch/csrc/stable/c/shim.h>
int main() {
// Reference stable C API functions to create undefined symbols
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable_shim = _compile_and_extract_symbols(
cpp_content=test_stable_shim_content,
compile_flags=compile_flags,
)
num_symbols_stable_shim = len(symbols_stable_shim)
assert num_symbols_stable_shim > 0, (
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable_shim} symbols"
)
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
print(f"lib: {lib}")
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
@ -460,13 +129,6 @@ def main() -> None:
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
# Check symbols when TORCH_STABLE_ONLY is defined
check_stable_only_symbols(install_root)
check_stable_api_symbols(install_root)
check_headeronly_symbols(install_root)
check_aoti_shim_symbols(install_root)
check_stable_c_shim_symbols(install_root)
if __name__ == "__main__":
main()

View File

@ -1358,45 +1358,6 @@ class concat_license_files:
# Need to create the proper LICENSE.txt for the wheel
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
Excludes:
- torch/include/torch/headeronly/*
- torch/include/torch/csrc/stable/*
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
- torch/include/torch/csrc/inductor/aoti_torch/generated/
"""
header_extensions = (".h", ".hpp", ".cuh")
header_files = [
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
]
# Paths to exclude from wrapping
exclude_dir_patterns = [
"torch/include/torch/headeronly/",
"torch/include/torch/csrc/stable/",
"torch/include/torch/csrc/inductor/aoti_torch/c/",
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
]
for header_file in header_files:
rel_path = header_file.relative_to(bdist_dir).as_posix()
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
report(f"Skipping header: {rel_path}")
continue
original_content = header_file.read_text(encoding="utf-8")
wrapped_content = (
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
f"{original_content}"
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
)
header_file.write_text(wrapped_content, encoding="utf-8")
report(f"Wrapped header: {rel_path}")
def run(self) -> None:
with concat_license_files(include_files=True):
super().run()
@ -1419,14 +1380,6 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
# need an __init__.py file otherwise we wouldn't have a package
(bdist_dir / "torch" / "__init__.py").touch()
# Wrap all header files with TORCH_STABLE_ONLY macro
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
bdist_dir = Path(self.bdist_dir)
report(
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
)
self._wrap_headers_with_macro(bdist_dir)
class clean(Command):
user_options: ClassVar[list[tuple[str, str | None, str]]] = []

View File

@ -634,3 +634,38 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_parallel_for", &boxed_test_parallel_for);
m.impl("test_get_num_threads", &boxed_test_get_num_threads);
}
Tensor my_empty(
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
std::optional<torch::headeronly::ScalarType> dtype,
std::optional<torch::stable::Device> device,
std::optional<bool> pin_memory) {
return empty(size, dtype, device, pin_memory);
}
Tensor my_flatten(Tensor t, int64_t start_dim, int64_t end_dim) {
return flatten(t, start_dim, end_dim);
}
Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) {
return reshape(t, shape);
}
Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) {
return view(t, size);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def(
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor");
m.def("my_reshape(Tensor t, int[] shape) -> Tensor");
m.def("my_view(Tensor t, int[] size) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_empty", TORCH_BOX(&my_empty));
m.impl("my_flatten", TORCH_BOX(&my_flatten));
m.impl("my_reshape", TORCH_BOX(&my_reshape));
m.impl("my_view", TORCH_BOX(&my_view));
}

View File

@ -487,3 +487,58 @@ def test_get_num_threads() -> int:
Returns: int - the number of threads for the parallel backend
"""
return torch.ops.libtorch_agnostic.test_get_num_threads.default()
def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor:
"""
Creates an empty tensor with the specified size, dtype, device, and pin_memory.
Args:
size: list[int] - size of the tensor to create
dtype: ScalarType or None - data type of the tensor
device: Device or None - device on which to create the tensor
pin_memory: bool or None - whether to use pinned memory
Returns: Tensor - an uninitialized tensor with the specified properties
"""
return torch.ops.libtorch_agnostic.my_empty.default(size, dtype, device, pin_memory)
def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor:
"""
Flattens the input tensor from start_dim to end_dim into a single dimension.
Args:
t: Tensor - tensor to flatten
start_dim: int - first dimension to flatten (default: 0)
end_dim: int - last dimension to flatten (default: -1)
Returns: Tensor - flattened tensor
"""
return torch.ops.libtorch_agnostic.my_flatten.default(t, start_dim, end_dim)
def my_reshape(t, shape) -> Tensor:
"""
Returns a tensor with the same data but different shape.
Args:
t: Tensor - tensor to reshape
shape: list[int] - new shape for the tensor
Returns: Tensor - reshaped tensor
"""
return torch.ops.libtorch_agnostic.my_reshape.default(t, shape)
def my_view(t, size) -> Tensor:
"""
Returns a new tensor with the same data as the input tensor but of a different shape.
Args:
t: Tensor - tensor to view
size: list[int] - new size for the tensor
Returns: Tensor - tensor with new view
"""
return torch.ops.libtorch_agnostic.my_view.default(t, size)

View File

@ -33,7 +33,7 @@ class clean(distutils.command.clean.clean):
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
"cxx": ["-fdiagnostics-color=always"],
}
extension = CppExtension

View File

@ -525,6 +525,97 @@ if not IS_WINDOWS:
expected_num_threads = torch.get_num_threads()
self.assertEqual(num_threads, expected_num_threads)
def test_my_empty(self, device):
import libtorch_agnostic
deterministic = torch.are_deterministic_algorithms_enabled()
try:
# set use_deterministic_algorithms to fill uninitialized memory
torch.use_deterministic_algorithms(True)
size = [2, 3]
result = libtorch_agnostic.ops.my_empty(size, None, None, None)
expected = torch.empty(size)
self.assertEqual(result, expected, exact_device=True)
result_float = libtorch_agnostic.ops.my_empty(
size, torch.float32, None, None
)
expected_float = torch.empty(size, dtype=torch.float32)
self.assertEqual(result_float, expected_float, exact_device=True)
result_with_device = libtorch_agnostic.ops.my_empty(
size, torch.float64, device, None
)
expected_with_device = torch.empty(
size, dtype=torch.float64, device=device
)
self.assertEqual(
result_with_device, expected_with_device, exact_device=True
)
if device == "cuda":
result_pinned = libtorch_agnostic.ops.my_empty(
size, torch.float32, "cpu", True
)
expected_pinned = torch.empty(
size, dtype=torch.float32, device="cpu", pin_memory=True
)
self.assertEqual(result_pinned, expected_pinned)
self.assertTrue(result_pinned.is_pinned())
finally:
torch.use_deterministic_algorithms(deterministic)
def test_my_flatten(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_flatten(t)
expected = torch.flatten(t)
self.assertEqual(result, expected)
result_start = libtorch_agnostic.ops.my_flatten(t, 1)
expected_start = torch.flatten(t, 1)
self.assertEqual(result_start, expected_start)
result_range = libtorch_agnostic.ops.my_flatten(t, 2, -1)
expected_range = torch.flatten(t, 2, -1)
self.assertEqual(result_range, expected_range)
def test_my_reshape(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_reshape(t, [6, 4])
expected = torch.reshape(t, [6, 4])
self.assertEqual(result, expected)
result_infer = libtorch_agnostic.ops.my_reshape(t, [-1, 4])
expected_infer = torch.reshape(t, [-1, 4])
self.assertEqual(result_infer, expected_infer)
result_flat = libtorch_agnostic.ops.my_reshape(t, [-1])
expected_flat = torch.reshape(t, [-1])
self.assertEqual(result_flat, expected_flat)
def test_my_view(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_view(t, [6, 4])
expected = t.view([6, 4])
self.assertEqual(result, expected)
result_infer = libtorch_agnostic.ops.my_view(t, [-1, 4])
expected_infer = t.view([-1, 4])
self.assertEqual(result_infer, expected_infer)
result_flat = libtorch_agnostic.ops.my_view(t, [-1])
expected_flat = t.view([-1])
self.assertEqual(result_flat, expected_flat)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -0,0 +1,67 @@
import distutils.command.clean
import shutil
from pathlib import Path
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
ROOT_DIR = Path(__file__).parent
CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc"
class clean(distutils.command.clean.clean):
def run(self):
# Run default behavior first
distutils.command.clean.clean.run(self)
# Remove extension
for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"):
path.unlink()
# Remove build and dist and egg-info directories
dirs = [
ROOT_DIR / "build",
ROOT_DIR / "dist",
ROOT_DIR / "torch_stable_test.egg-info",
]
for path in dirs:
if path.exists():
shutil.rmtree(str(path), ignore_errors=True)
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
}
sources = list(CSRC_DIR.glob("**/*.cpp"))
return [
CppExtension(
"torch_stable_test._C",
sources=sorted(str(s) for s in sources),
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=[],
)
]
setup(
name="torch_stable_test",
version="0.0",
author="PyTorch Core Team",
description="Test extension to verify TORCH_STABLE_ONLY flag",
packages=find_packages(exclude=("test",)),
package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]},
install_requires=[
"torch",
],
ext_modules=get_extension(),
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean,
},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)

View File

@ -0,0 +1 @@
#include <ATen/core/TensorBase.h> // This should trigger the TORCH_STABLE_ONLY error

View File

@ -0,0 +1,22 @@
# Owner(s): ["module: cpp"]
from pathlib import Path
from torch.testing._internal.common_utils import (
install_cpp_extension,
IS_WINDOWS,
run_tests,
TestCase,
)
if not IS_WINDOWS:
class TestTorchStable(TestCase):
def test_setup_fails(self):
with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"):
install_cpp_extension(extension_root=Path(__file__).parent.parent)
if __name__ == "__main__":
run_tests()

View File

@ -350,6 +350,27 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
):
partial_dt.clamp_(max=10)
def test_add_partial_scalar(self):
mesh = self.build_device_mesh()
rank = self.rank
local_tensor = torch.tensor([rank])
dt = DTensor.from_local(
local_tensor, device_mesh=mesh, placements=[Partial("sum")]
)
res = dt + 1
self.assertEqual(res, 7)
local_tensor = torch.tensor([1.0, 1.0, 7.0, 7.0])
dt = distribute_tensor(local_tensor, mesh, [Shard(0)])
norm = dt.norm()
norm = norm + 1
self.assertEqual(norm, 11)
if __name__ == "__main__":
run_tests()

View File

@ -3,6 +3,7 @@
import copy
import types
import unittest
import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple
@ -18,6 +19,9 @@ from torch.testing import FileCheck
from torch.testing._internal.common_utils import TEST_CUDA
GLOBAL_LIST = []
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
class TestExperiment(TestCase):
def test_joint_basic(self) -> None:
@ -585,9 +589,9 @@ def forward(self, args_0):
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
L_args_0_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
l_args_0_ = L_args_0_
add = l_args_0_ + 1
add = l_args_0_ + 1; add = None
mul = l_args_0_ * 2; l_args_0_ = None
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul, add), self._out_spec)""",
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul), self._out_spec)""",
)
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
@ -611,6 +615,34 @@ def forward(self, args_0):
self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers())))
self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters())))
def test_dynamo_graph_capture_side_effects(self):
GLOBAL_LIST.clear()
def foo(x):
z = x + 1
GLOBAL_LIST.append(z)
return z
def make_inputs():
return (torch.randn(2, 3),)
trace_inputs = make_inputs()
with warnings.catch_warnings(record=True) as w:
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
cnt = 0
for entry in w:
if "While compiling, we found certain side effects happened" in str(
entry.message
):
cnt += 1
self.assertEqual(cnt, 1)
self.assertEqual(len(GLOBAL_LIST), 0)
test_inputs = make_inputs()
gm_results = gm(*test_inputs)
self.assertEqual(len(GLOBAL_LIST), 0)
self.assertEqual(gm_results, foo(*test_inputs))
self.assertEqual(len(GLOBAL_LIST), 1)
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
class DummyOp(torch.autograd.Function):

View File

@ -740,18 +740,26 @@ class TestExport(TestCase):
dynamic_shapes={"x": {0: Dim("b")}, "y": None},
)
# clean up _torchdynamo related meta data as it could vary depending on the caller
# https://github.com/pytorch/pytorch/issues/167432
for node in ep.graph.nodes:
if "custom" in node.meta:
node.meta["custom"] = {
k: v
for k, v in node.meta["custom"].items()
if "_torchdynamo_disable" not in k
}
custom_metadata = torch.fx.traceback._get_custom_metadata(ep.module())
self.assertExpectedInline(
str(custom_metadata),
"""\
('placeholder', 'x', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('placeholder', 'y', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('call_function', 'cat', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'item', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'ge_1', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', '_assert_scalar_default', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'mul', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('output', 'output', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})""",
('call_function', 'cat', {'moo': 0})
('call_function', 'item', {'moo': 0})
('call_function', 'ge_1', {'moo': 0})
('call_function', '_assert_scalar_default', {'moo': 0})
('call_function', 'mul', {'moo': 0})""",
)
@requires_gpu

View File

@ -1897,6 +1897,73 @@ class TestPatternMatcher(TestCase):
f"to be >= view count with remove_noop enabled ({view_count_default})",
)
def test_bound_method_pattern_matcher(self):
class ReluSumPattern:
def __init__(self, e: float):
self.e = e
def pattern(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return x.pow(self.e) + y.pow(self.e) + z.pow(self.e)
def replacement(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return (x + y + z).pow(self.e)
def inputs(self):
return [
torch.empty(5, 5), # x
torch.empty(5, 5), # y
torch.empty(5, 5), # z
]
def register(self, pm: PatternMatcherPass):
register_replacement(
self.pattern, self.replacement, self.inputs(), fwd_only, pm
)
my_patterns = PatternMatcherPass()
ReluSumPattern(4).register(my_patterns)
count = 0
def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
nonlocal count
count = my_patterns.apply(graph)
graph.eliminate_dead_code()
return graph
def custom_backend(graph: torch.fx.GraphModule, example_inputs):
from torch._inductor import config
current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx
current_config["post_grad_custom_post_pass"] = custom_pass
current_config["enable_auto_functionalized_v2"] = False
return compile_fx(graph, example_inputs, config_patches=current_config)
@torch.compile(fullgraph=True, backend=custom_backend)
def fn(x):
y = x.relu()
z = y.tanh()
z2 = x.pow(2) + y.pow(2) + z.pow(2)
z3 = x.pow(3) + y.pow(3) + z2.pow(3)
z4 = x.pow(4) + y.pow(4) + z3.pow(4)
return z4 + 5
def fn_replaced(x):
y = x.relu()
z = y.tanh()
z2 = x.pow(2) + y.pow(2) + z.pow(2)
z3 = x.pow(3) + y.pow(3) + z2.pow(3)
z4 = (x + y + z3).pow(4)
return z4 + 5
x = [torch.ones((5, 4))]
fn_result = fn(*x)
fn_replaced_result = fn_replaced(*x)
self.assertEqual(count, 1)
self.assertEqual(fn_result, fn_replaced_result)
if __name__ == "__main__":
if IS_LINUX and HAS_GPU:

View File

@ -614,6 +614,8 @@ def dynamo_graph_capture_for_export(
def inner(*args: Any, **kwargs: Any) -> Any:
assert not torch._dynamo.config.install_free_tensors
with (
torch._dynamo.config.patch(replay_side_effects=False),
torch._dynamo.config.patch(side_effect_replay_policy="warn"),
get_metrics_context(),
dynamo_timed("fullgraph_capture"),
):

View File

@ -2538,7 +2538,7 @@ class CppKernel(Kernel):
@property
def assert_function(self) -> str:
if V.graph.aot_mode:
return "STD_TORCH_CHECK"
return "AOTI_TORCH_CHECK"
else:
return "TORCH_CHECK"

View File

@ -1442,6 +1442,13 @@ def register_replacement(
"""
argnames_static = [*inspect.signature(search_fn).parameters.keys()]
if inspect.ismethod(search_fn):
search_fn = _wrap_bound_method(search_fn, argnames_static)
if inspect.ismethod(replace_fn):
replace_argnames = [*inspect.signature(replace_fn).parameters.keys()]
replace_fn = _wrap_bound_method(replace_fn, replace_argnames)
def check_fn(match: Match) -> bool:
"""
Often shapes get burned into the pattern, so our initial match ran with
@ -1933,6 +1940,22 @@ def compute_mutation_region_ids(graph: torch.fx.Graph) -> None:
nd.meta["mutation_region_id"] = mutation_region_id
def _wrap_bound_method(fn: Any, argnames: list[str]) -> Any:
"""
Wrap a bound method to remove 'self' from its signature for FX tracing.
"""
def wrapper(*args: Any, **kwargs: Any) -> Any:
return fn(*args, **kwargs)
params = [
inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD)
for name in argnames
]
wrapper.__signature__ = inspect.Signature(params) # type: ignore[attr-defined]
return wrapper
class PatternMatcherPass:
def __init__(
self,

View File

@ -836,9 +836,10 @@ class AOTInductorModelBase {
}
void update_constants_array_from_map() {
STD_TORCH_CHECK(
constants_map_,
"constants_map_ was not ready when constants_ is trying to be constructed from it!");
if (!constants_map_) {
throw std::runtime_error{
"constants_map_ was not ready when constants_ is trying to be constructed from it!"};
}
if (!constants_) {
constants_ =
std::make_shared<std::vector<ConstantHandle>>(constants_info_.size());
@ -874,7 +875,9 @@ class AOTInductorModelBase {
/// Returns true if the model is complete.
bool is_finished() {
#ifdef USE_CUDA
STD_TORCH_CHECK(run_finished_, "Model CUDA event was not initialized");
if (!run_finished_) {
throw std::runtime_error{"Model CUDA event was not initialized"};
}
auto event_status = cudaEventQuery(*run_finished_);
if (event_status == cudaSuccess) {
@ -883,13 +886,13 @@ class AOTInductorModelBase {
return false;
}
STD_TORCH_CHECK(
false,
"The model did not finish successfully. Error: ",
throw std::runtime_error(
std::string("The model did not finish successfully. Error: ") +
cudaGetErrorString(cudaGetLastError()));
#elif defined(USE_XPU)
STD_TORCH_CHECK(run_finished_, "Model XPU event was not initialized");
if (!run_finished_) {
throw std::runtime_error{"Model XPU event was not initialized"};
}
using namespace sycl::info;
return (*run_finished_)->get_info<event::command_execution_status>() ==
event_command_status::complete;
@ -901,14 +904,19 @@ class AOTInductorModelBase {
/// Synchronizes completion event.
void wait_for_completion() {
STD_TORCH_CHECK(run_finished_, "Model event was not initialized");
#ifdef USE_CUDA
if (!run_finished_) {
throw std::runtime_error{"Model event was not initialized"};
}
AOTI_RUNTIME_CUDA_CHECK(cudaEventSynchronize(*run_finished_));
#endif // USE_CUDA
#ifdef USE_XPU
if (!run_finished_) {
throw std::runtime_error{"Model event was not initialized"};
}
(*run_finished_)->wait_and_throw();
#endif // USE_XPU
#endif
}
protected:

View File

@ -123,10 +123,8 @@ class AOTInductorModelContainer {
constants_folding_lk.unlock();
model_lk.lock();
} else if (const_folded != ConstantState::FOLDED) {
STD_TORCH_CHECK(
false,
"Unknown constant state: ",
toStringConstantState(constant_folded_));
throw std::runtime_error(
"Unknown constant state: " + toStringConstantState(constant_folded_));
}
try {
@ -169,10 +167,8 @@ class AOTInductorModelContainer {
/* validate_full_update = */ false);
const_folded = ConstantState::FOLDED;
} else if (constant_folded_ != ConstantState::FOLDED) {
STD_TORCH_CHECK(
false,
"Unknown constant state: ",
toStringConstantState(constant_folded_));
throw std::runtime_error(
"Unknown constant state: " + toStringConstantState(constant_folded_));
}
model->run_single_threaded(
@ -206,56 +202,56 @@ class AOTInductorModelContainer {
}
size_t num_constants() const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->num_constants();
}
// retrieve the constant name of constants_info_[idx]
const char* constant_name(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_name(static_cast<int64_t>(idx));
}
// retrieve original FQN of constants_info_[idx]
const char* constant_original_fqn(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_original_fqn(static_cast<int64_t>(idx));
}
// retrieve whether constant is from folded of constants_info_[idx]
bool constant_from_folded(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_from_folded(static_cast<int64_t>(idx));
}
size_t constant_data_size(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_data_size(static_cast<int64_t>(idx));
}
// retrieve type of constants_info_[idx]
int32_t constant_type(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_type(static_cast<int64_t>(idx));
}
// retrieve dtype of constants_info_[idx]
int32_t constant_dtype(size_t idx) const {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_dtype(static_cast<int64_t>(idx));
}
@ -387,12 +383,9 @@ class AOTInductorModelContainer {
<< " in model, but not provided by user!\n";
continue;
}
STD_TORCH_CHECK(
false,
"Cannot find constants ",
constant_name,
" in constants_map!");
throw std::runtime_error(
std::string("Cannot find constants ") + constant_name +
std::string(" in constants_map!"));
}
}
}
@ -402,8 +395,9 @@ class AOTInductorModelContainer {
std::unordered_map<std::string, AtenTensorHandle>&& constants_map,
bool use_inactive,
bool validate_full_update) {
STD_TORCH_CHECK(
this->num_models() != 0, "No available models in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No model available in container!");
}
if (validate_full_update) {
assert_all_constants(constants_map);
}
@ -449,9 +443,9 @@ class AOTInductorModelContainer {
bool use_inactive,
bool validate_full_update,
bool user_managed = false) {
STD_TORCH_CHECK(
this->num_models() != 0, "No model available in container!");
if (this->num_models() == 0) {
throw std::runtime_error("No model available in container!");
}
if (validate_full_update) {
assert_all_constants(constants_map);
}

View File

@ -7,7 +7,7 @@ namespace torch::aot_inductor {
template <typename T>
inline RAIIAtenTensorHandle scalar_to_tensor_handle(T value) {
STD_TORCH_CHECK(false, "Unsupported scalar_to_tensor_handle");
throw std::runtime_error("Unsupported scalar_to_tensor_handle");
}
// Specialize for supported C++ primitive types

View File

@ -11,11 +11,11 @@ template <>
struct ThreadLocalCachedOutputTensor<RAIIAtenTensorHandle> {
explicit ThreadLocalCachedOutputTensor(const RAIIAtenTensorHandle&) {}
void copy_data_from(const RAIIAtenTensorHandle& handle) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
AtenTensorHandle tensor() const {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
};
@ -23,11 +23,11 @@ template <>
struct ThreadLocalCachedOutputTensor<AtenTensorHandle> {
explicit ThreadLocalCachedOutputTensor(const AtenTensorHandle&) {}
void copy_data_from(const AtenTensorHandle& handle) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
AtenTensorHandle tensor() const {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
};
@ -35,11 +35,11 @@ template <>
struct ThreadLocalCachedOutputTensor<ConstantHandle> {
explicit ThreadLocalCachedOutputTensor(const ConstantHandle&) {}
void copy_data_from(const ConstantHandle& handle) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
AtenTensorHandle tensor() const {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
};
@ -92,18 +92,18 @@ struct ThreadLocalCachedOutputArray;
template <>
struct ThreadLocalCachedOutputArray<RAIIAtenTensorHandle> {
explicit ThreadLocalCachedOutputArray(const RAIIAtenTensorHandle&) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
// Not supported yet! We would need to put contiguous() or
// expect_contiguous() into the ABI.
void copy_data_from(const RAIIAtenTensorHandle&) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
template <typename U>
ArrayRefTensor<U> arrayref_tensor() const {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
};
@ -111,18 +111,18 @@ struct ThreadLocalCachedOutputArray<RAIIAtenTensorHandle> {
template <>
struct ThreadLocalCachedOutputArray<ConstantHandle> {
explicit ThreadLocalCachedOutputArray(const ConstantHandle&) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
// Not supported yet! We would need to put contiguous() or
// expect_contiguous() into the ABI.
void copy_data_from(const ConstantHandle&) {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
template <typename U>
ArrayRefTensor<U> arrayref_tensor() const {
STD_TORCH_CHECK(false, "can't happen");
throw std::runtime_error("can't happen");
}
};

View File

@ -38,10 +38,9 @@
// The following files are implemented in a header-only way and are guarded by
// test/cpp/aoti_abi_check
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/complex.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>
#ifdef __cplusplus
extern "C" {
@ -622,8 +621,34 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function(
int num_tensors,
AtenTensorHandle* flatten_tensor_args);
// Preserve for BC and will delete it later, using the STD_TORCH_CHECK directly
#define AOTI_TORCH_CHECK(cond, ...) STD_TORCH_CHECK(cond, ##__VA_ARGS__)
AOTI_TORCH_EXPORT void aoti_torch_check(
bool cond,
const char* func,
const char* file,
uint32_t line,
const char* msg);
#ifdef STRIP_ERROR_MESSAGES
#define AOTI_TORCH_CHECK(cond, ...) \
if (!(cond)) { \
aoti_torch_check( \
false, \
__func__, \
__FILE__, \
static_cast<uint32_t>(__LINE__), \
TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \
}
#else
#define AOTI_TORCH_CHECK(cond, ...) \
if (!(cond)) { \
aoti_torch_check( \
false, \
__func__, \
__FILE__, \
static_cast<uint32_t>(__LINE__), \
TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \
}
#endif
AOTI_TORCH_EXPORT void aoti_torch_warn(
const char* func,

View File

@ -1339,14 +1339,13 @@ AOTITorchError aoti_torch_proxy_executor_call_function(
int num_tensors,
AtenTensorHandle* flatten_tensor_args) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
TORCH_CHECK(
proxy_executor != nullptr,
"Unable to find a proxy executor to run custom ops.",
"Please check if there is a json file generated",
"in the same directory as the so,",
"or use torch._inductor.aoti_compile_and_package",
"to package everything into a PT2 artifact.");
if (!proxy_executor) {
throw std::runtime_error(
"Unable to find a proxy executor to run custom ops. Please check if "
"there is a json file generated in the same directory as the so, or use "
"torch._inductor.aoti_compile_and_package to package everything into a "
"PT2 artifact.");
}
ProxyExecutor* executor = reinterpret_cast<ProxyExecutor*>(proxy_executor);
executor->call_function(
extern_node_index,
@ -1357,6 +1356,17 @@ AOTITorchError aoti_torch_proxy_executor_call_function(
});
}
void aoti_torch_check(
bool cond,
const char* func,
const char* file,
uint32_t line,
const char* msg) {
if (C10_UNLIKELY_OR_CONST(!cond)) {
::c10::detail::torchCheckFail(func, file, line, msg);
}
}
void aoti_torch_warn(
const char* func,
const char* file,

View File

@ -10,7 +10,9 @@ AOTITorchError aoti_torch_mps_set_arg_tensor(
AtenTensorHandle tensor) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto t = tensor_handle_to_tensor_pointer(tensor);
TORCH_CHECK(t != nullptr, "Tensor is null.");
if (t == nullptr) {
throw std::runtime_error("Tensor is null.");
}
auto func = reinterpret_cast<at::native::mps::MetalKernelFunction*>(handle);
func->setArg(idx, *t);
});

View File

@ -92,11 +92,13 @@ inline void assert_inf_and_nan(
const std::string& tensor_name,
at::Tensor& check_tensor) {
auto isnan_tensor = check_tensor.isnan();
TORCH_CHECK(
!isnan_tensor.any().item<bool>(), "At least one NaN in ", tensor_name);
if (isnan_tensor.any().item<bool>()) {
throw std::runtime_error("At least one NaN in " + tensor_name);
}
auto isinf_tensor = check_tensor.isinf();
TORCH_CHECK(
!isinf_tensor.any().item<bool>(), "At least one INF in ", tensor_name);
if (isinf_tensor.any().item<bool>()) {
throw std::runtime_error("At least one INF in " + tensor_name);
}
}
// utility functions to convert a pointer to an optional value

View File

@ -69,7 +69,7 @@ inline torch::stable::Tensor narrow(
inline torch::stable::Tensor new_empty(
const torch::stable::Tensor& self,
torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<c10::ScalarType> dtype = std::nullopt) {
std::optional<torch::headeronly::ScalarType> dtype = std::nullopt) {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
@ -108,7 +108,7 @@ inline torch::stable::Tensor new_empty(
inline torch::stable::Tensor new_zeros(
const torch::stable::Tensor& self,
torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<c10::ScalarType> dtype = std::nullopt) {
std::optional<torch::headeronly::ScalarType> dtype = std::nullopt) {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
@ -306,6 +306,66 @@ inline uint32_t get_num_threads() {
return num_threads;
}
// We expect this to be the stable version of the empty op that takes in
// device and dtype parameters. The empty op creates a tensor with uninitialized
// values of the specified size, dtype, and device.
inline torch::stable::Tensor empty(
torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<torch::headeronly::ScalarType> dtype = std::nullopt,
std::optional<torch::stable::Device> device = std::nullopt,
std::optional<bool> pin_memory = std::nullopt) {
const auto num_args = 6;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(size),
torch::stable::detail::from(dtype),
torch::stable::detail::from(std::nullopt),
torch::stable::detail::from(device),
torch::stable::detail::from(pin_memory),
torch::stable::detail::from(std::nullopt)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::empty", "memory_format", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
// We expect this to be the stable version of the flatten.using_ints op.
inline torch::stable::Tensor flatten(
const torch::stable::Tensor& self,
int64_t start_dim = 0,
int64_t end_dim = -1) {
const auto num_args = 3;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self),
torch::stable::detail::from(start_dim),
torch::stable::detail::from(end_dim)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::flatten", "using_ints", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
// We expect this to be the stable version of the reshape op.
inline torch::stable::Tensor reshape(
const torch::stable::Tensor& self,
torch::headeronly::IntHeaderOnlyArrayRef shape) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(shape)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::reshape", "", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
// We expect this to be the stable version of the view op.
inline torch::stable::Tensor view(
const torch::stable::Tensor& self,
torch::headeronly::IntHeaderOnlyArrayRef size) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(size)};
TORCH_ERROR_CODE_CHECK(
torch_call_dispatcher("aten::view", "", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
#endif
HIDDEN_NAMESPACE_END(torch, stable)

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import copy
import inspect
import warnings
from collections.abc import Callable, Sequence
@ -96,16 +97,23 @@ class _ToTorchTensor(torch.autograd.Function):
)
tensor_stride = tuple(tensor_stride)
grad_placements = grad_placements or dtensor_spec.placements
grad_spec = DTensorSpec(
mesh,
grad_placements,
tensor_meta=TensorMeta(
shape=dtensor_meta.shape,
stride=tensor_stride,
dtype=dtensor_meta.dtype,
),
)
if (
tensor_stride == dtensor_meta.stride
and grad_placements == dtensor_spec.placements
):
# Avoid actual sharing of specs in case they're modified during (e.g.)
# sharding propagation.
grad_spec = copy.copy(dtensor_spec)
else:
grad_spec = DTensorSpec(
mesh,
grad_placements,
tensor_meta=TensorMeta(
shape=dtensor_meta.shape,
stride=tensor_stride,
dtype=dtensor_meta.dtype,
),
)
return (
# pyrefly: ignore [bad-argument-type]
DTensor(

View File

@ -25,6 +25,7 @@ from torch.distributed.tensor.placement_types import (
Replicate,
Shard,
)
from torch.types import _Number
from torch.utils._typing_utils import not_none
@ -465,6 +466,7 @@ def pointwise_strategy(op_schema: OpSchema, linearity: int = -1) -> OpStrategy:
f"no strategy to follow for {op_schema}!"
)
return common_pointwise_strategy(
op_schema.op,
op_schema.args_schema,
followed_strategy,
followed_strategy_index,
@ -489,6 +491,7 @@ def linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType:
def common_pointwise_strategy(
op,
args_schema: Sequence[object],
followed_strategy: OpStrategy,
followed_strategy_index: int,
@ -530,10 +533,16 @@ def common_pointwise_strategy(
new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim
out_placements.append(Shard(new_shard_dim))
elif isinstance(placement, Partial):
# partial + scalar doesn't work
addition_ops = [aten.add.Tensor, aten.add_.Tensor]
# note that only partial-sum and partial-avg are supported for linearity
partial_supports_linearity = placement.is_partial(
"sum"
) or placement.is_partial("avg")
partial_supports_linearity = (
placement.is_partial("sum") or placement.is_partial("avg")
) and not (
op in addition_ops
and any(isinstance(arg, _Number) for arg in args_schema)
)
if linearity > 0 and partial_supports_linearity:
# propagate the partial placement
out_placements.append(placement)
@ -748,6 +757,7 @@ def list_pointwise_strategy(
for arg_strategy in args_strategies
]
pointwise_strategy: OpStrategy = common_pointwise_strategy(
op_schema.op,
args_schema,
child_strtgy,
linearity,