mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo, nested graph breaks] add nested graph break tests (#144516)
Note: nested graph break tests (and wrapped tests) are xfailed/skipped for now - we will iteratively enable the tests as more of the nested graph break implementation is complete. Differential Revision: [D81084809](https://our.internmc.facebook.com/intern/diff/D81084809) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144516 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
b36a20d368
commit
8b78ba07b1
18
test/dynamo/_test_nested_graph_breaks_helper.py
Normal file
18
test/dynamo/_test_nested_graph_breaks_helper.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
global1 = torch.ones(3)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_state():
|
||||||
|
global global1
|
||||||
|
global1 = torch.ones(3)
|
||||||
|
|
||||||
|
|
||||||
|
def fn(val, call):
|
||||||
|
global global1
|
||||||
|
global1 += 1
|
||||||
|
val = val + global1
|
||||||
|
val = call(val)
|
||||||
|
val = val + 1
|
||||||
|
return val
|
424
test/dynamo/test_nested_graph_breaks.py
Normal file
424
test/dynamo/test_nested_graph_breaks.py
Normal file
@ -0,0 +1,424 @@
|
|||||||
|
# Owner(s): ["module: dynamo"]
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch._dynamo.test_case
|
||||||
|
import torch._dynamo.testing
|
||||||
|
from torch._dynamo import config
|
||||||
|
from torch._dynamo.testing import make_test_cls_with_patches
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# from . import test_ctx_manager
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
# import test_aot_autograd
|
||||||
|
# import test_ctx_manager
|
||||||
|
|
||||||
|
# import test_export
|
||||||
|
# import test_functions
|
||||||
|
# import test_higher_order_ops
|
||||||
|
# import test_misc
|
||||||
|
# import test_modules
|
||||||
|
# import test_repros
|
||||||
|
# import test_sdpa
|
||||||
|
# import test_subgraphs
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
test_classes = {}
|
||||||
|
|
||||||
|
|
||||||
|
def make_nested_cls(cls):
|
||||||
|
suffix = "_nested_graph_breaks"
|
||||||
|
|
||||||
|
cls_prefix = "NestedGraphBreaks"
|
||||||
|
|
||||||
|
test_class = make_test_cls_with_patches(
|
||||||
|
cls,
|
||||||
|
cls_prefix,
|
||||||
|
suffix,
|
||||||
|
(config, "debug_force_nested_calls", True),
|
||||||
|
(config, "debug_force_graph_break_on_leaf_return", True),
|
||||||
|
(config, "debug_disable_compile_counter", True),
|
||||||
|
xfail_prop="_expected_failure_nested_graph_breaks",
|
||||||
|
)
|
||||||
|
|
||||||
|
test_classes[test_class.__name__] = test_class
|
||||||
|
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||||
|
# globals()[test_class.__name__] = test_class
|
||||||
|
test_class.__module__ = __name__
|
||||||
|
return test_class
|
||||||
|
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
# test_ctx_manager.CtxManagerTests,
|
||||||
|
# test_functions.FunctionTests,
|
||||||
|
# test_misc.MiscTests,
|
||||||
|
# test_repros.ReproTests,
|
||||||
|
# test_modules.NNModuleTests,
|
||||||
|
# test_subgraphs.SubGraphTests,
|
||||||
|
# test_higher_order_ops.HigherOrderOpTests,
|
||||||
|
# test_higher_order_ops.FuncTorchHigherOrderOpTests,
|
||||||
|
# test_aot_autograd.AotAutogradFallbackTests,
|
||||||
|
# test_sdpa.TestSDPA,
|
||||||
|
]
|
||||||
|
test = None
|
||||||
|
for test in tests:
|
||||||
|
make_nested_cls(test)
|
||||||
|
del test
|
||||||
|
|
||||||
|
global_val = 0
|
||||||
|
|
||||||
|
|
||||||
|
class CustomizedCtxManager:
|
||||||
|
def __init__(self, val):
|
||||||
|
self.val = val
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
global global_val
|
||||||
|
global_val += self.val
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
global global_val
|
||||||
|
global_val -= self.val
|
||||||
|
|
||||||
|
|
||||||
|
# for use in test_side_effects_globals
|
||||||
|
global1, global2, global3, global4 = (torch.zeros(3),) * 4
|
||||||
|
|
||||||
|
|
||||||
|
class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
torch._dynamo.config.nested_graph_breaks = True
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
torch._dynamo.config.nested_graph_breaks = False
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_single_graph_break(self):
|
||||||
|
def f1(x1):
|
||||||
|
x1 = x1 + 1
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x1 + 2
|
||||||
|
|
||||||
|
def f2(x2):
|
||||||
|
return f1(x2 + 4) + 8
|
||||||
|
|
||||||
|
def f3(x3):
|
||||||
|
return f2(x3 + 16) + 32
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||||
|
x = torch.zeros(3)
|
||||||
|
res = f3(x)
|
||||||
|
ref = opt_fn(x)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_single_graph_break_repeat(self):
|
||||||
|
def f1(x1):
|
||||||
|
x1 = x1 + 1
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x1 + 2
|
||||||
|
|
||||||
|
def f2(x2):
|
||||||
|
tmp1 = f1(x2 + 4)
|
||||||
|
tmp2 = f1(x2 + 8) << 4
|
||||||
|
return tmp1 + tmp2
|
||||||
|
|
||||||
|
def f3(x3):
|
||||||
|
return f2(x3 + 256) + 512
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||||
|
x = torch.zeros(3, dtype=torch.long)
|
||||||
|
res = f3(x)
|
||||||
|
ref = opt_fn(x)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
self.assertEqual(cnts.frame_count, 3)
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_doubly_nested_graph_break(self):
|
||||||
|
def f1(x1):
|
||||||
|
x1 = x1 + 1
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x1 + 2
|
||||||
|
|
||||||
|
def f2(x2):
|
||||||
|
x2 = x2 + 4
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return f1(x2 + 8) + 16
|
||||||
|
|
||||||
|
def f3(x3):
|
||||||
|
return f2(x3 + 32) + 64
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||||
|
x = torch.zeros(3)
|
||||||
|
res = f3(x)
|
||||||
|
ref = opt_fn(x)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
self.assertEqual(cnts.frame_count, 3)
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_differing_arg_nums(self):
|
||||||
|
def f1(x1, x2):
|
||||||
|
x = x1 + x2
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
def f2(x3, x4, x5, x6):
|
||||||
|
return f1(x3 + x4, x5 + x6) + 2
|
||||||
|
|
||||||
|
def f3(x7, x8):
|
||||||
|
return f2(x7, x7 + 4, x8, x8 + 8) + 16
|
||||||
|
|
||||||
|
def f4(x9):
|
||||||
|
return f3(x9, x9 + 32) + 64
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f4)
|
||||||
|
x = torch.zeros(3)
|
||||||
|
res = f4(x)
|
||||||
|
ref = opt_fn(x)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_differing_locals_nums(self):
|
||||||
|
def f1(x1):
|
||||||
|
loc1 = x1 + 1
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return loc1 + 2
|
||||||
|
|
||||||
|
def f2(x2):
|
||||||
|
loc1 = x2 + 4
|
||||||
|
loc2 = x2 + 8
|
||||||
|
return f1(x2) + loc1 + loc2
|
||||||
|
|
||||||
|
def f3(x3):
|
||||||
|
loc1 = x3 + 16
|
||||||
|
loc2 = x3 + 32
|
||||||
|
loc3 = x3 + 64
|
||||||
|
loc4 = x3 + 128
|
||||||
|
return f2(x3) + loc1 + loc2 + loc3 + loc4
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||||
|
x = torch.zeros(3)
|
||||||
|
res = f3(x)
|
||||||
|
ref = opt_fn(x)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_ctx_manager(self):
|
||||||
|
global global_val
|
||||||
|
global_val = 0
|
||||||
|
|
||||||
|
@torch._dynamo.disable
|
||||||
|
def f1():
|
||||||
|
return global_val
|
||||||
|
|
||||||
|
def f2(x2):
|
||||||
|
with CustomizedCtxManager(8):
|
||||||
|
x2 = x2 + (1 << 4)
|
||||||
|
x2 = x2 + f1() # 15
|
||||||
|
x2 = x2 + (1 << 5)
|
||||||
|
x2 = x2 << 2
|
||||||
|
x2 = x2 + global_val # 3
|
||||||
|
with CustomizedCtxManager(4):
|
||||||
|
x2 = x2 << 4
|
||||||
|
x2 = x2 + f1() # 7
|
||||||
|
x2 = x2 + (1 << 3)
|
||||||
|
return x2
|
||||||
|
|
||||||
|
def f3(x3):
|
||||||
|
with CustomizedCtxManager(2):
|
||||||
|
return f2(x3)
|
||||||
|
|
||||||
|
def f4(x4):
|
||||||
|
with CustomizedCtxManager(1):
|
||||||
|
return f3(x4)
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f4)
|
||||||
|
x = torch.zeros(3, dtype=torch.long)
|
||||||
|
res = f4(x)
|
||||||
|
ref = opt_fn(x)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
self.assertEqual(cnts.frame_count, 3)
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_cells(self):
|
||||||
|
def f1(x1):
|
||||||
|
cell1 = x1 + 1
|
||||||
|
cell2 = x1 + 2
|
||||||
|
|
||||||
|
def f2(x2, x3):
|
||||||
|
nonlocal cell1
|
||||||
|
cell3 = x2 + x3 + 4
|
||||||
|
cell1 += 8
|
||||||
|
|
||||||
|
def f3(x4):
|
||||||
|
nonlocal cell2, cell3
|
||||||
|
cell2 += 16
|
||||||
|
cell3 += 32
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x4 + cell1 + cell2 + cell3
|
||||||
|
|
||||||
|
return f3(x2 + x3), cell3
|
||||||
|
|
||||||
|
return f2(x1 + 64, x1 + 128) + (cell1, cell2)
|
||||||
|
|
||||||
|
def outer(x):
|
||||||
|
return f1(x)
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
|
||||||
|
x = torch.zeros(3)
|
||||||
|
res = outer(x)
|
||||||
|
ref = opt_fn(x)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_side_effects_cells(self):
|
||||||
|
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4
|
||||||
|
|
||||||
|
def f1():
|
||||||
|
nonlocal cell1
|
||||||
|
cell1 += 1
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return cell1 + cell2
|
||||||
|
|
||||||
|
def f2():
|
||||||
|
nonlocal cell3
|
||||||
|
cell3 += 2
|
||||||
|
return f1() + cell3 + cell4
|
||||||
|
|
||||||
|
def f3():
|
||||||
|
return f2()
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||||
|
|
||||||
|
cell1 = torch.zeros(3)
|
||||||
|
cell2 = torch.zeros(3) + 4
|
||||||
|
cell3 = torch.zeros(3)
|
||||||
|
cell4 = torch.zeros(3) + 8
|
||||||
|
res = f3()
|
||||||
|
res = (res,) + tuple(x.clone() for x in (cell1, cell2, cell3, cell4))
|
||||||
|
|
||||||
|
cell1 = torch.zeros(3)
|
||||||
|
cell2 = torch.zeros(3) + 4
|
||||||
|
cell3 = torch.zeros(3)
|
||||||
|
cell4 = torch.zeros(3) + 8
|
||||||
|
ref = opt_fn()
|
||||||
|
ref = (ref,) + tuple(x.clone() for x in (cell1, cell2, cell3, cell4))
|
||||||
|
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_side_effects_globals(self):
|
||||||
|
global global1, global2, global3, global4
|
||||||
|
|
||||||
|
def f1():
|
||||||
|
global global1
|
||||||
|
global1 += 1
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return global1 + global2
|
||||||
|
|
||||||
|
def f2():
|
||||||
|
global global3
|
||||||
|
global3 += 2
|
||||||
|
return f1() + global3 + global4
|
||||||
|
|
||||||
|
def f3(x):
|
||||||
|
return x + f2()
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||||
|
x = torch.ones(3)
|
||||||
|
|
||||||
|
global1 = torch.zeros(3)
|
||||||
|
global2 = torch.zeros(3) + 4
|
||||||
|
global3 = torch.zeros(3)
|
||||||
|
global4 = torch.zeros(3) + 8
|
||||||
|
res = (f3(x), global1.clone(), global2, global3.clone(), global4)
|
||||||
|
|
||||||
|
global1 = torch.zeros(3)
|
||||||
|
global2 = torch.zeros(3) + 4
|
||||||
|
global3 = torch.zeros(3)
|
||||||
|
global4 = torch.zeros(3) + 8
|
||||||
|
ref = (opt_fn(x), global1.clone(), global2, global3.clone(), global4)
|
||||||
|
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_side_effects_globals_different_module(self):
|
||||||
|
try:
|
||||||
|
from . import _test_nested_graph_breaks_helper
|
||||||
|
except ImportError:
|
||||||
|
import _test_nested_graph_breaks_helper
|
||||||
|
|
||||||
|
def f1(x):
|
||||||
|
x = x + 1
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
def f2(x):
|
||||||
|
x = x + 1
|
||||||
|
x = _test_nested_graph_breaks_helper.fn(x, f1)
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f2)
|
||||||
|
|
||||||
|
_test_nested_graph_breaks_helper.reset_state()
|
||||||
|
x = torch.zeros(3)
|
||||||
|
res = (f2(x), _test_nested_graph_breaks_helper.global1.clone())
|
||||||
|
|
||||||
|
_test_nested_graph_breaks_helper.reset_state()
|
||||||
|
ref = (opt_fn(x), _test_nested_graph_breaks_helper.global1.clone())
|
||||||
|
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_nested_graph_break_in_loop(self):
|
||||||
|
def f1(x, i):
|
||||||
|
if i == 5:
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
def f2(x):
|
||||||
|
for i in range(8):
|
||||||
|
x = f1(x, i)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def f3(x):
|
||||||
|
x = x + 1
|
||||||
|
x = f2(x)
|
||||||
|
x = x + 1
|
||||||
|
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||||
|
x = torch.zeros(3)
|
||||||
|
res = f3(x)
|
||||||
|
ref = opt_fn(x)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
# skip frame due to nested graph break in for loop
|
||||||
|
self.assertEqual(cnts.frame_count, 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
||||||
|
run_tests()
|
@ -25,6 +25,7 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|||||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from ..utils._backport_slots import dataclass_slots
|
from ..utils._backport_slots import dataclass_slots
|
||||||
|
from . import config
|
||||||
from .bytecode_analysis import (
|
from .bytecode_analysis import (
|
||||||
get_indexof,
|
get_indexof,
|
||||||
propagate_line_nums,
|
propagate_line_nums,
|
||||||
@ -1200,6 +1201,50 @@ def remove_fused_load_store(instructions: list[Instruction]) -> None:
|
|||||||
instructions[:] = new_insts
|
instructions[:] = new_insts
|
||||||
|
|
||||||
|
|
||||||
|
# adds GRAPH_BREAK_IF_LEAF (not a real instruction) before RETURN_* instructions
|
||||||
|
# for testing purposes
|
||||||
|
def add_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> None:
|
||||||
|
new_insts = []
|
||||||
|
for inst in instructions:
|
||||||
|
if "RETURN" in inst.opname:
|
||||||
|
replace_insts = [
|
||||||
|
create_instruction("NOP", argval="GRAPH_BREAK_IF_LEAF"),
|
||||||
|
create_instruction(inst.opname, argval=inst.argval),
|
||||||
|
]
|
||||||
|
# breakpoint()
|
||||||
|
new_insts.extend(overwrite_instruction(inst, replace_insts))
|
||||||
|
else:
|
||||||
|
new_insts.append(inst)
|
||||||
|
instructions[:] = new_insts
|
||||||
|
|
||||||
|
|
||||||
|
def remove_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> None:
|
||||||
|
new_insts = []
|
||||||
|
for inst, next_inst in zip(instructions, instructions[1:]):
|
||||||
|
if (
|
||||||
|
inst.opname == "NOP"
|
||||||
|
and inst.argval == "GRAPH_BREAK_IF_LEAF"
|
||||||
|
and next_inst.opname.startswith("RETURN")
|
||||||
|
):
|
||||||
|
# remove this instruction and update all other instructions' jump targets
|
||||||
|
for i in range(len(instructions)):
|
||||||
|
if instructions[i].target is inst:
|
||||||
|
instructions[i].target = next_inst
|
||||||
|
if instructions[i].exn_tab_entry:
|
||||||
|
# linter is mistakenly complaining that None has no attribute "..."
|
||||||
|
# but this codepath only runs if instructions[i] is not None
|
||||||
|
if instructions[i].exn_tab_entry.start is inst: # type: ignore[union-attr]
|
||||||
|
instructions[i].exn_tab_entry.start = next_inst # type: ignore[union-attr]
|
||||||
|
if instructions[i].exn_tab_entry.end is inst: # type: ignore[union-attr]
|
||||||
|
instructions[i].exn_tab_entry.end = next_inst # type: ignore[union-attr]
|
||||||
|
if instructions[i].exn_tab_entry.target is inst: # type: ignore[union-attr]
|
||||||
|
instructions[i].exn_tab_entry.target = next_inst # type: ignore[union-attr]
|
||||||
|
else:
|
||||||
|
new_insts.append(inst)
|
||||||
|
new_insts.append(instructions[-1])
|
||||||
|
instructions[:] = new_insts
|
||||||
|
|
||||||
|
|
||||||
def explicit_super(code: types.CodeType, instructions: list[Instruction]) -> None:
|
def explicit_super(code: types.CodeType, instructions: list[Instruction]) -> None:
|
||||||
"""convert super() with no args into explicit arg form"""
|
"""convert super() with no args into explicit arg form"""
|
||||||
cell_and_free = (code.co_cellvars or ()) + (code.co_freevars or ())
|
cell_and_free = (code.co_cellvars or ()) + (code.co_freevars or ())
|
||||||
@ -1521,6 +1566,7 @@ def transform_code_object(
|
|||||||
def clean_and_assemble_instructions(
|
def clean_and_assemble_instructions(
|
||||||
instructions: list[Instruction], keys: list[str], code_options: dict[str, Any]
|
instructions: list[Instruction], keys: list[str], code_options: dict[str, Any]
|
||||||
) -> tuple[list[Instruction], types.CodeType]:
|
) -> tuple[list[Instruction], types.CodeType]:
|
||||||
|
remove_graph_break_if_leaf_instructions(instructions)
|
||||||
# also implicitly checks for no duplicate instructions
|
# also implicitly checks for no duplicate instructions
|
||||||
check_inst_exn_tab_entries_valid(instructions)
|
check_inst_exn_tab_entries_valid(instructions)
|
||||||
|
|
||||||
@ -1636,6 +1682,8 @@ def _cached_cleaned_instructions(
|
|||||||
remove_binary_store_slice(instructions)
|
remove_binary_store_slice(instructions)
|
||||||
if sys.version_info >= (3, 13):
|
if sys.version_info >= (3, 13):
|
||||||
remove_fused_load_store(instructions)
|
remove_fused_load_store(instructions)
|
||||||
|
if config.debug_force_graph_break_on_leaf_return:
|
||||||
|
add_graph_break_if_leaf_instructions(instructions)
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
update_offsets(instructions)
|
update_offsets(instructions)
|
||||||
devirtualize_jumps(instructions)
|
devirtualize_jumps(instructions)
|
||||||
|
@ -481,6 +481,18 @@ issue_3_13_0_warning = True
|
|||||||
# traced FX graph is empty when RETURN_* is traced.
|
# traced FX graph is empty when RETURN_* is traced.
|
||||||
allow_empty_graphs = False
|
allow_empty_graphs = False
|
||||||
|
|
||||||
|
# Used for testing - forces all top-level functions to be nested when traced with Dynamo
|
||||||
|
debug_force_nested_calls = False
|
||||||
|
|
||||||
|
# Used for testing - forces a graph break when a function
|
||||||
|
# that doesn't make any Dynamo-inlined calls returns
|
||||||
|
debug_force_graph_break_on_leaf_return = False
|
||||||
|
|
||||||
|
# Used for testing - causes CompileCounter.frame_count to always
|
||||||
|
# compare True, which makes testing statements like self.assertEqual(CompileCounter.frame_count, n)
|
||||||
|
# always pass.
|
||||||
|
debug_disable_compile_counter = False
|
||||||
|
|
||||||
# When set, total compile time instruction count is recorded using
|
# When set, total compile time instruction count is recorded using
|
||||||
# torch._dynamo.utilsCompileTimeInstructionCounter.
|
# torch._dynamo.utilsCompileTimeInstructionCounter.
|
||||||
record_compile_time_instruction_count = False
|
record_compile_time_instruction_count = False
|
||||||
|
@ -36,6 +36,7 @@ import textwrap
|
|||||||
import threading
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -739,7 +740,9 @@ class _TorchDynamoContext:
|
|||||||
filename = inspect.getsourcefile(fn)
|
filename = inspect.getsourcefile(fn)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
filename = None
|
filename = None
|
||||||
if config.wrap_top_frame or (
|
if config.debug_force_nested_calls:
|
||||||
|
fn = external_utils.wrap_inline(fn)
|
||||||
|
elif config.wrap_top_frame or (
|
||||||
(filename is None or trace_rules.check(fn))
|
(filename is None or trace_rules.check(fn))
|
||||||
and (
|
and (
|
||||||
getattr(fn, "__name__", "")
|
getattr(fn, "__name__", "")
|
||||||
@ -1219,7 +1222,8 @@ def _optimize(
|
|||||||
),
|
),
|
||||||
hooks,
|
hooks,
|
||||||
backend_ctx_ctor,
|
backend_ctx_ctor,
|
||||||
error_on_graph_break=nopython,
|
error_on_graph_break=nopython
|
||||||
|
and not config.debug_force_graph_break_on_leaf_return,
|
||||||
dynamic=dynamic,
|
dynamic=dynamic,
|
||||||
compiler_config=(
|
compiler_config=(
|
||||||
backend.get_compiler_config()
|
backend.get_compiler_config()
|
||||||
@ -1760,6 +1764,9 @@ def export(
|
|||||||
|
|
||||||
Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
|
Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
|
||||||
"""
|
"""
|
||||||
|
if config.debug_force_graph_break_on_leaf_return:
|
||||||
|
raise unittest.SkipTest("Cannot force graph break on export")
|
||||||
|
|
||||||
if _log_export_usage:
|
if _log_export_usage:
|
||||||
log_export_usage(event="export.private_api", flags={"_dynamo"})
|
log_export_usage(event="export.private_api", flags={"_dynamo"})
|
||||||
|
|
||||||
|
@ -2690,5 +2690,15 @@
|
|||||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
],
|
||||||
|
"GB0269": [
|
||||||
|
{
|
||||||
|
"Gb_type": "Forced graph break on leaf function",
|
||||||
|
"Context": "",
|
||||||
|
"Explanation": "Forced graph break for nested graph break testing purposes",
|
||||||
|
"Hints": [
|
||||||
|
"Set torch._dynamo.config.debug_force_graph_break_on_leaf_return = False"
|
||||||
|
]
|
||||||
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -1267,6 +1267,7 @@ class InstructionTranslatorBase(
|
|||||||
"""
|
"""
|
||||||
A call to some user defined function by inlining it.
|
A call to some user defined function by inlining it.
|
||||||
"""
|
"""
|
||||||
|
self.is_leaf_tracer = False
|
||||||
if config.enable_faithful_generator_behavior and is_generator(fn.get_code()): # type: ignore[attr-defined]
|
if config.enable_faithful_generator_behavior and is_generator(fn.get_code()): # type: ignore[attr-defined]
|
||||||
return self.inline_generator_function(fn, args, kwargs)
|
return self.inline_generator_function(fn, args, kwargs)
|
||||||
else:
|
else:
|
||||||
@ -2927,8 +2928,22 @@ class InstructionTranslatorBase(
|
|||||||
hints=[*graph_break_hints.USER_ERROR],
|
hints=[*graph_break_hints.USER_ERROR],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@break_graph_if_unsupported(push=0)
|
||||||
|
def graph_break_on_leaf_function(self, inst: Instruction) -> None:
|
||||||
|
if self.is_leaf_tracer:
|
||||||
|
unimplemented_v2(
|
||||||
|
gb_type="Forced graph break on leaf function",
|
||||||
|
context="",
|
||||||
|
explanation="Forced graph break for nested graph break testing purposes",
|
||||||
|
hints=[
|
||||||
|
"Set torch._dynamo.config.debug_force_graph_break_on_leaf_return = False",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def NOP(self, inst: Instruction) -> None:
|
def NOP(self, inst: Instruction) -> None:
|
||||||
pass
|
# Dynamo-specific testing behavior
|
||||||
|
if inst.argval == "GRAPH_BREAK_IF_LEAF":
|
||||||
|
self.graph_break_on_leaf_function(inst)
|
||||||
|
|
||||||
def POP_TOP(self, inst: Instruction) -> None:
|
def POP_TOP(self, inst: Instruction) -> None:
|
||||||
self.pop()
|
self.pop()
|
||||||
|
@ -101,6 +101,18 @@ class TestCase(TorchTestCase):
|
|||||||
log.warning("Running test changed grad mode")
|
log.warning("Running test changed grad mode")
|
||||||
torch.set_grad_enabled(self._prior_is_grad_enabled)
|
torch.set_grad_enabled(self._prior_is_grad_enabled)
|
||||||
|
|
||||||
|
def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
|
||||||
|
if (
|
||||||
|
config.debug_disable_compile_counter
|
||||||
|
and isinstance(x, utils.CompileCounterInt)
|
||||||
|
or isinstance(y, utils.CompileCounterInt)
|
||||||
|
):
|
||||||
|
return
|
||||||
|
return super().assertEqual(x, y, *args, **kwargs)
|
||||||
|
|
||||||
|
# assertExpectedInline might also need to be disabled for wrapped nested
|
||||||
|
# graph break tests
|
||||||
|
|
||||||
|
|
||||||
class CPythonTestCase(TestCase):
|
class CPythonTestCase(TestCase):
|
||||||
"""
|
"""
|
||||||
|
@ -42,7 +42,7 @@ from .bytecode_transformation import (
|
|||||||
)
|
)
|
||||||
from .guards import CheckFunctionManager, CompileId, GuardedCode
|
from .guards import CheckFunctionManager, CompileId, GuardedCode
|
||||||
from .types import ConvertFrameReturn, DynamoFrameType, wrap_guarded_code
|
from .types import ConvertFrameReturn, DynamoFrameType, wrap_guarded_code
|
||||||
from .utils import same
|
from .utils import CompileCounterInt, same
|
||||||
|
|
||||||
|
|
||||||
np: Optional[types.ModuleType] = None
|
np: Optional[types.ModuleType] = None
|
||||||
@ -227,8 +227,8 @@ def debug_insert_nops(
|
|||||||
|
|
||||||
class CompileCounter:
|
class CompileCounter:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.frame_count = 0
|
self.frame_count: Union[int, CompileCounterInt] = 0
|
||||||
self.op_count = 0
|
self.clear()
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||||
@ -240,16 +240,19 @@ class CompileCounter:
|
|||||||
return gm.forward
|
return gm.forward
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.frame_count = 0
|
if config.debug_disable_compile_counter:
|
||||||
|
self.frame_count = CompileCounterInt(0)
|
||||||
|
else:
|
||||||
|
self.frame_count = 0
|
||||||
self.op_count = 0
|
self.op_count = 0
|
||||||
|
|
||||||
|
|
||||||
class CompileCounterWithBackend:
|
class CompileCounterWithBackend:
|
||||||
def __init__(self, backend: str) -> None:
|
def __init__(self, backend: str) -> None:
|
||||||
self.frame_count = 0
|
self.frame_count: Union[int, CompileCounterInt] = 0
|
||||||
self.op_count = 0
|
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
self.graphs: list[torch.fx.GraphModule] = []
|
self.graphs: list[torch.fx.GraphModule] = []
|
||||||
|
self.clear()
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||||
@ -264,7 +267,10 @@ class CompileCounterWithBackend:
|
|||||||
return lookup_backend(self.backend)(gm, example_inputs)
|
return lookup_backend(self.backend)(gm, example_inputs)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.frame_count = 0
|
if config.debug_disable_compile_counter:
|
||||||
|
self.frame_count = CompileCounterInt(0)
|
||||||
|
else:
|
||||||
|
self.frame_count = 0
|
||||||
self.op_count = 0
|
self.op_count = 0
|
||||||
self.graphs = []
|
self.graphs = []
|
||||||
|
|
||||||
|
@ -3404,6 +3404,7 @@ MOD_INLINELIST = [
|
|||||||
"torch._dynamo.compiled_autograd",
|
"torch._dynamo.compiled_autograd",
|
||||||
"torch._dynamo.comptime",
|
"torch._dynamo.comptime",
|
||||||
"torch._dynamo.polyfills",
|
"torch._dynamo.polyfills",
|
||||||
|
"torch._dynamo.test_case",
|
||||||
"torch._functorch._aot_autograd.subclass_parametrization",
|
"torch._functorch._aot_autograd.subclass_parametrization",
|
||||||
"torch._functorch.autograd_function",
|
"torch._functorch.autograd_function",
|
||||||
"torch._functorch.eager_transforms",
|
"torch._functorch.eager_transforms",
|
||||||
|
@ -4734,6 +4734,11 @@ class CompileTimeInstructionCounter:
|
|||||||
cls.end()
|
cls.end()
|
||||||
|
|
||||||
|
|
||||||
|
class CompileCounterInt(int):
|
||||||
|
def __add__(self, other: Any) -> CompileCounterInt:
|
||||||
|
return CompileCounterInt(super().__add__(other))
|
||||||
|
|
||||||
|
|
||||||
def set_feature_use(feature: str, usage: bool) -> None:
|
def set_feature_use(feature: str, usage: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Records whether we are using a feature
|
Records whether we are using a feature
|
||||||
|
Reference in New Issue
Block a user