mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo, nested graph breaks] add nested graph break tests (#144516)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144516 Approved by: https://github.com/anijain2305 ghstack dependencies: #157971, #159281
This commit is contained in:
committed by
PyTorch MergeBot
parent
504a6445a4
commit
9a756c2d71
18
test/dynamo/_test_nested_graph_breaks_helper.py
Normal file
18
test/dynamo/_test_nested_graph_breaks_helper.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
|
||||
global1 = torch.ones(3)
|
||||
|
||||
|
||||
def reset_state():
|
||||
global global1
|
||||
global1 = torch.ones(3)
|
||||
|
||||
|
||||
def fn(val, call):
|
||||
global global1
|
||||
global1 += 1
|
||||
val = val + global1
|
||||
val = call(val)
|
||||
val = val + 1
|
||||
return val
|
424
test/dynamo/test_nested_graph_breaks.py
Normal file
424
test/dynamo/test_nested_graph_breaks.py
Normal file
@ -0,0 +1,424 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
|
||||
|
||||
try:
|
||||
# from . import test_ctx_manager
|
||||
pass
|
||||
except ImportError:
|
||||
# import test_aot_autograd
|
||||
# import test_ctx_manager
|
||||
|
||||
# import test_export
|
||||
# import test_functions
|
||||
# import test_higher_order_ops
|
||||
# import test_misc
|
||||
# import test_modules
|
||||
# import test_repros
|
||||
# import test_sdpa
|
||||
# import test_subgraphs
|
||||
pass
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def make_nested_cls(cls):
|
||||
suffix = "_nested_graph_breaks"
|
||||
|
||||
cls_prefix = "NestedGraphBreaks"
|
||||
|
||||
test_class = make_test_cls_with_patches(
|
||||
cls,
|
||||
cls_prefix,
|
||||
suffix,
|
||||
(config, "debug_force_nested_calls", True),
|
||||
(config, "debug_force_graph_break_on_leaf_return", True),
|
||||
(config, "debug_disable_compile_counter", True),
|
||||
xfail_prop="_expected_failure_nested_graph_breaks",
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
# globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
return test_class
|
||||
|
||||
|
||||
tests = [
|
||||
# test_ctx_manager.CtxManagerTests,
|
||||
# test_functions.FunctionTests,
|
||||
# test_misc.MiscTests,
|
||||
# test_repros.ReproTests,
|
||||
# test_modules.NNModuleTests,
|
||||
# test_subgraphs.SubGraphTests,
|
||||
# test_higher_order_ops.HigherOrderOpTests,
|
||||
# test_higher_order_ops.FuncTorchHigherOrderOpTests,
|
||||
# test_aot_autograd.AotAutogradFallbackTests,
|
||||
# test_sdpa.TestSDPA,
|
||||
]
|
||||
test = None
|
||||
for test in tests:
|
||||
make_nested_cls(test)
|
||||
del test
|
||||
|
||||
global_val = 0
|
||||
|
||||
|
||||
class CustomizedCtxManager:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def __enter__(self):
|
||||
global global_val
|
||||
global_val += self.val
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
global global_val
|
||||
global_val -= self.val
|
||||
|
||||
|
||||
# for use in test_side_effects_globals
|
||||
global1, global2, global3, global4 = (torch.zeros(3),) * 4
|
||||
|
||||
|
||||
class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch._dynamo.config.nested_graph_breaks = True
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
torch._dynamo.config.nested_graph_breaks = False
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_single_graph_break(self):
|
||||
def f1(x1):
|
||||
x1 = x1 + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x1 + 2
|
||||
|
||||
def f2(x2):
|
||||
return f1(x2 + 4) + 8
|
||||
|
||||
def f3(x3):
|
||||
return f2(x3 + 16) + 32
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||
x = torch.zeros(3)
|
||||
res = f3(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_single_graph_break_repeat(self):
|
||||
def f1(x1):
|
||||
x1 = x1 + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x1 + 2
|
||||
|
||||
def f2(x2):
|
||||
tmp1 = f1(x2 + 4)
|
||||
tmp2 = f1(x2 + 8) << 4
|
||||
return tmp1 + tmp2
|
||||
|
||||
def f3(x3):
|
||||
return f2(x3 + 256) + 512
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||
x = torch.zeros(3, dtype=torch.long)
|
||||
res = f3(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_doubly_nested_graph_break(self):
|
||||
def f1(x1):
|
||||
x1 = x1 + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x1 + 2
|
||||
|
||||
def f2(x2):
|
||||
x2 = x2 + 4
|
||||
torch._dynamo.graph_break()
|
||||
return f1(x2 + 8) + 16
|
||||
|
||||
def f3(x3):
|
||||
return f2(x3 + 32) + 64
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||
x = torch.zeros(3)
|
||||
res = f3(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_differing_arg_nums(self):
|
||||
def f1(x1, x2):
|
||||
x = x1 + x2
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
def f2(x3, x4, x5, x6):
|
||||
return f1(x3 + x4, x5 + x6) + 2
|
||||
|
||||
def f3(x7, x8):
|
||||
return f2(x7, x7 + 4, x8, x8 + 8) + 16
|
||||
|
||||
def f4(x9):
|
||||
return f3(x9, x9 + 32) + 64
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f4)
|
||||
x = torch.zeros(3)
|
||||
res = f4(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_differing_locals_nums(self):
|
||||
def f1(x1):
|
||||
loc1 = x1 + 1
|
||||
torch._dynamo.graph_break()
|
||||
return loc1 + 2
|
||||
|
||||
def f2(x2):
|
||||
loc1 = x2 + 4
|
||||
loc2 = x2 + 8
|
||||
return f1(x2) + loc1 + loc2
|
||||
|
||||
def f3(x3):
|
||||
loc1 = x3 + 16
|
||||
loc2 = x3 + 32
|
||||
loc3 = x3 + 64
|
||||
loc4 = x3 + 128
|
||||
return f2(x3) + loc1 + loc2 + loc3 + loc4
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||
x = torch.zeros(3)
|
||||
res = f3(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_ctx_manager(self):
|
||||
global global_val
|
||||
global_val = 0
|
||||
|
||||
@torch._dynamo.disable
|
||||
def f1():
|
||||
return global_val
|
||||
|
||||
def f2(x2):
|
||||
with CustomizedCtxManager(8):
|
||||
x2 = x2 + (1 << 4)
|
||||
x2 = x2 + f1() # 15
|
||||
x2 = x2 + (1 << 5)
|
||||
x2 = x2 << 2
|
||||
x2 = x2 + global_val # 3
|
||||
with CustomizedCtxManager(4):
|
||||
x2 = x2 << 4
|
||||
x2 = x2 + f1() # 7
|
||||
x2 = x2 + (1 << 3)
|
||||
return x2
|
||||
|
||||
def f3(x3):
|
||||
with CustomizedCtxManager(2):
|
||||
return f2(x3)
|
||||
|
||||
def f4(x4):
|
||||
with CustomizedCtxManager(1):
|
||||
return f3(x4)
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f4)
|
||||
x = torch.zeros(3, dtype=torch.long)
|
||||
res = f4(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_cells(self):
|
||||
def f1(x1):
|
||||
cell1 = x1 + 1
|
||||
cell2 = x1 + 2
|
||||
|
||||
def f2(x2, x3):
|
||||
nonlocal cell1
|
||||
cell3 = x2 + x3 + 4
|
||||
cell1 += 8
|
||||
|
||||
def f3(x4):
|
||||
nonlocal cell2, cell3
|
||||
cell2 += 16
|
||||
cell3 += 32
|
||||
torch._dynamo.graph_break()
|
||||
return x4 + cell1 + cell2 + cell3
|
||||
|
||||
return f3(x2 + x3), cell3
|
||||
|
||||
return f2(x1 + 64, x1 + 128) + (cell1, cell2)
|
||||
|
||||
def outer(x):
|
||||
return f1(x)
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
|
||||
x = torch.zeros(3)
|
||||
res = outer(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_side_effects_cells(self):
|
||||
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4
|
||||
|
||||
def f1():
|
||||
nonlocal cell1
|
||||
cell1 += 1
|
||||
torch._dynamo.graph_break()
|
||||
return cell1 + cell2
|
||||
|
||||
def f2():
|
||||
nonlocal cell3
|
||||
cell3 += 2
|
||||
return f1() + cell3 + cell4
|
||||
|
||||
def f3():
|
||||
return f2()
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||
|
||||
cell1 = torch.zeros(3)
|
||||
cell2 = torch.zeros(3) + 4
|
||||
cell3 = torch.zeros(3)
|
||||
cell4 = torch.zeros(3) + 8
|
||||
res = f3()
|
||||
res = (res,) + tuple(x.clone() for x in (cell1, cell2, cell3, cell4))
|
||||
|
||||
cell1 = torch.zeros(3)
|
||||
cell2 = torch.zeros(3) + 4
|
||||
cell3 = torch.zeros(3)
|
||||
cell4 = torch.zeros(3) + 8
|
||||
ref = opt_fn()
|
||||
ref = (ref,) + tuple(x.clone() for x in (cell1, cell2, cell3, cell4))
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_side_effects_globals(self):
|
||||
global global1, global2, global3, global4
|
||||
|
||||
def f1():
|
||||
global global1
|
||||
global1 += 1
|
||||
torch._dynamo.graph_break()
|
||||
return global1 + global2
|
||||
|
||||
def f2():
|
||||
global global3
|
||||
global3 += 2
|
||||
return f1() + global3 + global4
|
||||
|
||||
def f3(x):
|
||||
return x + f2()
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||
x = torch.ones(3)
|
||||
|
||||
global1 = torch.zeros(3)
|
||||
global2 = torch.zeros(3) + 4
|
||||
global3 = torch.zeros(3)
|
||||
global4 = torch.zeros(3) + 8
|
||||
res = (f3(x), global1.clone(), global2, global3.clone(), global4)
|
||||
|
||||
global1 = torch.zeros(3)
|
||||
global2 = torch.zeros(3) + 4
|
||||
global3 = torch.zeros(3)
|
||||
global4 = torch.zeros(3) + 8
|
||||
ref = (opt_fn(x), global1.clone(), global2, global3.clone(), global4)
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_side_effects_globals_different_module(self):
|
||||
try:
|
||||
from . import _test_nested_graph_breaks_helper
|
||||
except ImportError:
|
||||
import _test_nested_graph_breaks_helper
|
||||
|
||||
def f1(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
def f2(x):
|
||||
x = x + 1
|
||||
x = _test_nested_graph_breaks_helper.fn(x, f1)
|
||||
return x + 1
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f2)
|
||||
|
||||
_test_nested_graph_breaks_helper.reset_state()
|
||||
x = torch.zeros(3)
|
||||
res = (f2(x), _test_nested_graph_breaks_helper.global1.clone())
|
||||
|
||||
_test_nested_graph_breaks_helper.reset_state()
|
||||
ref = (opt_fn(x), _test_nested_graph_breaks_helper.global1.clone())
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_nested_graph_break_in_loop(self):
|
||||
def f1(x, i):
|
||||
if i == 5:
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
def f2(x):
|
||||
for i in range(8):
|
||||
x = f1(x, i)
|
||||
return x
|
||||
|
||||
def f3(x):
|
||||
x = x + 1
|
||||
x = f2(x)
|
||||
x = x + 1
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||
x = torch.zeros(3)
|
||||
res = f3(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
# skip frame due to nested graph break in for loop
|
||||
self.assertEqual(cnts.frame_count, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -25,6 +25,7 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from ..utils._backport_slots import dataclass_slots
|
||||
from . import config
|
||||
from .bytecode_analysis import (
|
||||
get_indexof,
|
||||
propagate_line_nums,
|
||||
@ -1200,6 +1201,50 @@ def remove_fused_load_store(instructions: list[Instruction]) -> None:
|
||||
instructions[:] = new_insts
|
||||
|
||||
|
||||
# adds GRAPH_BREAK_IF_LEAF (not a real instruction) before RETURN_* instructions
|
||||
# for testing purposes
|
||||
def add_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> None:
|
||||
new_insts = []
|
||||
for inst in instructions:
|
||||
if "RETURN" in inst.opname:
|
||||
replace_insts = [
|
||||
create_instruction("NOP", argval="GRAPH_BREAK_IF_LEAF"),
|
||||
create_instruction(inst.opname, argval=inst.argval),
|
||||
]
|
||||
# breakpoint()
|
||||
new_insts.extend(overwrite_instruction(inst, replace_insts))
|
||||
else:
|
||||
new_insts.append(inst)
|
||||
instructions[:] = new_insts
|
||||
|
||||
|
||||
def remove_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> None:
|
||||
new_insts = []
|
||||
for inst, next_inst in zip(instructions, instructions[1:]):
|
||||
if (
|
||||
inst.opname == "NOP"
|
||||
and inst.argval == "GRAPH_BREAK_IF_LEAF"
|
||||
and next_inst.opname.startswith("RETURN")
|
||||
):
|
||||
# remove this instruction and update all other instructions' jump targets
|
||||
for i in range(len(instructions)):
|
||||
if instructions[i].target is inst:
|
||||
instructions[i].target = next_inst
|
||||
if instructions[i].exn_tab_entry:
|
||||
# linter is mistakenly complaining that None has no attribute "..."
|
||||
# but this codepath only runs if instructions[i] is not None
|
||||
if instructions[i].exn_tab_entry.start is inst: # type: ignore[union-attr]
|
||||
instructions[i].exn_tab_entry.start = next_inst # type: ignore[union-attr]
|
||||
if instructions[i].exn_tab_entry.end is inst: # type: ignore[union-attr]
|
||||
instructions[i].exn_tab_entry.end = next_inst # type: ignore[union-attr]
|
||||
if instructions[i].exn_tab_entry.target is inst: # type: ignore[union-attr]
|
||||
instructions[i].exn_tab_entry.target = next_inst # type: ignore[union-attr]
|
||||
else:
|
||||
new_insts.append(inst)
|
||||
new_insts.append(instructions[-1])
|
||||
instructions[:] = new_insts
|
||||
|
||||
|
||||
def explicit_super(code: types.CodeType, instructions: list[Instruction]) -> None:
|
||||
"""convert super() with no args into explicit arg form"""
|
||||
cell_and_free = (code.co_cellvars or ()) + (code.co_freevars or ())
|
||||
@ -1521,6 +1566,7 @@ def transform_code_object(
|
||||
def clean_and_assemble_instructions(
|
||||
instructions: list[Instruction], keys: list[str], code_options: dict[str, Any]
|
||||
) -> tuple[list[Instruction], types.CodeType]:
|
||||
remove_graph_break_if_leaf_instructions(instructions)
|
||||
# also implicitly checks for no duplicate instructions
|
||||
check_inst_exn_tab_entries_valid(instructions)
|
||||
|
||||
@ -1636,6 +1682,8 @@ def _cached_cleaned_instructions(
|
||||
remove_binary_store_slice(instructions)
|
||||
if sys.version_info >= (3, 13):
|
||||
remove_fused_load_store(instructions)
|
||||
if config.debug_force_graph_break_on_leaf_return:
|
||||
add_graph_break_if_leaf_instructions(instructions)
|
||||
if sys.version_info >= (3, 11):
|
||||
update_offsets(instructions)
|
||||
devirtualize_jumps(instructions)
|
||||
|
@ -481,6 +481,18 @@ issue_3_13_0_warning = True
|
||||
# traced FX graph is empty when RETURN_* is traced.
|
||||
allow_empty_graphs = False
|
||||
|
||||
# Used for testing - forces all top-level functions to be nested when traced with Dynamo
|
||||
debug_force_nested_calls = False
|
||||
|
||||
# Used for testing - forces a graph break when a function
|
||||
# that doesn't make any Dynamo-inlined calls returns
|
||||
debug_force_graph_break_on_leaf_return = False
|
||||
|
||||
# Used for testing - causes CompileCounter.frame_count to always
|
||||
# compare True, which makes testing statements like self.assertEqual(CompileCounter.frame_count, n)
|
||||
# always pass.
|
||||
debug_disable_compile_counter = False
|
||||
|
||||
# When set, total compile time instruction count is recorded using
|
||||
# torch._dynamo.utilsCompileTimeInstructionCounter.
|
||||
record_compile_time_instruction_count = False
|
||||
|
@ -36,6 +36,7 @@ import textwrap
|
||||
import threading
|
||||
import traceback
|
||||
import types
|
||||
import unittest
|
||||
import warnings
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
@ -739,7 +740,9 @@ class _TorchDynamoContext:
|
||||
filename = inspect.getsourcefile(fn)
|
||||
except TypeError:
|
||||
filename = None
|
||||
if config.wrap_top_frame or (
|
||||
if config.debug_force_nested_calls:
|
||||
fn = external_utils.wrap_inline(fn)
|
||||
elif config.wrap_top_frame or (
|
||||
(filename is None or trace_rules.check(fn))
|
||||
and (
|
||||
getattr(fn, "__name__", "")
|
||||
@ -1219,7 +1222,8 @@ def _optimize(
|
||||
),
|
||||
hooks,
|
||||
backend_ctx_ctor,
|
||||
error_on_graph_break=nopython,
|
||||
error_on_graph_break=nopython
|
||||
and not config.debug_force_graph_break_on_leaf_return,
|
||||
dynamic=dynamic,
|
||||
compiler_config=(
|
||||
backend.get_compiler_config()
|
||||
@ -1760,6 +1764,9 @@ def export(
|
||||
|
||||
Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
|
||||
"""
|
||||
if config.debug_force_graph_break_on_leaf_return:
|
||||
raise unittest.SkipTest("Cannot force graph break on export")
|
||||
|
||||
if _log_export_usage:
|
||||
log_export_usage(event="export.private_api", flags={"_dynamo"})
|
||||
|
||||
|
@ -2690,5 +2690,15 @@
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0269": [
|
||||
{
|
||||
"Gb_type": "Forced graph break on leaf function",
|
||||
"Context": "",
|
||||
"Explanation": "Forced graph break for nested graph break testing purposes",
|
||||
"Hints": [
|
||||
"Set torch._dynamo.config.debug_force_graph_break_on_leaf_return = False"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -1267,6 +1267,7 @@ class InstructionTranslatorBase(
|
||||
"""
|
||||
A call to some user defined function by inlining it.
|
||||
"""
|
||||
self.is_leaf_tracer = False
|
||||
if config.enable_faithful_generator_behavior and is_generator(fn.get_code()): # type: ignore[attr-defined]
|
||||
return self.inline_generator_function(fn, args, kwargs)
|
||||
else:
|
||||
@ -2927,8 +2928,22 @@ class InstructionTranslatorBase(
|
||||
hints=[*graph_break_hints.USER_ERROR],
|
||||
)
|
||||
|
||||
@break_graph_if_unsupported(push=0)
|
||||
def graph_break_on_leaf_function(self, inst: Instruction) -> None:
|
||||
if self.is_leaf_tracer:
|
||||
unimplemented_v2(
|
||||
gb_type="Forced graph break on leaf function",
|
||||
context="",
|
||||
explanation="Forced graph break for nested graph break testing purposes",
|
||||
hints=[
|
||||
"Set torch._dynamo.config.debug_force_graph_break_on_leaf_return = False",
|
||||
],
|
||||
)
|
||||
|
||||
def NOP(self, inst: Instruction) -> None:
|
||||
pass
|
||||
# Dynamo-specific testing behavior
|
||||
if inst.argval == "GRAPH_BREAK_IF_LEAF":
|
||||
self.graph_break_on_leaf_function(inst)
|
||||
|
||||
def POP_TOP(self, inst: Instruction) -> None:
|
||||
self.pop()
|
||||
|
@ -101,6 +101,18 @@ class TestCase(TorchTestCase):
|
||||
log.warning("Running test changed grad mode")
|
||||
torch.set_grad_enabled(self._prior_is_grad_enabled)
|
||||
|
||||
def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
|
||||
if (
|
||||
config.debug_disable_compile_counter
|
||||
and isinstance(x, utils.CompileCounterInt)
|
||||
or isinstance(y, utils.CompileCounterInt)
|
||||
):
|
||||
return
|
||||
return super().assertEqual(x, y, *args, **kwargs)
|
||||
|
||||
# assertExpectedInline might also need to be disabled for wrapped nested
|
||||
# graph break tests
|
||||
|
||||
|
||||
class CPythonTestCase(TestCase):
|
||||
"""
|
||||
|
@ -42,7 +42,7 @@ from .bytecode_transformation import (
|
||||
)
|
||||
from .guards import CheckFunctionManager, CompileId, GuardedCode
|
||||
from .types import ConvertFrameReturn, DynamoFrameType, wrap_guarded_code
|
||||
from .utils import same
|
||||
from .utils import CompileCounterInt, same
|
||||
|
||||
|
||||
np: Optional[types.ModuleType] = None
|
||||
@ -227,8 +227,8 @@ def debug_insert_nops(
|
||||
|
||||
class CompileCounter:
|
||||
def __init__(self) -> None:
|
||||
self.frame_count = 0
|
||||
self.op_count = 0
|
||||
self.frame_count: Union[int, CompileCounterInt] = 0
|
||||
self.clear()
|
||||
|
||||
def __call__(
|
||||
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
@ -240,16 +240,19 @@ class CompileCounter:
|
||||
return gm.forward
|
||||
|
||||
def clear(self) -> None:
|
||||
self.frame_count = 0
|
||||
if config.debug_disable_compile_counter:
|
||||
self.frame_count = CompileCounterInt(0)
|
||||
else:
|
||||
self.frame_count = 0
|
||||
self.op_count = 0
|
||||
|
||||
|
||||
class CompileCounterWithBackend:
|
||||
def __init__(self, backend: str) -> None:
|
||||
self.frame_count = 0
|
||||
self.op_count = 0
|
||||
self.frame_count: Union[int, CompileCounterInt] = 0
|
||||
self.backend = backend
|
||||
self.graphs: list[torch.fx.GraphModule] = []
|
||||
self.clear()
|
||||
|
||||
def __call__(
|
||||
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
@ -264,7 +267,10 @@ class CompileCounterWithBackend:
|
||||
return lookup_backend(self.backend)(gm, example_inputs)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.frame_count = 0
|
||||
if config.debug_disable_compile_counter:
|
||||
self.frame_count = CompileCounterInt(0)
|
||||
else:
|
||||
self.frame_count = 0
|
||||
self.op_count = 0
|
||||
self.graphs = []
|
||||
|
||||
|
@ -3403,6 +3403,7 @@ MOD_INLINELIST = [
|
||||
"torch._dynamo.compiled_autograd",
|
||||
"torch._dynamo.comptime",
|
||||
"torch._dynamo.polyfills",
|
||||
"torch._dynamo.test_case",
|
||||
"torch._functorch._aot_autograd.subclass_parametrization",
|
||||
"torch._functorch.autograd_function",
|
||||
"torch._functorch.eager_transforms",
|
||||
|
@ -4727,6 +4727,11 @@ class CompileTimeInstructionCounter:
|
||||
cls.end()
|
||||
|
||||
|
||||
class CompileCounterInt(int):
|
||||
def __add__(self, other: Any) -> CompileCounterInt:
|
||||
return CompileCounterInt(super().__add__(other))
|
||||
|
||||
|
||||
def set_feature_use(feature: str, usage: bool) -> None:
|
||||
"""
|
||||
Records whether we are using a feature
|
||||
|
Reference in New Issue
Block a user