Compare commits

...

1 Commits

6 changed files with 129 additions and 6 deletions

View File

@ -1066,6 +1066,8 @@ coverage_ignore_functions = [
"set_current_meta",
"set_grad_fn_seq_nr",
"set_stack_trace",
"set_current_replay_node",
"get_current_replay_node",
# torch.jit.annotations
"ann_to_type",
"check_fn",

View File

@ -1016,6 +1016,59 @@ class inner_f(torch.nn.Module):
self.assertFalse("self._opoverload" in foo_node.meta.get("stack_trace", None))
self.assertFalse("self._opoverload" in gm.print_readable(print_output=False))
def test_preserve_annotate_replay_view(self):
"""Test stack trace and annotation are correct on nodes regenerated in functionalization"""
def _unpermute(out, input_shape, permuted_indices):
"""
Unpermute operation from torchtitan MoE utils.
"""
out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]
return out
class Module(nn.Module):
def __init__(self):
super().__init__()
self.input_shape = (5, 3)
self.permuted_indices = torch.tensor([2, 0, 3, 1])
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
routed_output = _unpermute(
x, self.input_shape, self.permuted_indices
)
return routed_output.cos()
inputs = (torch.randn(4, 3, requires_grad=True),)
model = Module()
graph_module = graph_capture(model, inputs, True)
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
slice_nodes = graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten.slice.Tensor
)
self.assertEqual(len(slice_nodes), 1)
slice_backward_nodes = graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten.slice_backward.default
)
self.assertEqual(len(slice_backward_nodes), 1)
slice_node = slice_nodes[0]
slice_backward_node = slice_backward_nodes[0]
self.assertEqual(slice_node.meta["seq_nr"], slice_backward_node.meta["seq_nr"])
self.assertTrue("out = out_unpermuted[:-1]" in slice_node.meta["stack_trace"])
self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 'new_empty', {'pp_stage': 0})
('call_function', 'index_put', {'pp_stage': 0})
('call_function', 'slice_2', {'pp_stage': 0})
('call_function', 'slice_backward', {'pp_stage': 0})
('call_function', 'index', {'pp_stage': 0})""",
)
if __name__ == "__main__":
run_tests()

View File

@ -3245,8 +3245,8 @@ def forward(self, primals_1):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None
as_strided_9 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_9, [4]); as_strided_9 = None
return (as_strided_scatter, view_1)""",
) # noqa: B950
@ -3409,13 +3409,13 @@ def forward(self, primals_1, primals_2, primals_3):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None
unsqueeze = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze); add_1 = None
as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None
return (as_strided_scatter, add_2, view_2, unsqueeze_1)""",
return (as_strided_scatter, add_2, view_2, unsqueeze)""",
) # noqa: B950
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")

View File

@ -8,6 +8,7 @@ from contextlib import AbstractContextManager
from typing import Any, Optional, Union
import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._C import _functionalization_reapply_views_tls as _reapply_views
from torch._ops import _get_dispatch_mode_pre_dispatch
@ -504,6 +505,30 @@ class FunctionalTensorMode(TorchDispatchMode):
torch.Tensor, wrap, outs_unwrapped
)
else:
# Note: [Functionalization View Replay Annotation]
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
# at the first time they are next used.
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
# to have the same annotation that the user specified on the original views. But view replay in
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
# for second_op!
#
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
# is globally set when we lazily perform view replay.
# The globally set metadata will be used to populate the fx node created for the replayed operation.
if m := torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
):
for a in pytree.tree_leaves([args, kwargs]):
if not isinstance(a, FunctionalTensor):
continue
curr_node = m.tracer.tensor_tracker[
torch._from_functional_tensor(a.elem)
].proxy.node
with fx_traceback.set_current_replay_node(curr_node):
torch._sync(a)
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch

View File

@ -206,6 +206,21 @@ class TracerBase:
if current_meta.get("in_grad_fn", 0) > 0:
annotation_log.debug("seq_nr from current_meta")
new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
# See Note [Functionalization View Replay Annotation]
# Overriding some node meta with the original node meta of the
# regenerated node.
replay_node: Node = fx_traceback.get_current_replay_node()
if replay_node is not None:
node.meta["is_functional_regenerated"] = True
if "seq_nr" in replay_node.meta:
annotation_log.debug("seq_nr from replay_node")
new_seq_nr = replay_node.meta["seq_nr"]
if "custom" in replay_node.meta:
node.meta["custom"] = replay_node.meta.get("custom")
if "stack_trace" in replay_node.meta:
node.stack_trace = replay_node.meta.get("stack_trace")
annotation_log.debug("Assigning new_seq_nr %s to %s", new_seq_nr, node.name)
node.meta["seq_nr"] = new_seq_nr

View File

@ -30,9 +30,12 @@ __all__ = [
"NodeSource",
"NodeSourceAction",
"get_graph_provenance_json",
"set_current_replay_node",
"get_current_replay_node",
]
current_meta: dict[str, Any] = {}
current_replay_node: Optional[Node] = None
should_preserve_node_meta = False
@ -400,6 +403,31 @@ def get_current_meta() -> dict[str, Any]:
return current_meta
@compatibility(is_backward_compatible=False)
@contextmanager
def set_current_replay_node(node):
"""
Set the currently replay node. If `current_replay_node` is not None,
then we're re-generating the `current_replay_node` in FunctionalTensorMode.
"""
# See [Note] annotation for more details.
global current_replay_node
saved_current_replay_node = current_replay_node
try:
current_replay_node = node
yield
finally:
current_replay_node = saved_current_replay_node
@compatibility(is_backward_compatible=False)
def get_current_replay_node():
"""
Get the currently replay node
"""
return current_replay_node
@compatibility(is_backward_compatible=False)
def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
"""