Revert "[dynamo, nested graph breaks] support nested closures (#159817)"

This reverts commit ef0ef6f93f7ef6d16d71a6997b72185504acd4b6.

Reverted https://github.com/pytorch/pytorch/pull/159817 on behalf of https://github.com/atalman due to failing internal tests ([comment](https://github.com/pytorch/pytorch/pull/159817#issuecomment-3225586996))
This commit is contained in:
PyTorch MergeBot
2025-08-26 20:13:33 +00:00
parent 9f6e1b8730
commit a7aa480e55
6 changed files with 17 additions and 43 deletions

View File

@ -296,6 +296,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
@unittest.expectedFailure
def test_cells(self):
def f1(x1):
cell1 = x1 + 1
@ -328,6 +329,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_side_effects_cells(self):
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4

View File

@ -1206,6 +1206,7 @@ def add_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> Non
create_instruction("NOP", argval="GRAPH_BREAK_IF_LEAF"),
create_instruction(inst.opname, argval=inst.argval),
]
# breakpoint()
new_insts.extend(overwrite_instruction(inst, replace_insts))
else:
new_insts.append(inst)

View File

@ -536,31 +536,20 @@ class PyCodegen:
self.append_output(self.create_load_deref(varname))
def make_function_with_closure(
self,
tx: "InstructionTranslatorBase",
fn_name: str,
code: types.CodeType,
push_null: bool,
num_on_stack: int = 0,
self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack: int = 0
) -> None:
freevars = code.co_freevars
assert freevars
output = self._output
def gen_fn() -> None:
self.clear_tos()
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
# requires that in the generated bytecode, these cells would keep
# their original local names, which we ensure via
# `CellVariable.local_name`.
for var in freevars:
if tx is self.tx: # root frame
assert var in self.cell_and_freevars()
output.append(self.create_load_closure(var))
else: # nested frame
assert var in tx.cell_and_freevars()
assert tx.post_prune_cell_and_freevars
self(tx.post_prune_cell_and_freevars[var])
assert var in self.cell_and_freevars()
output.append(self.create_load_closure(var))
output.append(create_instruction("BUILD_TUPLE", arg=len(freevars)))
output.append(self.create_load_const(code))
if sys.version_info < (3, 11):

View File

@ -1330,8 +1330,7 @@ class OutputGraph(OutputGraphGuardsState):
if inst.opname == "COPY_FREE_VARS":
prefix_insts.append(
create_instruction(
"COPY_FREE_VARS",
arg=len(self.root_tx.code_options["co_freevars"]),
"COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"])
)
)
else:
@ -1356,9 +1355,6 @@ class OutputGraph(OutputGraphGuardsState):
break
cur_tx = cur_tx.parent
# "Garbage collect the heap".
self.side_effects.prune_dead_object_new(tx)
self.add_output_instructions(prefix_insts)
assert not (self.pregraph_bytecode and self.export), (

View File

@ -617,21 +617,16 @@ class SideEffects:
# The only live side effects come from returns (tx.stack), any intermediates
# during a graph break (tx.symbolic_locals), and mutation on pre-existing variables.
# Recursively visit Variables and see if any of them have been mutated.
init_live_vars = []
# gather stack/symbolic_locals for all tx's up the chain
cur_tx: Optional[InstructionTranslatorBase] = tx
while cur_tx is not None:
init_live_vars.extend([cur_tx.stack, cur_tx.symbolic_locals])
cur_tx = cur_tx.parent
VariableTracker.visit(
visit,
# TODO track from all possible sources.
init_live_vars
+ [
(
tx.stack,
tx.symbolic_locals,
pre_existing_vars,
tx.output.backward_state,
self.tensor_hooks,
],
),
)
# Manually release the self-referential function, which indirectly
# captures certain `VariableTracker` and affects parts of PT test/logic

View File

@ -1151,7 +1151,6 @@ class InstructionTranslatorBase(
symbolic_locals: dict[str, VariableTracker]
symbolic_globals: dict[str, VariableTracker]
symbolic_torch_function_state: SymbolicTorchFunctionState
post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]]
stack: list[VariableTracker]
instruction_pointer: Optional[int]
current_instruction: Instruction
@ -1224,17 +1223,13 @@ class InstructionTranslatorBase(
return self._cell_and_freevars
def prune_dead_locals(self) -> None:
# keep cell and freevar references alive
self.post_prune_cell_and_freevars = {
k: v
for k, v in self.symbolic_locals.items()
if k in self.cell_and_freevars()
}
# Only keep the locals that must remain on the stack.
reads = livevars_analysis(self.instructions, self.current_instruction)
self.symbolic_locals = {
k: v for k, v in self.symbolic_locals.items() if k in reads
}
# "Garbage collect the heap".
self.output.side_effects.prune_dead_object_new(self)
def call_function(
self,
@ -2659,18 +2654,17 @@ class InstructionTranslatorBase(
# load first resume function (to be called this frame)
if resume_codes[-1].co_freevars:
cg.make_function_with_closure(
txes[-1], resume_names[-1], resume_codes[-1], True, 1
)
cg.make_function_with_closure(resume_names[-1], resume_codes[-1], True, 1)
else:
cg.extend_output(cg.load_function_name(resume_names[-1], True, 1))
# load all other resume functions (to be called later)
resume_names.pop()
resume_codes.pop()
for tx, name, code in zip(txes, resume_names, resume_codes):
for name, code in zip(resume_names, resume_codes):
if code.co_freevars:
cg.make_function_with_closure(tx, name, code, False, 0)
assert not config.nested_graph_breaks, "NYI"
cg.make_function_with_closure(name, code, False, 0)
else:
cg.extend_output(cg.load_function_name(name, False, 0))
cg.extend_output(
@ -3660,9 +3654,6 @@ class InstructionTranslatorBase(
self.symbolic_locals = symbolic_locals
self.symbolic_globals = symbolic_globals
self.symbolic_torch_function_state = symbolic_torch_function_state
# used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals)
# in order to generate any nested closures
self.post_prune_cell_and_freevars = None
self.stack: list[VariableTracker] = []
self.instruction_pointer = 0
self.start_point = None