mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[partitioners][ac][dynamic] Fix output signature of fwd with symints (#105771)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105771 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
0148db6765
commit
0b11da0ccb
@ -331,6 +331,36 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
||||
body_function = getattr(cnt.graphs[0], wrap_node.args[0].name)
|
||||
self.assertEqual(op_count(body_function), 2)
|
||||
|
||||
@requires_cuda()
|
||||
def test_symints_location(self):
|
||||
def gn(x, y):
|
||||
return torch.matmul(x, torch.nn.functional.dropout(y, 0.5))
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(gn, x, y)
|
||||
|
||||
backend = "aot_eager"
|
||||
cnt = CompileCounterWithBackend(backend)
|
||||
opt_fn = torch.compile(fn, backend=cnt)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
y = torch.randn(4, 4, requires_grad=True)
|
||||
args = (x, y)
|
||||
expected = fn(*args)
|
||||
result = opt_fn(*args)
|
||||
|
||||
x = torch.randn(5, 5, requires_grad=True)
|
||||
y = torch.randn(5, 5, requires_grad=True)
|
||||
args = (x, y)
|
||||
expected = fn(*args)
|
||||
result = opt_fn(*args)
|
||||
|
||||
self.assertEqual(result.shape, expected.shape)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
self.assertEqual(len(cnt.graphs), 2)
|
||||
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
|
||||
self.assertEqual(len(wrap_node.args), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -130,7 +130,7 @@ def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs):
|
||||
return fwd_outputs, bwd_outputs
|
||||
|
||||
|
||||
def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes=(), *, num_fwd_outputs):
|
||||
def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes, *, num_fwd_outputs):
|
||||
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes))
|
||||
@ -199,9 +199,11 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_s
|
||||
saved_symbols |= new_symbols
|
||||
|
||||
|
||||
# Update saved_sym_nodes that are now reordered to have all bindings
|
||||
# at front
|
||||
saved_sym_nodes = saved_sym_nodes_binding + saved_sym_nodes_derived
|
||||
# Update saved_sym_nodes that are now reordered to have all bindings at
|
||||
# front. This can also be used later on to figure out the position of saved
|
||||
# sym nodes in the output of fwd graph.
|
||||
saved_sym_nodes.clear()
|
||||
saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived)
|
||||
|
||||
# Now, we re-generate the fwd/bwd graphs.
|
||||
# NB: This might increase compilation time, but I doubt it matters
|
||||
@ -480,7 +482,7 @@ def reordering_to_mimic_autograd_engine(gm):
|
||||
return new_gm
|
||||
|
||||
|
||||
def functionalize_rng_ops(joint_module, fw_module, bw_module):
|
||||
def functionalize_rng_ops(joint_module, fw_module, bw_module, num_sym_nodes):
|
||||
# During user-driven activation checkpointing, we have to ensure that a rng
|
||||
# op in fwd yields the same output as the recomputed rng op in the bwd. To
|
||||
# do this, we use functionalize wrappers to wrap the random ops and share
|
||||
@ -491,7 +493,9 @@ def functionalize_rng_ops(joint_module, fw_module, bw_module):
|
||||
# Step 2 - Modify the fwd pass such that
|
||||
# 1) Replace rand with run_and_save_rng_state wrapper
|
||||
# 2) Replace the users of the original op with the output[1] of this op.
|
||||
# 3) Collect all the rng_state - output[0] of each op, and make them output nodes.
|
||||
# 3) Collect all the rng_state - output[0] of each op, and make them
|
||||
# output nodes. Special care needs to be taken here because fwd outputs
|
||||
# has symints at the very end.
|
||||
# Step 3 - Modify the bwd pass such that
|
||||
# 1) Add the input nodes just before the tangents for the stashed rng states
|
||||
# 2) Replace rand with run_with_save_rng_state wrappers
|
||||
@ -574,11 +578,15 @@ def functionalize_rng_ops(joint_module, fw_module, bw_module):
|
||||
bw_graph.erase_node(bw_node)
|
||||
|
||||
|
||||
# Add the rng states in the output of the fwd graph
|
||||
fw_output = [node for node in fw_module.graph.nodes if node.op == "output"][0]
|
||||
outputs = fw_output.args[0] + fw_rng_state_outputs
|
||||
# Add the rng states in the output of the fwd graph. AOT Autograd assumes
|
||||
# that symints are at the end of forward graph outputs. So, insert the new
|
||||
# rng states accordingly.
|
||||
fw_output_node = [node for node in fw_module.graph.nodes if node.op == "output"][0]
|
||||
fw_outputs = fw_output_node.args[0]
|
||||
sym_node_start_idx = len(fw_outputs) - num_sym_nodes
|
||||
outputs = fw_outputs[:sym_node_start_idx] + fw_rng_state_outputs + fw_outputs[sym_node_start_idx:]
|
||||
fw_module.graph.output(outputs)
|
||||
fw_module.graph.erase_node(fw_output)
|
||||
fw_module.graph.erase_node(fw_output_node)
|
||||
fw_module.recompile()
|
||||
bw_module.recompile()
|
||||
return fw_module, bw_module
|
||||
@ -866,13 +874,14 @@ def min_cut_rematerialization_partition(
|
||||
# save_for_backward on tensors and stashes symints in autograd .ctx
|
||||
saved_sym_nodes = list(filter(lambda n: is_sym_node(n), saved_values))
|
||||
saved_values = list(filter(lambda n: not is_sym_node(n), saved_values))
|
||||
# NB: saved_sym_nodes will be mutated to reflect the actual saved symbols
|
||||
fw_module, bw_module = _extract_fwd_bwd_modules(
|
||||
joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs)
|
||||
|
||||
if graph_has_recomputable_ops:
|
||||
if graph_has_recomputable_rng_ops:
|
||||
fw_module, bw_module = functionalize_rng_ops(
|
||||
joint_module, fw_module, bw_module
|
||||
joint_module, fw_module, bw_module, len(saved_sym_nodes)
|
||||
)
|
||||
bw_module = reordering_to_mimic_autograd_engine(bw_module)
|
||||
|
||||
|
Reference in New Issue
Block a user