[user-streams] Allow new events to be created and registered during compilation (#167510)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167510
Approved by: https://github.com/williamwen42
This commit is contained in:
Michael Lazos
2025-11-11 01:54:35 -08:00
committed by PyTorch MergeBot
parent fe841a1db4
commit d8ada1ee76
2 changed files with 42 additions and 1 deletions

View File

@ -335,6 +335,34 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
@requires_multigpu()
def test_new_event_api(self) -> None:
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
from torch._dynamo.variables.streams import new_event
def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
e0_ind = new_event()
with torch.Stream(device="cuda:1"):
get_external_object_by_index(e0_ind).record()
e1_ind = new_event()
self.assertNotEqual(e0_ind, e1_ind)
self.assertNotEqual(
get_external_object_by_index(e0_ind),
get_external_object_by_index(e1_ind),
)
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
gm.graph.call_function(
get_external_object_by_index, args=(1,), kwargs={}
)
return gm
@torch.compile(backend=event_generation_backend)
def fn(x):
return x + 1
fn(torch.ones(2, 2, device="cuda:0"))
@requires_cuda
def test_stream_with_mutation(self):
def fn(x, y):

View File

@ -10,7 +10,10 @@ from torch.fx import has_side_effect, Proxy
from .. import graph_break_hints
from ..bytecode_transformation import create_call_function
from ..exc import TYPE_CHECKING, unimplemented
from ..graph_bytecode_inputs import get_external_object_by_index
from ..graph_bytecode_inputs import (
get_external_object_by_index,
register_graph_created_object,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import FxTracebackAnnotateVariable
@ -28,6 +31,16 @@ from torch._library.custom_ops import custom_op
Tensor = torch.Tensor
def new_event(*args: Any, **kwargs: Any) -> int:
event = torch.Event(*args, **kwargs)
return register_graph_created_object(
event,
EventVariable.make_construct_in_graph_event_fn(
TupleVariable([]), ConstDictVariable({})
),
)
def _get_stream_by_index(index: int) -> torch.Stream:
stream = get_external_object_by_index(index)
assert isinstance(stream, torch.Stream), (