[partitioners][ac][dynamic] Fix output signature of fwd with symints (#105771)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105771
Approved by: https://github.com/Chillee
This commit is contained in:
Animesh Jain
2023-07-21 17:03:02 -07:00
committed by PyTorch MergeBot
parent 0148db6765
commit 0b11da0ccb
2 changed files with 50 additions and 11 deletions

View File

@ -331,6 +331,36 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
body_function = getattr(cnt.graphs[0], wrap_node.args[0].name)
self.assertEqual(op_count(body_function), 2)
@requires_cuda()
def test_symints_location(self):
def gn(x, y):
return torch.matmul(x, torch.nn.functional.dropout(y, 0.5))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(gn, x, y)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
opt_fn = torch.compile(fn, backend=cnt)
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
args = (x, y)
expected = fn(*args)
result = opt_fn(*args)
x = torch.randn(5, 5, requires_grad=True)
y = torch.randn(5, 5, requires_grad=True)
args = (x, y)
expected = fn(*args)
result = opt_fn(*args)
self.assertEqual(result.shape, expected.shape)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(len(cnt.graphs), 2)
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
self.assertEqual(len(wrap_node.args), 3)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -130,7 +130,7 @@ def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs):
return fwd_outputs, bwd_outputs
def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes=(), *, num_fwd_outputs):
def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes, *, num_fwd_outputs):
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes))
@ -199,9 +199,11 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_s
saved_symbols |= new_symbols
# Update saved_sym_nodes that are now reordered to have all bindings
# at front
saved_sym_nodes = saved_sym_nodes_binding + saved_sym_nodes_derived
# Update saved_sym_nodes that are now reordered to have all bindings at
# front. This can also be used later on to figure out the position of saved
# sym nodes in the output of fwd graph.
saved_sym_nodes.clear()
saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived)
# Now, we re-generate the fwd/bwd graphs.
# NB: This might increase compilation time, but I doubt it matters
@ -480,7 +482,7 @@ def reordering_to_mimic_autograd_engine(gm):
return new_gm
def functionalize_rng_ops(joint_module, fw_module, bw_module):
def functionalize_rng_ops(joint_module, fw_module, bw_module, num_sym_nodes):
# During user-driven activation checkpointing, we have to ensure that a rng
# op in fwd yields the same output as the recomputed rng op in the bwd. To
# do this, we use functionalize wrappers to wrap the random ops and share
@ -491,7 +493,9 @@ def functionalize_rng_ops(joint_module, fw_module, bw_module):
# Step 2 - Modify the fwd pass such that
# 1) Replace rand with run_and_save_rng_state wrapper
# 2) Replace the users of the original op with the output[1] of this op.
# 3) Collect all the rng_state - output[0] of each op, and make them output nodes.
# 3) Collect all the rng_state - output[0] of each op, and make them
# output nodes. Special care needs to be taken here because fwd outputs
# has symints at the very end.
# Step 3 - Modify the bwd pass such that
# 1) Add the input nodes just before the tangents for the stashed rng states
# 2) Replace rand with run_with_save_rng_state wrappers
@ -574,11 +578,15 @@ def functionalize_rng_ops(joint_module, fw_module, bw_module):
bw_graph.erase_node(bw_node)
# Add the rng states in the output of the fwd graph
fw_output = [node for node in fw_module.graph.nodes if node.op == "output"][0]
outputs = fw_output.args[0] + fw_rng_state_outputs
# Add the rng states in the output of the fwd graph. AOT Autograd assumes
# that symints are at the end of forward graph outputs. So, insert the new
# rng states accordingly.
fw_output_node = [node for node in fw_module.graph.nodes if node.op == "output"][0]
fw_outputs = fw_output_node.args[0]
sym_node_start_idx = len(fw_outputs) - num_sym_nodes
outputs = fw_outputs[:sym_node_start_idx] + fw_rng_state_outputs + fw_outputs[sym_node_start_idx:]
fw_module.graph.output(outputs)
fw_module.graph.erase_node(fw_output)
fw_module.graph.erase_node(fw_output_node)
fw_module.recompile()
bw_module.recompile()
return fw_module, bw_module
@ -866,13 +874,14 @@ def min_cut_rematerialization_partition(
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes = list(filter(lambda n: is_sym_node(n), saved_values))
saved_values = list(filter(lambda n: not is_sym_node(n), saved_values))
# NB: saved_sym_nodes will be mutated to reflect the actual saved symbols
fw_module, bw_module = _extract_fwd_bwd_modules(
joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs)
if graph_has_recomputable_ops:
if graph_has_recomputable_rng_ops:
fw_module, bw_module = functionalize_rng_ops(
joint_module, fw_module, bw_module
joint_module, fw_module, bw_module, len(saved_sym_nodes)
)
bw_module = reordering_to_mimic_autograd_engine(bw_module)