[dynamo, nested graph breaks] add nested graph break tests (#144516)

Note: nested graph break tests (and wrapped tests) are xfailed/skipped for now - we will iteratively enable the tests as more of the nested graph break implementation is complete.

Differential Revision: [D81084809](https://our.internmc.facebook.com/intern/diff/D81084809)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144516
Approved by: https://github.com/anijain2305
This commit is contained in:
William Wen
2025-08-26 15:57:22 -07:00
committed by PyTorch MergeBot
parent b36a20d368
commit 8b78ba07b1
11 changed files with 568 additions and 10 deletions

View File

@ -0,0 +1,18 @@
import torch
global1 = torch.ones(3)
def reset_state():
global global1
global1 = torch.ones(3)
def fn(val, call):
global global1
global1 += 1
val = val + global1
val = call(val)
val = val + 1
return val

View File

@ -0,0 +1,424 @@
# Owner(s): ["module: dynamo"]
import unittest
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import config
from torch._dynamo.testing import make_test_cls_with_patches
try:
# from . import test_ctx_manager
pass
except ImportError:
# import test_aot_autograd
# import test_ctx_manager
# import test_export
# import test_functions
# import test_higher_order_ops
# import test_misc
# import test_modules
# import test_repros
# import test_sdpa
# import test_subgraphs
pass
test_classes = {}
def make_nested_cls(cls):
suffix = "_nested_graph_breaks"
cls_prefix = "NestedGraphBreaks"
test_class = make_test_cls_with_patches(
cls,
cls_prefix,
suffix,
(config, "debug_force_nested_calls", True),
(config, "debug_force_graph_break_on_leaf_return", True),
(config, "debug_disable_compile_counter", True),
xfail_prop="_expected_failure_nested_graph_breaks",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
# globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
return test_class
tests = [
# test_ctx_manager.CtxManagerTests,
# test_functions.FunctionTests,
# test_misc.MiscTests,
# test_repros.ReproTests,
# test_modules.NNModuleTests,
# test_subgraphs.SubGraphTests,
# test_higher_order_ops.HigherOrderOpTests,
# test_higher_order_ops.FuncTorchHigherOrderOpTests,
# test_aot_autograd.AotAutogradFallbackTests,
# test_sdpa.TestSDPA,
]
test = None
for test in tests:
make_nested_cls(test)
del test
global_val = 0
class CustomizedCtxManager:
def __init__(self, val):
self.val = val
def __enter__(self):
global global_val
global_val += self.val
def __exit__(self, exc_type, exc_value, traceback):
global global_val
global_val -= self.val
# for use in test_side_effects_globals
global1, global2, global3, global4 = (torch.zeros(3),) * 4
class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
def setUp(self):
super().setUp()
torch._dynamo.config.nested_graph_breaks = True
def tearDown(self):
super().tearDown()
torch._dynamo.config.nested_graph_breaks = False
@unittest.expectedFailure
def test_single_graph_break(self):
def f1(x1):
x1 = x1 + 1
torch._dynamo.graph_break()
return x1 + 2
def f2(x2):
return f1(x2 + 4) + 8
def f3(x3):
return f2(x3 + 16) + 32
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_single_graph_break_repeat(self):
def f1(x1):
x1 = x1 + 1
torch._dynamo.graph_break()
return x1 + 2
def f2(x2):
tmp1 = f1(x2 + 4)
tmp2 = f1(x2 + 8) << 4
return tmp1 + tmp2
def f3(x3):
return f2(x3 + 256) + 512
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3, dtype=torch.long)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
@unittest.expectedFailure
def test_doubly_nested_graph_break(self):
def f1(x1):
x1 = x1 + 1
torch._dynamo.graph_break()
return x1 + 2
def f2(x2):
x2 = x2 + 4
torch._dynamo.graph_break()
return f1(x2 + 8) + 16
def f3(x3):
return f2(x3 + 32) + 64
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
@unittest.expectedFailure
def test_differing_arg_nums(self):
def f1(x1, x2):
x = x1 + x2
torch._dynamo.graph_break()
return x + 1
def f2(x3, x4, x5, x6):
return f1(x3 + x4, x5 + x6) + 2
def f3(x7, x8):
return f2(x7, x7 + 4, x8, x8 + 8) + 16
def f4(x9):
return f3(x9, x9 + 32) + 64
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f4)
x = torch.zeros(3)
res = f4(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_differing_locals_nums(self):
def f1(x1):
loc1 = x1 + 1
torch._dynamo.graph_break()
return loc1 + 2
def f2(x2):
loc1 = x2 + 4
loc2 = x2 + 8
return f1(x2) + loc1 + loc2
def f3(x3):
loc1 = x3 + 16
loc2 = x3 + 32
loc3 = x3 + 64
loc4 = x3 + 128
return f2(x3) + loc1 + loc2 + loc3 + loc4
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_ctx_manager(self):
global global_val
global_val = 0
@torch._dynamo.disable
def f1():
return global_val
def f2(x2):
with CustomizedCtxManager(8):
x2 = x2 + (1 << 4)
x2 = x2 + f1() # 15
x2 = x2 + (1 << 5)
x2 = x2 << 2
x2 = x2 + global_val # 3
with CustomizedCtxManager(4):
x2 = x2 << 4
x2 = x2 + f1() # 7
x2 = x2 + (1 << 3)
return x2
def f3(x3):
with CustomizedCtxManager(2):
return f2(x3)
def f4(x4):
with CustomizedCtxManager(1):
return f3(x4)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f4)
x = torch.zeros(3, dtype=torch.long)
res = f4(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
@unittest.expectedFailure
def test_cells(self):
def f1(x1):
cell1 = x1 + 1
cell2 = x1 + 2
def f2(x2, x3):
nonlocal cell1
cell3 = x2 + x3 + 4
cell1 += 8
def f3(x4):
nonlocal cell2, cell3
cell2 += 16
cell3 += 32
torch._dynamo.graph_break()
return x4 + cell1 + cell2 + cell3
return f3(x2 + x3), cell3
return f2(x1 + 64, x1 + 128) + (cell1, cell2)
def outer(x):
return f1(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
x = torch.zeros(3)
res = outer(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_side_effects_cells(self):
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4
def f1():
nonlocal cell1
cell1 += 1
torch._dynamo.graph_break()
return cell1 + cell2
def f2():
nonlocal cell3
cell3 += 2
return f1() + cell3 + cell4
def f3():
return f2()
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
cell1 = torch.zeros(3)
cell2 = torch.zeros(3) + 4
cell3 = torch.zeros(3)
cell4 = torch.zeros(3) + 8
res = f3()
res = (res,) + tuple(x.clone() for x in (cell1, cell2, cell3, cell4))
cell1 = torch.zeros(3)
cell2 = torch.zeros(3) + 4
cell3 = torch.zeros(3)
cell4 = torch.zeros(3) + 8
ref = opt_fn()
ref = (ref,) + tuple(x.clone() for x in (cell1, cell2, cell3, cell4))
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_side_effects_globals(self):
global global1, global2, global3, global4
def f1():
global global1
global1 += 1
torch._dynamo.graph_break()
return global1 + global2
def f2():
global global3
global3 += 2
return f1() + global3 + global4
def f3(x):
return x + f2()
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.ones(3)
global1 = torch.zeros(3)
global2 = torch.zeros(3) + 4
global3 = torch.zeros(3)
global4 = torch.zeros(3) + 8
res = (f3(x), global1.clone(), global2, global3.clone(), global4)
global1 = torch.zeros(3)
global2 = torch.zeros(3) + 4
global3 = torch.zeros(3)
global4 = torch.zeros(3) + 8
ref = (opt_fn(x), global1.clone(), global2, global3.clone(), global4)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_side_effects_globals_different_module(self):
try:
from . import _test_nested_graph_breaks_helper
except ImportError:
import _test_nested_graph_breaks_helper
def f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 1
def f2(x):
x = x + 1
x = _test_nested_graph_breaks_helper.fn(x, f1)
return x + 1
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f2)
_test_nested_graph_breaks_helper.reset_state()
x = torch.zeros(3)
res = (f2(x), _test_nested_graph_breaks_helper.global1.clone())
_test_nested_graph_breaks_helper.reset_state()
ref = (opt_fn(x), _test_nested_graph_breaks_helper.global1.clone())
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_nested_graph_break_in_loop(self):
def f1(x, i):
if i == 5:
torch._dynamo.graph_break()
return x + 1
def f2(x):
for i in range(8):
x = f1(x, i)
return x
def f3(x):
x = x + 1
x = f2(x)
x = x + 1
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# skip frame due to nested graph break in for loop
self.assertEqual(cnts.frame_count, 0)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -25,6 +25,7 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
from ..utils._backport_slots import dataclass_slots
from . import config
from .bytecode_analysis import (
get_indexof,
propagate_line_nums,
@ -1200,6 +1201,50 @@ def remove_fused_load_store(instructions: list[Instruction]) -> None:
instructions[:] = new_insts
# adds GRAPH_BREAK_IF_LEAF (not a real instruction) before RETURN_* instructions
# for testing purposes
def add_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> None:
new_insts = []
for inst in instructions:
if "RETURN" in inst.opname:
replace_insts = [
create_instruction("NOP", argval="GRAPH_BREAK_IF_LEAF"),
create_instruction(inst.opname, argval=inst.argval),
]
# breakpoint()
new_insts.extend(overwrite_instruction(inst, replace_insts))
else:
new_insts.append(inst)
instructions[:] = new_insts
def remove_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> None:
new_insts = []
for inst, next_inst in zip(instructions, instructions[1:]):
if (
inst.opname == "NOP"
and inst.argval == "GRAPH_BREAK_IF_LEAF"
and next_inst.opname.startswith("RETURN")
):
# remove this instruction and update all other instructions' jump targets
for i in range(len(instructions)):
if instructions[i].target is inst:
instructions[i].target = next_inst
if instructions[i].exn_tab_entry:
# linter is mistakenly complaining that None has no attribute "..."
# but this codepath only runs if instructions[i] is not None
if instructions[i].exn_tab_entry.start is inst: # type: ignore[union-attr]
instructions[i].exn_tab_entry.start = next_inst # type: ignore[union-attr]
if instructions[i].exn_tab_entry.end is inst: # type: ignore[union-attr]
instructions[i].exn_tab_entry.end = next_inst # type: ignore[union-attr]
if instructions[i].exn_tab_entry.target is inst: # type: ignore[union-attr]
instructions[i].exn_tab_entry.target = next_inst # type: ignore[union-attr]
else:
new_insts.append(inst)
new_insts.append(instructions[-1])
instructions[:] = new_insts
def explicit_super(code: types.CodeType, instructions: list[Instruction]) -> None:
"""convert super() with no args into explicit arg form"""
cell_and_free = (code.co_cellvars or ()) + (code.co_freevars or ())
@ -1521,6 +1566,7 @@ def transform_code_object(
def clean_and_assemble_instructions(
instructions: list[Instruction], keys: list[str], code_options: dict[str, Any]
) -> tuple[list[Instruction], types.CodeType]:
remove_graph_break_if_leaf_instructions(instructions)
# also implicitly checks for no duplicate instructions
check_inst_exn_tab_entries_valid(instructions)
@ -1636,6 +1682,8 @@ def _cached_cleaned_instructions(
remove_binary_store_slice(instructions)
if sys.version_info >= (3, 13):
remove_fused_load_store(instructions)
if config.debug_force_graph_break_on_leaf_return:
add_graph_break_if_leaf_instructions(instructions)
if sys.version_info >= (3, 11):
update_offsets(instructions)
devirtualize_jumps(instructions)

View File

@ -481,6 +481,18 @@ issue_3_13_0_warning = True
# traced FX graph is empty when RETURN_* is traced.
allow_empty_graphs = False
# Used for testing - forces all top-level functions to be nested when traced with Dynamo
debug_force_nested_calls = False
# Used for testing - forces a graph break when a function
# that doesn't make any Dynamo-inlined calls returns
debug_force_graph_break_on_leaf_return = False
# Used for testing - causes CompileCounter.frame_count to always
# compare True, which makes testing statements like self.assertEqual(CompileCounter.frame_count, n)
# always pass.
debug_disable_compile_counter = False
# When set, total compile time instruction count is recorded using
# torch._dynamo.utilsCompileTimeInstructionCounter.
record_compile_time_instruction_count = False

View File

@ -36,6 +36,7 @@ import textwrap
import threading
import traceback
import types
import unittest
import warnings
import weakref
from dataclasses import dataclass
@ -739,7 +740,9 @@ class _TorchDynamoContext:
filename = inspect.getsourcefile(fn)
except TypeError:
filename = None
if config.wrap_top_frame or (
if config.debug_force_nested_calls:
fn = external_utils.wrap_inline(fn)
elif config.wrap_top_frame or (
(filename is None or trace_rules.check(fn))
and (
getattr(fn, "__name__", "")
@ -1219,7 +1222,8 @@ def _optimize(
),
hooks,
backend_ctx_ctor,
error_on_graph_break=nopython,
error_on_graph_break=nopython
and not config.debug_force_graph_break_on_leaf_return,
dynamic=dynamic,
compiler_config=(
backend.get_compiler_config()
@ -1760,6 +1764,9 @@ def export(
Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
"""
if config.debug_force_graph_break_on_leaf_return:
raise unittest.SkipTest("Cannot force graph break on export")
if _log_export_usage:
log_export_usage(event="export.private_api", flags={"_dynamo"})

View File

@ -2690,5 +2690,15 @@
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0269": [
{
"Gb_type": "Forced graph break on leaf function",
"Context": "",
"Explanation": "Forced graph break for nested graph break testing purposes",
"Hints": [
"Set torch._dynamo.config.debug_force_graph_break_on_leaf_return = False"
]
}
]
}

View File

@ -1267,6 +1267,7 @@ class InstructionTranslatorBase(
"""
A call to some user defined function by inlining it.
"""
self.is_leaf_tracer = False
if config.enable_faithful_generator_behavior and is_generator(fn.get_code()): # type: ignore[attr-defined]
return self.inline_generator_function(fn, args, kwargs)
else:
@ -2927,8 +2928,22 @@ class InstructionTranslatorBase(
hints=[*graph_break_hints.USER_ERROR],
)
@break_graph_if_unsupported(push=0)
def graph_break_on_leaf_function(self, inst: Instruction) -> None:
if self.is_leaf_tracer:
unimplemented_v2(
gb_type="Forced graph break on leaf function",
context="",
explanation="Forced graph break for nested graph break testing purposes",
hints=[
"Set torch._dynamo.config.debug_force_graph_break_on_leaf_return = False",
],
)
def NOP(self, inst: Instruction) -> None:
pass
# Dynamo-specific testing behavior
if inst.argval == "GRAPH_BREAK_IF_LEAF":
self.graph_break_on_leaf_function(inst)
def POP_TOP(self, inst: Instruction) -> None:
self.pop()

View File

@ -101,6 +101,18 @@ class TestCase(TorchTestCase):
log.warning("Running test changed grad mode")
torch.set_grad_enabled(self._prior_is_grad_enabled)
def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
if (
config.debug_disable_compile_counter
and isinstance(x, utils.CompileCounterInt)
or isinstance(y, utils.CompileCounterInt)
):
return
return super().assertEqual(x, y, *args, **kwargs)
# assertExpectedInline might also need to be disabled for wrapped nested
# graph break tests
class CPythonTestCase(TestCase):
"""

View File

@ -42,7 +42,7 @@ from .bytecode_transformation import (
)
from .guards import CheckFunctionManager, CompileId, GuardedCode
from .types import ConvertFrameReturn, DynamoFrameType, wrap_guarded_code
from .utils import same
from .utils import CompileCounterInt, same
np: Optional[types.ModuleType] = None
@ -227,8 +227,8 @@ def debug_insert_nops(
class CompileCounter:
def __init__(self) -> None:
self.frame_count = 0
self.op_count = 0
self.frame_count: Union[int, CompileCounterInt] = 0
self.clear()
def __call__(
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
@ -240,16 +240,19 @@ class CompileCounter:
return gm.forward
def clear(self) -> None:
self.frame_count = 0
if config.debug_disable_compile_counter:
self.frame_count = CompileCounterInt(0)
else:
self.frame_count = 0
self.op_count = 0
class CompileCounterWithBackend:
def __init__(self, backend: str) -> None:
self.frame_count = 0
self.op_count = 0
self.frame_count: Union[int, CompileCounterInt] = 0
self.backend = backend
self.graphs: list[torch.fx.GraphModule] = []
self.clear()
def __call__(
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
@ -264,7 +267,10 @@ class CompileCounterWithBackend:
return lookup_backend(self.backend)(gm, example_inputs)
def clear(self) -> None:
self.frame_count = 0
if config.debug_disable_compile_counter:
self.frame_count = CompileCounterInt(0)
else:
self.frame_count = 0
self.op_count = 0
self.graphs = []

View File

@ -3404,6 +3404,7 @@ MOD_INLINELIST = [
"torch._dynamo.compiled_autograd",
"torch._dynamo.comptime",
"torch._dynamo.polyfills",
"torch._dynamo.test_case",
"torch._functorch._aot_autograd.subclass_parametrization",
"torch._functorch.autograd_function",
"torch._functorch.eager_transforms",

View File

@ -4734,6 +4734,11 @@ class CompileTimeInstructionCounter:
cls.end()
class CompileCounterInt(int):
def __add__(self, other: Any) -> CompileCounterInt:
return CompileCounterInt(super().__add__(other))
def set_feature_use(feature: str, usage: bool) -> None:
"""
Records whether we are using a feature