mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 22:25:03 +08:00
[user-streams] Assign streams to gradient accum in bwd
ghstack-source-id: 788de728e52037bfd09ae567398b8137b939dbfa Pull-Request: https://github.com/pytorch/pytorch/pull/167513 more bwd changes
This commit is contained in:
@ -470,7 +470,7 @@ class <lambda>(torch.nn.Module):
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_backward(self) -> None:
|
||||
def test_stream_backward_simple(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
@ -524,7 +524,68 @@ class GraphModule(torch.nn.Module):
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_backward_sync(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s0:
|
||||
y0 = 2 * x + y
|
||||
with s2:
|
||||
z = 2 * x + y
|
||||
|
||||
return y0, z
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2, device="cuda:0", requires_grad=True) + 1,
|
||||
torch.ones(2, 2, device="cuda:0", requires_grad=True),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
bw_graphs,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
|
||||
return (add, add_1)
|
||||
""",
|
||||
)
|
||||
|
||||
actual[1].sum().backward()
|
||||
self.assertExpectedInline(
|
||||
print_graph(bw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
|
||||
|
||||
#
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
|
||||
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
""",
|
||||
|
||||
@ -33,6 +33,7 @@ from .graph_capture_wrappers import (
|
||||
handle_effect_tokens_fn,
|
||||
)
|
||||
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
|
||||
from .streams import assign_backward_streams
|
||||
from .utils import (
|
||||
call_and_expect_output_descs,
|
||||
copy_fwd_metadata_to_bw_nodes,
|
||||
@ -473,6 +474,9 @@ def aot_dispatch_autograd_graph(
|
||||
# fw node match might be erased
|
||||
copy_fwd_metadata_to_bw_nodes(fx_g)
|
||||
|
||||
# After copying metadata, assign streams to gradient accumulation nodes
|
||||
assign_backward_streams(fx_g)
|
||||
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
if not aot_config.disable_functionalization:
|
||||
# There should be *NO* mutating ops in the graph at this point.
|
||||
|
||||
53
torch/_functorch/_aot_autograd/streams.py
Normal file
53
torch/_functorch/_aot_autograd/streams.py
Normal file
@ -0,0 +1,53 @@
|
||||
from typing import Optional, TypeAlias
|
||||
|
||||
import torch.fx
|
||||
import torch.fx.traceback
|
||||
from torch._dynamo.graph_utils import _get_flat_args
|
||||
|
||||
|
||||
Node: TypeAlias = torch.fx.Node
|
||||
|
||||
|
||||
def is_gradient_acc(node: Node) -> bool:
|
||||
return node.meta.get("is_gradient_acc", False)
|
||||
|
||||
|
||||
def get_stream(node: Node) -> Optional[int]:
|
||||
maybe_annotation = node.meta.get("custom", None)
|
||||
if maybe_annotation is not None:
|
||||
return node.meta["custom"].get("stream", None)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def set_stream(node: Node, ind: int) -> None:
|
||||
if "custom" in node.meta:
|
||||
node.meta["custom"].update({"stream": ind})
|
||||
else:
|
||||
node.meta["custom"] = {"stream": ind}
|
||||
|
||||
|
||||
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
|
||||
"""Assigns backward streams to gradient accumulation nodes"""
|
||||
|
||||
# NB: iterate in reverse order to more closely match eager
|
||||
# the user node stream will be populated first
|
||||
for node in reversed(list(gm.graph.nodes)):
|
||||
if is_gradient_acc(node):
|
||||
# Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream:
|
||||
# 1. Match first stream assignment of the first user with a stream
|
||||
# 2. Match first stream assignment encountered in the args from left to right
|
||||
# This differs from eager in some cases:
|
||||
# Specifically the eager code uses the autograd node to determine the stream,
|
||||
# crucially this does not necessarily correspond to the FX graph node. For example,
|
||||
# in the backward for an add node with a constant we will passthrough and during backward tracing,
|
||||
# no op will be added to the FX graph, so our stream assignment will differ in this case.
|
||||
gradients = _get_flat_args(node, {})
|
||||
users = list(node.users.keys())
|
||||
|
||||
# All gradients will be on same device, they will be coerced if they were not with a .to() node
|
||||
for neighbor in users + gradients:
|
||||
ind = get_stream(neighbor)
|
||||
if ind is not None:
|
||||
set_stream(node, ind)
|
||||
break
|
||||
Reference in New Issue
Block a user