[user-streams] Assign streams to gradient accum in bwd

ghstack-source-id: 788de728e52037bfd09ae567398b8137b939dbfa
Pull-Request: https://github.com/pytorch/pytorch/pull/167513

more bwd changes
This commit is contained in:
Michael Lazos
2025-11-13 15:45:37 -08:00
parent e5eb89e111
commit ad6bbc5d86
3 changed files with 119 additions and 1 deletions

View File

@ -470,7 +470,7 @@ class <lambda>(torch.nn.Module):
)
@requires_cuda
def test_stream_backward(self) -> None:
def test_stream_backward_simple(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
@ -524,7 +524,68 @@ class GraphModule(torch.nn.Module):
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
@requires_cuda
def test_stream_backward_sync(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, device="cuda:0", requires_grad=True) + 1,
torch.ones(2, 2, device="cuda:0", requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",

View File

@ -33,6 +33,7 @@ from .graph_capture_wrappers import (
handle_effect_tokens_fn,
)
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
from .streams import assign_backward_streams
from .utils import (
call_and_expect_output_descs,
copy_fwd_metadata_to_bw_nodes,
@ -473,6 +474,9 @@ def aot_dispatch_autograd_graph(
# fw node match might be erased
copy_fwd_metadata_to_bw_nodes(fx_g)
# After copying metadata, assign streams to gradient accumulation nodes
assign_backward_streams(fx_g)
fx_g.graph.eliminate_dead_code()
if not aot_config.disable_functionalization:
# There should be *NO* mutating ops in the graph at this point.

View File

@ -0,0 +1,53 @@
from typing import Optional, TypeAlias
import torch.fx
import torch.fx.traceback
from torch._dynamo.graph_utils import _get_flat_args
Node: TypeAlias = torch.fx.Node
def is_gradient_acc(node: Node) -> bool:
return node.meta.get("is_gradient_acc", False)
def get_stream(node: Node) -> Optional[int]:
maybe_annotation = node.meta.get("custom", None)
if maybe_annotation is not None:
return node.meta["custom"].get("stream", None)
else:
return None
def set_stream(node: Node, ind: int) -> None:
if "custom" in node.meta:
node.meta["custom"].update({"stream": ind})
else:
node.meta["custom"] = {"stream": ind}
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
"""Assigns backward streams to gradient accumulation nodes"""
# NB: iterate in reverse order to more closely match eager
# the user node stream will be populated first
for node in reversed(list(gm.graph.nodes)):
if is_gradient_acc(node):
# Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream:
# 1. Match first stream assignment of the first user with a stream
# 2. Match first stream assignment encountered in the args from left to right
# This differs from eager in some cases:
# Specifically the eager code uses the autograd node to determine the stream,
# crucially this does not necessarily correspond to the FX graph node. For example,
# in the backward for an add node with a constant we will passthrough and during backward tracing,
# no op will be added to the FX graph, so our stream assignment will differ in this case.
gradients = _get_flat_args(node, {})
users = list(node.users.keys())
# All gradients will be on same device, they will be coerced if they were not with a .to() node
for neighbor in users + gradients:
ind = get_stream(neighbor)
if ind is not None:
set_stream(node, ind)
break