Compare commits

...

10 Commits

Author SHA1 Message Date
2405279681 Update
[ghstack-poisoned]
2025-11-17 22:11:28 -08:00
fbcb2b3090 Update (base update)
[ghstack-poisoned]
2025-11-17 22:11:28 -08:00
8fc150ec9a Update
[ghstack-poisoned]
2025-11-17 18:51:23 -08:00
88a971f1e0 Update (base update)
[ghstack-poisoned]
2025-11-17 18:51:22 -08:00
6671232705 Update
[ghstack-poisoned]
2025-11-13 15:45:39 -08:00
0c149af21b Update (base update)
[ghstack-poisoned]
2025-11-13 15:45:39 -08:00
60b393237f Update
[ghstack-poisoned]
2025-11-13 14:23:26 -08:00
cad0bf0b0e Update (base update)
[ghstack-poisoned]
2025-11-13 14:23:26 -08:00
6d78e0fcd8 Update
[ghstack-poisoned]
2025-11-13 10:24:38 -08:00
8e177eee0e Update (base update)
[ghstack-poisoned]
2025-11-13 10:24:38 -08:00
3 changed files with 68 additions and 1 deletions

View File

@ -585,6 +585,10 @@ class GraphModule(torch.nn.Module):
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
# No stacktrace found for following nodes
record_event_default = torch.ops.streams.record_event.default(2, 1); record_event_default = None
wait_event_default = torch.ops.streams.wait_event.default(2, 0); wait_event_default = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)

View File

@ -33,7 +33,7 @@ from .graph_capture_wrappers import (
handle_effect_tokens_fn,
)
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
from .streams import assign_backward_streams
from .streams import assign_backward_streams, insert_backward_syncs
from .utils import (
call_and_expect_output_descs,
copy_fwd_metadata_to_bw_nodes,
@ -477,6 +477,8 @@ def aot_dispatch_autograd_graph(
# After copying metadata, assign streams to gradient accumulation nodes
assign_backward_streams(fx_g)
insert_backward_syncs(fx_g)
fx_g.graph.eliminate_dead_code()
if not aot_config.disable_functionalization:
# There should be *NO* mutating ops in the graph at this point.

View File

@ -3,6 +3,7 @@ from typing import Optional, TypeAlias
import torch.fx
import torch.fx.traceback
from torch._dynamo.graph_utils import _get_flat_args
from torch._dynamo.variables.streams import get_current_stream, new_event
Node: TypeAlias = torch.fx.Node
@ -12,6 +13,14 @@ def is_gradient_acc(node: Node) -> bool:
return node.meta.get("is_gradient_acc", False)
def is_bwd_node(node: Node) -> bool:
return node.meta.get("partitioner_tag") == "is_backward"
def get_device(node: Node) -> torch.device:
return node.meta["val"].device
def get_stream(node: Node) -> Optional[int]:
maybe_annotation = node.meta.get("custom", None)
if maybe_annotation is not None:
@ -20,6 +29,13 @@ def get_stream(node: Node) -> Optional[int]:
return None
def get_stream_or_current_stream(node: Node) -> int:
ind = get_stream(node)
if ind is None:
ind = get_current_stream(get_device(node))
return ind
def set_stream(node: Node, ind: int) -> None:
if "custom" in node.meta:
node.meta["custom"].update({"stream": ind})
@ -27,6 +43,36 @@ def set_stream(node: Node, ind: int) -> None:
node.meta["custom"] = {"stream": ind}
def insert_sync(
graph: torch.fx.Graph,
consumer: Node,
producer: Node,
node_to_wait_event_ind: dict[Node, int],
) -> None:
if producer not in node_to_wait_event_ind:
node_to_wait_event_ind[producer] = new_event()
with graph.inserting_after(producer):
node = graph.call_function(
torch.ops.streams.record_event.default,
(
node_to_wait_event_ind[producer],
get_stream_or_current_stream(producer),
),
)
node.meta["partitioner_tag"] = "must_be_in_backward"
with graph.inserting_before(consumer):
node = graph.call_function(
torch.ops.streams.wait_event.default,
(
node_to_wait_event_ind[producer],
get_stream_or_current_stream(consumer),
),
)
node.meta["partitioner_tag"] = "must_be_in_backward"
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
"""Assigns backward streams to gradient accumulation nodes"""
@ -51,3 +97,18 @@ def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
if ind is not None:
set_stream(node, ind)
break
def insert_backward_syncs(gm: torch.fx.GraphModule) -> None:
"""Inserts stream syncs for backward nodes if consumer and producer are on different streams"""
node_to_wait_event_ind = {}
for node in gm.graph.nodes:
if is_bwd_node(node):
flat_args = _get_flat_args(node, {})
cur_node_stream = get_stream(node)
for arg in flat_args:
if is_bwd_node(arg):
arg_stream = get_stream(arg)
if arg_stream != cur_node_stream and get_device(arg).type != "cpu":
insert_sync(gm.graph, node, arg, node_to_wait_event_ind)