mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[dynamo, nested graph breaks] support nested closures (#159817)"
This reverts commit ef0ef6f93f7ef6d16d71a6997b72185504acd4b6. Reverted https://github.com/pytorch/pytorch/pull/159817 on behalf of https://github.com/atalman due to failing internal tests ([comment](https://github.com/pytorch/pytorch/pull/159817#issuecomment-3225586996))
This commit is contained in:
@ -296,6 +296,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_cells(self):
|
||||
def f1(x1):
|
||||
cell1 = x1 + 1
|
||||
@ -328,6 +329,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_side_effects_cells(self):
|
||||
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4
|
||||
|
||||
|
@ -1206,6 +1206,7 @@ def add_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> Non
|
||||
create_instruction("NOP", argval="GRAPH_BREAK_IF_LEAF"),
|
||||
create_instruction(inst.opname, argval=inst.argval),
|
||||
]
|
||||
# breakpoint()
|
||||
new_insts.extend(overwrite_instruction(inst, replace_insts))
|
||||
else:
|
||||
new_insts.append(inst)
|
||||
|
@ -536,31 +536,20 @@ class PyCodegen:
|
||||
self.append_output(self.create_load_deref(varname))
|
||||
|
||||
def make_function_with_closure(
|
||||
self,
|
||||
tx: "InstructionTranslatorBase",
|
||||
fn_name: str,
|
||||
code: types.CodeType,
|
||||
push_null: bool,
|
||||
num_on_stack: int = 0,
|
||||
self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack: int = 0
|
||||
) -> None:
|
||||
freevars = code.co_freevars
|
||||
assert freevars
|
||||
output = self._output
|
||||
|
||||
def gen_fn() -> None:
|
||||
self.clear_tos()
|
||||
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
|
||||
# requires that in the generated bytecode, these cells would keep
|
||||
# their original local names, which we ensure via
|
||||
# `CellVariable.local_name`.
|
||||
for var in freevars:
|
||||
if tx is self.tx: # root frame
|
||||
assert var in self.cell_and_freevars()
|
||||
output.append(self.create_load_closure(var))
|
||||
else: # nested frame
|
||||
assert var in tx.cell_and_freevars()
|
||||
assert tx.post_prune_cell_and_freevars
|
||||
self(tx.post_prune_cell_and_freevars[var])
|
||||
assert var in self.cell_and_freevars()
|
||||
output.append(self.create_load_closure(var))
|
||||
output.append(create_instruction("BUILD_TUPLE", arg=len(freevars)))
|
||||
output.append(self.create_load_const(code))
|
||||
if sys.version_info < (3, 11):
|
||||
|
@ -1330,8 +1330,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
if inst.opname == "COPY_FREE_VARS":
|
||||
prefix_insts.append(
|
||||
create_instruction(
|
||||
"COPY_FREE_VARS",
|
||||
arg=len(self.root_tx.code_options["co_freevars"]),
|
||||
"COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"])
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -1356,9 +1355,6 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
break
|
||||
cur_tx = cur_tx.parent
|
||||
|
||||
# "Garbage collect the heap".
|
||||
self.side_effects.prune_dead_object_new(tx)
|
||||
|
||||
self.add_output_instructions(prefix_insts)
|
||||
|
||||
assert not (self.pregraph_bytecode and self.export), (
|
||||
|
@ -617,21 +617,16 @@ class SideEffects:
|
||||
# The only live side effects come from returns (tx.stack), any intermediates
|
||||
# during a graph break (tx.symbolic_locals), and mutation on pre-existing variables.
|
||||
# Recursively visit Variables and see if any of them have been mutated.
|
||||
init_live_vars = []
|
||||
# gather stack/symbolic_locals for all tx's up the chain
|
||||
cur_tx: Optional[InstructionTranslatorBase] = tx
|
||||
while cur_tx is not None:
|
||||
init_live_vars.extend([cur_tx.stack, cur_tx.symbolic_locals])
|
||||
cur_tx = cur_tx.parent
|
||||
VariableTracker.visit(
|
||||
visit,
|
||||
# TODO track from all possible sources.
|
||||
init_live_vars
|
||||
+ [
|
||||
(
|
||||
tx.stack,
|
||||
tx.symbolic_locals,
|
||||
pre_existing_vars,
|
||||
tx.output.backward_state,
|
||||
self.tensor_hooks,
|
||||
],
|
||||
),
|
||||
)
|
||||
# Manually release the self-referential function, which indirectly
|
||||
# captures certain `VariableTracker` and affects parts of PT test/logic
|
||||
|
@ -1151,7 +1151,6 @@ class InstructionTranslatorBase(
|
||||
symbolic_locals: dict[str, VariableTracker]
|
||||
symbolic_globals: dict[str, VariableTracker]
|
||||
symbolic_torch_function_state: SymbolicTorchFunctionState
|
||||
post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]]
|
||||
stack: list[VariableTracker]
|
||||
instruction_pointer: Optional[int]
|
||||
current_instruction: Instruction
|
||||
@ -1224,17 +1223,13 @@ class InstructionTranslatorBase(
|
||||
return self._cell_and_freevars
|
||||
|
||||
def prune_dead_locals(self) -> None:
|
||||
# keep cell and freevar references alive
|
||||
self.post_prune_cell_and_freevars = {
|
||||
k: v
|
||||
for k, v in self.symbolic_locals.items()
|
||||
if k in self.cell_and_freevars()
|
||||
}
|
||||
# Only keep the locals that must remain on the stack.
|
||||
reads = livevars_analysis(self.instructions, self.current_instruction)
|
||||
self.symbolic_locals = {
|
||||
k: v for k, v in self.symbolic_locals.items() if k in reads
|
||||
}
|
||||
# "Garbage collect the heap".
|
||||
self.output.side_effects.prune_dead_object_new(self)
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
@ -2659,18 +2654,17 @@ class InstructionTranslatorBase(
|
||||
|
||||
# load first resume function (to be called this frame)
|
||||
if resume_codes[-1].co_freevars:
|
||||
cg.make_function_with_closure(
|
||||
txes[-1], resume_names[-1], resume_codes[-1], True, 1
|
||||
)
|
||||
cg.make_function_with_closure(resume_names[-1], resume_codes[-1], True, 1)
|
||||
else:
|
||||
cg.extend_output(cg.load_function_name(resume_names[-1], True, 1))
|
||||
|
||||
# load all other resume functions (to be called later)
|
||||
resume_names.pop()
|
||||
resume_codes.pop()
|
||||
for tx, name, code in zip(txes, resume_names, resume_codes):
|
||||
for name, code in zip(resume_names, resume_codes):
|
||||
if code.co_freevars:
|
||||
cg.make_function_with_closure(tx, name, code, False, 0)
|
||||
assert not config.nested_graph_breaks, "NYI"
|
||||
cg.make_function_with_closure(name, code, False, 0)
|
||||
else:
|
||||
cg.extend_output(cg.load_function_name(name, False, 0))
|
||||
cg.extend_output(
|
||||
@ -3660,9 +3654,6 @@ class InstructionTranslatorBase(
|
||||
self.symbolic_locals = symbolic_locals
|
||||
self.symbolic_globals = symbolic_globals
|
||||
self.symbolic_torch_function_state = symbolic_torch_function_state
|
||||
# used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals)
|
||||
# in order to generate any nested closures
|
||||
self.post_prune_cell_and_freevars = None
|
||||
self.stack: list[VariableTracker] = []
|
||||
self.instruction_pointer = 0
|
||||
self.start_point = None
|
||||
|
Reference in New Issue
Block a user