mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 14:15:07 +08:00
[user-streams] Allow new events to be created and registered during compilation (#167510)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167510 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
fe841a1db4
commit
d8ada1ee76
@ -335,6 +335,34 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@requires_multigpu()
|
||||
def test_new_event_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import new_event
|
||||
|
||||
def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
e0_ind = new_event()
|
||||
with torch.Stream(device="cuda:1"):
|
||||
get_external_object_by_index(e0_ind).record()
|
||||
e1_ind = new_event()
|
||||
self.assertNotEqual(e0_ind, e1_ind)
|
||||
self.assertNotEqual(
|
||||
get_external_object_by_index(e0_ind),
|
||||
get_external_object_by_index(e1_ind),
|
||||
)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(1,), kwargs={}
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=event_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_with_mutation(self):
|
||||
def fn(x, y):
|
||||
|
||||
@ -10,7 +10,10 @@ from torch.fx import has_side_effect, Proxy
|
||||
from .. import graph_break_hints
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..exc import TYPE_CHECKING, unimplemented
|
||||
from ..graph_bytecode_inputs import get_external_object_by_index
|
||||
from ..graph_bytecode_inputs import (
|
||||
get_external_object_by_index,
|
||||
register_graph_created_object,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import FxTracebackAnnotateVariable
|
||||
@ -28,6 +31,16 @@ from torch._library.custom_ops import custom_op
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
def new_event(*args: Any, **kwargs: Any) -> int:
|
||||
event = torch.Event(*args, **kwargs)
|
||||
return register_graph_created_object(
|
||||
event,
|
||||
EventVariable.make_construct_in_graph_event_fn(
|
||||
TupleVariable([]), ConstDictVariable({})
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _get_stream_by_index(index: int) -> torch.Stream:
|
||||
stream = get_external_object_by_index(index)
|
||||
assert isinstance(stream, torch.Stream), (
|
||||
|
||||
Reference in New Issue
Block a user