Compare commits

...

8 Commits

Author SHA1 Message Date
cafea25ce4 Update
[ghstack-poisoned]
2025-11-13 15:45:38 -08:00
405e7fa2d8 Update (base update)
[ghstack-poisoned]
2025-11-13 15:45:38 -08:00
2dd6ddc5dd Update
[ghstack-poisoned]
2025-11-13 14:23:25 -08:00
cdb9f8ea4d Update (base update)
[ghstack-poisoned]
2025-11-13 14:23:25 -08:00
6e26abbc5e Update
[ghstack-poisoned]
2025-11-13 10:24:38 -08:00
f87cb783da Update (base update)
[ghstack-poisoned]
2025-11-13 10:24:38 -08:00
ddae908d2b Update
[ghstack-poisoned]
2025-11-12 20:43:18 -08:00
8824cc6a88 Update (base update)
[ghstack-poisoned]
2025-11-12 20:43:18 -08:00
6 changed files with 171 additions and 35 deletions

View File

@ -470,7 +470,7 @@ class <lambda>(torch.nn.Module):
)
@requires_cuda
def test_stream_backward(self) -> None:
def test_stream_backward_simple(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
@ -524,7 +524,68 @@ class GraphModule(torch.nn.Module):
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
@requires_cuda
def test_stream_backward_sync(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, device="cuda:0", requires_grad=True) + 1,
torch.ones(2, 2, device="cuda:0", requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",

View File

@ -25,6 +25,11 @@ def has_user_objects() -> bool:
return bool(index_to_bytecode_constructor)
def stash_graph_created_object(obj: Any) -> Any:
keep_alive.append(obj)
return obj
def get_external_object_by_index(index: int) -> Any:
assert index in index_to_external_object_weakref, (
"Index not registered in index_to_user_object_weakref"

View File

@ -1527,37 +1527,6 @@ class OutputGraph(OutputGraphCommon):
from .decorators import disable
if has_user_objects():
# NB: This is where we store possible user objects before running the graph
# index_to_user_object_weakref is the function used in the graph to translate
# the dynamo-generated index into the actual object passed to the compiled function.
# We generate bytecode to store all user objects at the proper index in the below
# call.
codegen = PyCodegen(
self.root_tx, root, overridden_sources=overridden_sources
)
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__,
"store_user_object_weakrefs",
)
)
tmp_vars = []
for constructor in index_to_bytecode_constructor.values():
constructor(codegen)
var_name = (
self.new_var()
) # keep alive any temp objects for the rest of the frame
codegen.store(var_name)
tmp_vars.append(var_name)
for var_name in tmp_vars:
codegen.append_output(codegen.create_load(var_name))
codegen.call_function(len(index_to_bytecode_constructor), False)
codegen.pop_top()
self.add_output_instructions(codegen.get_instructions())
# to handle random calls
if len(self.random_calls) > 0:
random_calls_instructions = []
@ -2342,6 +2311,25 @@ class OutputGraph(OutputGraphCommon):
assert self.root_tx is not None
cg = PyCodegen(self.root_tx)
if has_user_objects():
# NB: This is where we store possible user objects before running the graph
# index_to_user_object_weakref is the function used in the graph to translate
# the dynamo-generated index into the actual object passed to the compiled function.
# We generate bytecode to store all user objects at the proper index in the below
# call.
cg.add_push_null(
lambda: cg.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__,
"store_user_object_weakrefs",
)
)
for constructor in index_to_bytecode_constructor.values():
constructor(cg)
cg.call_function(len(index_to_bytecode_constructor), False)
cg.pop_top()
for idx, arg in enumerate(self.graphargs):
self.export_metadata.graph_input_idx_to_local_source[idx] = arg.source
@ -3008,7 +2996,7 @@ class SubgraphTracer(fx.Tracer):
self.tracked_tensor_or_symint_vt: OrderedSet[VariableTracker] = OrderedSet()
def record_tensor_or_symint_vt(self, vt):
def record_tensor_or_symint_vt(self, vt: VariableTracker):
self.tracked_tensor_or_symint_vt.add(vt)
# preserve original meta if it is available

View File

@ -52,10 +52,21 @@ def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
)
def _codegen_current_stream(device: torch.device, cg: "PyCodegen") -> None:
cg.add_push_null(
lambda: cg.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
"stash_graph_created_object",
)
)
cg(CurrentStreamSource(device))
cg.extend_output(create_call_function(1, False))
def get_current_stream(device: torch.device) -> int:
stream = torch.accelerator.current_stream()
stream = torch.accelerator.current_stream(device)
return register_graph_created_object(
stream, lambda _, cg: cg(CurrentStreamSource(device))
stream, lambda _, cg: _codegen_current_stream(device, cg)
)
@ -362,6 +373,12 @@ class StreamVariable(StreamContextVariable):
args: TupleVariable, kwargs: ConstDictVariable
) -> Callable[[int, "PyCodegen"], None]:
def fn(index: int, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
"stash_graph_created_object",
)
)
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.utils.__name__, "build_stream"
@ -370,6 +387,7 @@ class StreamVariable(StreamContextVariable):
codegen(args)
codegen(kwargs)
codegen.extend_output(create_call_function(2, False))
codegen.extend_output(create_call_function(1, False))
return fn
@ -473,6 +491,12 @@ class EventVariable(VariableTracker):
args: TupleVariable, kwargs: ConstDictVariable
) -> Callable[[int, "PyCodegen"], None]:
def fn(index: int, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
"stash_graph_created_object",
)
)
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.utils.__name__, "build_event"
@ -481,6 +505,7 @@ class EventVariable(VariableTracker):
codegen(args)
codegen(kwargs)
codegen.extend_output(create_call_function(2, False))
codegen.extend_output(create_call_function(1, False))
return fn

View File

@ -33,6 +33,7 @@ from .graph_capture_wrappers import (
handle_effect_tokens_fn,
)
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
from .streams import assign_backward_streams
from .utils import (
call_and_expect_output_descs,
copy_fwd_metadata_to_bw_nodes,
@ -473,6 +474,9 @@ def aot_dispatch_autograd_graph(
# fw node match might be erased
copy_fwd_metadata_to_bw_nodes(fx_g)
# After copying metadata, assign streams to gradient accumulation nodes
assign_backward_streams(fx_g)
fx_g.graph.eliminate_dead_code()
if not aot_config.disable_functionalization:
# There should be *NO* mutating ops in the graph at this point.

View File

@ -0,0 +1,53 @@
from typing import Optional, TypeAlias
import torch.fx
import torch.fx.traceback
from torch._dynamo.graph_utils import _get_flat_args
Node: TypeAlias = torch.fx.Node
def is_gradient_acc(node: Node) -> bool:
return node.meta.get("is_gradient_acc", False)
def get_stream(node: Node) -> Optional[int]:
maybe_annotation = node.meta.get("custom", None)
if maybe_annotation is not None:
return node.meta["custom"].get("stream", None)
else:
return None
def set_stream(node: Node, ind: int) -> None:
if "custom" in node.meta:
node.meta["custom"].update({"stream": ind})
else:
node.meta["custom"] = {"stream": ind}
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
"""Assigns backward streams to gradient accumulation nodes"""
# NB: iterate in reverse order to more closely match eager
# the user node stream will be populated first
for node in reversed(list(gm.graph.nodes)):
if is_gradient_acc(node):
# Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream:
# 1. Match first stream assignment of the first user with a stream
# 2. Match first stream assignment encountered in the args from left to right
# This differs from eager in some cases:
# Specifically the eager code uses the autograd node to determine the stream,
# crucially this does not necessarily correspond to the FX graph node. For example,
# in the backward for an add node with a constant we will passthrough and during backward tracing,
# no op will be added to the FX graph, so our stream assignment will differ in this case.
gradients = _get_flat_args(node, {})
users = list(node.users.keys())
# All gradients will be on same device, they will be coerced if they were not with a .to() node
for neighbor in users + gradients:
ind = get_stream(neighbor)
if ind is not None:
set_stream(node, ind)
break