mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support propagating custom meta field to backward graph nodes (#164174)
# Propagate custom meta data to backward Support propagating the user annotation tags to backward graph, by extending the `copy_fwd_metadata_to_bw_nodes` utils (recommended by @xmfan , thanks!). Example annotation API (added in https://github.com/pytorch/pytorch/pull/163673): ``` class M(torch.nn.Module): def forward(self, x): with fx_traceback.annotate({"pp_stage": 0}): with fx_traceback.annotate({"fdsp_bucket": 0}): x = x + 1 x = x - 2 with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}): x = x * 2 x = x / 3 return x ``` Assumptions (some inherited from https://github.com/pytorch/pytorch/pull/126573): - I am trusting the seq_nr mapping introduced to aot_autograd nodes in https://github.com/pytorch/pytorch/pull/103129 - I am also trusting that the forward is single threaded, since seq_nr is thread local. If this isn't always true, we'll need to also plumb thread_id through the same machinery which is populating seq_nr. - **(This is changed in this PR!) I assume all backward graph nodes has "is_backward" for 'partitioner_tag', and all other nodes are forward graph nodes**. If we don't run export before `aot_export_join_with_descriptors`, then none of the nodes has "nn_module_stack" in node meta. If we do run export first, then we don't need this change. - I copy "custom" node meta from forward to backward graph nodes. Question: - Is it a good idea to copy all "custom" node meta? Or should we create a dedicated key in custom node meta to be copied? @SherlockNoMad - Do we expect people to run export before using `aot_export_join_with_descriptors`? - Can we assume the following for graph produced by `aot_export_join_with_descriptors`? "all backward graph nodes has "is_backward" for 'partitioner_tag', and all other nodes are forward graph nodes". Maybe this is a question for @ezyang ``` python test/functorch/test_aot_joint_with_descriptors.py -k test_preserve_ python test/export/test_export.py -k preserve_anno python test/distributed/tensor/test_dtensor_export.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164174 Approved by: https://github.com/xmfan, https://github.com/SherlockNoMad
This commit is contained in:
committed by
PyTorch MergeBot
parent
35c4130fd1
commit
6b768e1890
@ -5,6 +5,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
@ -37,6 +38,18 @@ class SimpleModel(torch.nn.Module):
|
||||
return self.mlp_1(self.mlp_0(input))
|
||||
|
||||
|
||||
class SimpleModelAnnotated(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.mlp_0 = MLPModule(device)
|
||||
self.mlp_1 = MLPModule(device)
|
||||
|
||||
def forward(self, input):
|
||||
with fx_traceback.annotate({"pp_stage": 0}):
|
||||
x = self.mlp_0(input)
|
||||
return self.mlp_1(x)
|
||||
|
||||
|
||||
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
|
||||
# needed for stric export
|
||||
torch.utils._pytree.register_constant(DTensorSpec)
|
||||
@ -90,7 +103,7 @@ class DTensorExportTest(TestCase):
|
||||
)
|
||||
self.device_type = "cuda"
|
||||
|
||||
def _run_test(self, export_fn):
|
||||
def _run_test(self, export_fn, test_annotation=False):
|
||||
dp_degree = 2
|
||||
tp_degree = self.world_size // dp_degree
|
||||
|
||||
@ -101,7 +114,11 @@ class DTensorExportTest(TestCase):
|
||||
mesh_dim_names=["dp", "tp"],
|
||||
)
|
||||
|
||||
model = SimpleModel(self.device_type)
|
||||
model = None
|
||||
if test_annotation:
|
||||
model = SimpleModelAnnotated(self.device_type)
|
||||
else:
|
||||
model = SimpleModel(self.device_type)
|
||||
parallelize_plan = {
|
||||
"mlp_0.net1": ColwiseParallel(),
|
||||
"mlp_0.net2": RowwiseParallel(),
|
||||
@ -131,6 +148,116 @@ class DTensorExportTest(TestCase):
|
||||
1,
|
||||
)
|
||||
|
||||
if test_annotation:
|
||||
|
||||
def has_tag(node):
|
||||
return "custom" in node.meta and node.meta["custom"] == {"pp_stage": 0}
|
||||
|
||||
def marked_nodes(gm):
|
||||
return [
|
||||
node.name
|
||||
for node in gm.graph.nodes
|
||||
if has_tag(node) and node.op == "call_function"
|
||||
]
|
||||
|
||||
def unmarked_nodes(gm):
|
||||
return [
|
||||
node.name
|
||||
for node in gm.graph.nodes
|
||||
if not has_tag(node) and node.op == "call_function"
|
||||
]
|
||||
|
||||
marked_nodes_fw = [
|
||||
"t",
|
||||
"addmm",
|
||||
"view",
|
||||
"relu",
|
||||
"view_1",
|
||||
"t_1",
|
||||
"div",
|
||||
"addmm_1",
|
||||
"all_reduce",
|
||||
"wait_tensor",
|
||||
"view_2",
|
||||
"t_12",
|
||||
]
|
||||
unmarked_nodes_fw = [
|
||||
"view_3",
|
||||
"t_2",
|
||||
"addmm_2",
|
||||
"view_4",
|
||||
"relu_1",
|
||||
"view_5",
|
||||
"t_3",
|
||||
"div_1",
|
||||
"addmm_3",
|
||||
"all_reduce_1",
|
||||
"wait_tensor_1",
|
||||
"view_6",
|
||||
"t_4",
|
||||
"t_8",
|
||||
]
|
||||
|
||||
marked_nodes_bw = [
|
||||
"mm_4",
|
||||
"t_13",
|
||||
"view_1",
|
||||
"mm_5",
|
||||
"t_14",
|
||||
"sum_3",
|
||||
"view_9",
|
||||
"t_15",
|
||||
"detach",
|
||||
"detach_1",
|
||||
"detach_6",
|
||||
"detach_7",
|
||||
"threshold_backward_1",
|
||||
"t_16",
|
||||
"mm_6",
|
||||
"t_17",
|
||||
"sum_4",
|
||||
"view_10",
|
||||
"t_18",
|
||||
]
|
||||
unmarked_nodes_bw = [
|
||||
"mm",
|
||||
"t_5",
|
||||
"view_5",
|
||||
"mm_1",
|
||||
"t_6",
|
||||
"sum_1",
|
||||
"view_7",
|
||||
"t_7",
|
||||
"detach_2",
|
||||
"detach_3",
|
||||
"detach_4",
|
||||
"detach_5",
|
||||
"threshold_backward",
|
||||
"mm_2",
|
||||
"t_9",
|
||||
"mm_3",
|
||||
"t_10",
|
||||
"sum_2",
|
||||
"view_8",
|
||||
"t_11",
|
||||
"all_reduce_2",
|
||||
"wait_tensor_2",
|
||||
]
|
||||
|
||||
self.assertEqual(marked_nodes(fw_gm), marked_nodes_fw)
|
||||
self.assertEqual(unmarked_nodes(fw_gm), unmarked_nodes_fw)
|
||||
|
||||
self.assertEqual(marked_nodes(bw_gm), marked_nodes_bw)
|
||||
self.assertEqual(unmarked_nodes(bw_gm), unmarked_nodes_bw)
|
||||
|
||||
self.assertEqual(
|
||||
set(marked_nodes(joint_gm)), set(marked_nodes_fw + marked_nodes_bw)
|
||||
)
|
||||
self.assertEqual(
|
||||
set(unmarked_nodes(joint_gm)),
|
||||
set(unmarked_nodes_fw + unmarked_nodes_bw),
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"export_fn",
|
||||
[
|
||||
@ -150,6 +277,9 @@ class DTensorExportTest(TestCase):
|
||||
def test_strict_export_parallelize_module_with_dtensor_input(self):
|
||||
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
|
||||
|
||||
def test_annotate_aot_export_joint_with_descriptors_alone(self):
|
||||
self._run_test(aot_export_joint_with_descriptors_alone, True)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
|
@ -910,7 +910,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
7|aten.view.default||l__self___fc1
|
||||
6|aten.t.default||l__self___fc1
|
||||
5|aten.view.default||l__self___fc1
|
||||
4|aten.view.default||
|
||||
4|aten.view.default||flatten
|
||||
2|aten.detach.default||l__self___relu1
|
||||
2|aten.detach.default||l__self___relu1
|
||||
2|aten.threshold_backward.default||l__self___relu1
|
||||
|
@ -15573,6 +15573,11 @@ def forward(self, x):
|
||||
test_serdes=True,
|
||||
)
|
||||
|
||||
@testing.expectedFailureTrainingIRToRunDecomp
|
||||
@testing.expectedFailureRetraceability
|
||||
@testing.expectedFailureStrictV2
|
||||
@testing.expectedFailureStrict # annotation needs to be handled in dynamo
|
||||
@testing.expectedFailureSerDer
|
||||
def test_preserve_annotation(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -15591,17 +15596,22 @@ def forward(self, x):
|
||||
ep = export(m, (torch.randn(10),))
|
||||
|
||||
for node in ep.graph.nodes:
|
||||
if node.target == torch.ops.aten.add.default:
|
||||
if node.op in ("placeholder", "output"):
|
||||
continue
|
||||
if node.target == torch.ops.aten.add.Tensor:
|
||||
self.assertTrue(node.meta["custom"], {"pp_stage": 0, "fdsp_bucket": 0})
|
||||
if node.target == torch.ops.aten.sub.default:
|
||||
elif node.target == torch.ops.aten.sub.Tensor:
|
||||
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
|
||||
if node.target == torch.ops.aten.mul.default:
|
||||
elif node.target == torch.ops.aten.mul.Tensor:
|
||||
self.assertTrue(
|
||||
node.meta["custom"],
|
||||
{"pp_stage": 0, "cuda_stream": 2, "fsdp_bucket": 1},
|
||||
)
|
||||
if node.target == torch.ops.aten.div.default:
|
||||
self.assertTrue(node.meta["custom"], {})
|
||||
elif node.target == torch.ops.aten.div.Tensor:
|
||||
if "custom" in node.meta:
|
||||
self.assertTrue(node.meta["custom"], {})
|
||||
else:
|
||||
raise AssertionError(f"Node not checked: {node}, {node.target}")
|
||||
|
||||
def test_dynamic_shapes_serdes_generic(self):
|
||||
from torch._export.serde.dynamic_shapes import (
|
||||
|
@ -13,6 +13,7 @@ import torch.fx.traceback as fx_traceback
|
||||
import torch.nn as nn
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._decomp import decomposition_table
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.testing import normalize_gm
|
||||
from torch._functorch._aot_autograd.descriptors import (
|
||||
BufferAOTInput,
|
||||
@ -777,36 +778,34 @@ class inner_f(torch.nn.Module):
|
||||
|
||||
inputs = (torch.randn(4, 3),)
|
||||
|
||||
for with_export in [True, False]:
|
||||
for with_export in [False]: # TODO: make dynamo work for annotation
|
||||
with ExitStack() as stack:
|
||||
model = None
|
||||
with fx_traceback.preserve_node_meta():
|
||||
if with_export:
|
||||
ep = torch.export.export(SimpleLinear(), inputs)
|
||||
model = ep.module()
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
model = _dynamo_graph_capture_for_export(model)(*inputs)
|
||||
else:
|
||||
model = SimpleLinear()
|
||||
|
||||
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||
stack, model, inputs, decompositions=decomposition_table
|
||||
stack, model, inputs, decompositions={}
|
||||
)
|
||||
|
||||
for node in joint_with_descriptors.graph_module.graph.nodes:
|
||||
if (
|
||||
node.target
|
||||
in (
|
||||
torch.ops.prims.transpose.default,
|
||||
torch.ops.aten.mm.default,
|
||||
torch.ops.prims.mul.default,
|
||||
torch.ops.prims.broadcast_in_dim.default,
|
||||
torch.ops.prims.add.default,
|
||||
)
|
||||
# TODO: add annotation to backward graph nodes
|
||||
and node.meta.get("partitioner_tag") != "is_backward"
|
||||
if node.op in ("placeholder", "output"):
|
||||
continue
|
||||
if node.target != torch.ops.aten.sub.Tensor and node.op not in (
|
||||
"placeholder",
|
||||
"output",
|
||||
):
|
||||
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
|
||||
if node.target == torch.ops.aten.sub.default:
|
||||
self.assertTrue(node.meta.get("custom", {}), {})
|
||||
elif node.target == torch.ops.aten.sub.Tensor:
|
||||
if "custom" in node.meta:
|
||||
self.assertTrue(node.meta.get("custom", {}), {})
|
||||
else:
|
||||
raise AssertionError(f"Node not checked: {node}, {node.target}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -422,14 +422,18 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
|
||||
# the descendants of graph inputs corresponding to fwd inputs, didn't
|
||||
# seem obvious at first glance on how to partition graph inputs into
|
||||
# fwd vs bwd without relying on string names.
|
||||
return "nn_module_stack" in node.meta and "seq_nr" in node.meta
|
||||
return (
|
||||
node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
|
||||
)
|
||||
|
||||
def _is_backward_node_with_seq_nr(node):
|
||||
# For now, assume that if nn_module_stack_metadata is not populated,
|
||||
# this node is from the backward. Ignore nodes without `seq_nr`.
|
||||
# TODO(future): there is likely a less brittle way to do this, same
|
||||
# as with the forward.
|
||||
return ("nn_module_stack" not in node.meta) and "seq_nr" in node.meta
|
||||
return (
|
||||
node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
|
||||
)
|
||||
|
||||
fwd_seq_nr_to_node = {}
|
||||
for node in fx_g.graph.nodes:
|
||||
@ -449,8 +453,10 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
|
||||
# fwd_node should always exist, but handle non-existence just in case
|
||||
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
|
||||
if fwd_node is not None:
|
||||
node.meta["fwd_nn_module_stack"] = fwd_node.meta["nn_module_stack"]
|
||||
node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack")
|
||||
node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack")
|
||||
# TODO: better to change to a specific field of custom?
|
||||
node.meta["custom"] = fwd_node.meta.get("custom")
|
||||
|
||||
|
||||
def register_buffer_assignment_hook(mod, assigned_buffers):
|
||||
|
Reference in New Issue
Block a user