Compare commits

...

2 Commits

Author SHA1 Message Date
91b5fee902 [user-streams] Trace events with the new ops
ghstack-source-id: c79dd58a1c6f2e59d210d9305beef1053d43bc4b
Pull-Request: https://github.com/pytorch/pytorch/pull/167177
2025-11-06 22:13:36 -08:00
f247cab939 [user-streams] Add fallbacks for record and wait event
ghstack-source-id: c1d0439ca25a0de4d9856ba0bf58f0cad3216663
Pull-Request: https://github.com/pytorch/pytorch/pull/167260
2025-11-06 22:13:35 -08:00
8 changed files with 178 additions and 9 deletions

View File

@ -408,6 +408,9 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
self.assertEqual(ref0, res0)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@unittest.skip(
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
)
def test_cuda_event_reconstruct(self):
def fn(x):
e = torch.cuda.Event()
@ -425,6 +428,9 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
self.assertEqual(cnts.op_count, 3)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@unittest.skip(
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
)
def test_cuda_event_across_graph_break(self):
def fn(x):
e = torch.cuda.Event()
@ -446,9 +452,12 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = opt_fn(x)
self.assertEqual(ref[0], res[0])
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 9)
self.assertEqual(cnts.op_count, 10)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@unittest.skip(
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
)
def test_cuda_event_created_outside_of_graph(self):
user_stream = torch.cuda.Stream()
event = torch.cuda.Event()
@ -478,9 +487,12 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = run_iters(func, compile=True)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 3)
self.assertEqual(cnts.op_count, 4)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@unittest.skip(
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
)
def test_cuda_event_method_create_stream_outside_of_compile(self):
def fn(x, cur_stream, new_stream):
x = torch.mul(x, 1)

View File

@ -3,6 +3,7 @@ import functools
import re
import unittest
import weakref
from unittest.mock import patch
import torch
import torch._dynamo.test_case
@ -445,6 +446,37 @@ class GraphModule(torch.nn.Module):
""",
)
@requires_cuda
def test_event_tracing(self):
def fn(x) -> None:
e = torch.Event()
e.record()
x.add_(1)
return x
inp = (torch.ones(2, 2, device="cuda"),)
(
_,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]"):
#
record_event = torch.ops.streams.record_event.default(0, 1); record_event = None
#
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, 1)
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = None
return (copy_,)
""",
)
@requires_cuda
def test_run_opcheck_fork_join(self):
from torch._dynamo.variables.streams import fork_stream, join_stream
@ -491,6 +523,20 @@ class GraphModule(torch.nn.Module):
torch.accelerator.set_stream(original_stream)
reset_user_object_tracking()
@requires_cuda
def test_inductor_lowering(self):
with patch("torch._inductor.config.implicit_fallbacks", False):
@torch.compile()
def fn(x):
e = torch.Event()
x += x + 1
e.record()
return x
inp = (torch.ones(2, 2, device="cuda"),)
fn(*inp)
def test_is_marked_side_effectful(self):
self.assertIn(
torch.ops.streams.fork.default, torch.fx.node._side_effectful_functions

View File

@ -4768,6 +4768,10 @@ def build_stream(args: tuple[Any], kwargs: dict[Any, Any]) -> torch.Stream:
return torch._C.Stream(*args, **kwargs)
def build_event(args: tuple[Any], kwargs: dict[Any, Any]) -> torch.Event:
return torch._C.Event(*args, **kwargs)
class CompileTimeInstructionCounter:
_counter: int = 0
_id: int = -1

View File

@ -1083,6 +1083,7 @@ class VariableBuilder:
return EventVariable(
event_proxy,
value,
index,
source=self.source,
)
elif (
@ -3004,16 +3005,28 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
return SymNodeVariable(proxy, example_value, **options)
elif (
isinstance(example_value, torch.Stream)
and proxy.node.target == get_external_object_by_index
and proxy.node.target is get_external_object_by_index
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()
]:
set_example_value(proxy.node, example_value)
index = None
if proxy.node.target == get_external_object_by_index:
if proxy.node.target is get_external_object_by_index:
index = proxy.node.args[0]
return StreamVariable(proxy, example_value, index, **options)
elif (
isinstance(example_value, torch.Event)
and proxy.node.target is get_external_object_by_index
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()
]:
index = None
if proxy.node.target is get_external_object_by_index:
index = proxy.node.args[0]
set_example_value(proxy.node, example_value)
return EventVariable(proxy, example_value, index, **options)
elif (
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, torch.Event)
@ -3022,7 +3035,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
for _, device_interface in get_registered_device_interfaces()
]:
set_example_value(proxy.node, example_value)
return EventVariable(proxy, example_value, **options)
return EventVariable(proxy, example_value, None, **options)
elif proxy.node.target == "query" and proxy.node.op == "call_method":
set_example_value(proxy.node, example_value)
return ConstantVariable(example_value, **options)
@ -3033,7 +3046,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
and proxy.node.op == "call_method"
):
set_example_value(proxy.node, example_value)
return EventVariable(proxy, example_value, **options)
return EventVariable(proxy, example_value, None, **options)
elif isinstance(example_value, int) and (
proxy.node.target
in [

View File

@ -326,12 +326,19 @@ class StreamVariable(StreamContextVariable):
class EventVariable(VariableTracker):
def __init__(self, proxy: Proxy, value: torch.Event, **kwargs: Any) -> None:
def __init__(
self,
proxy: Proxy,
value: torch.Event,
user_object_index: Optional[int],
**kwargs: Any,
) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
super().__init__(**kwargs)
self.proxy = proxy
self.value = value
self.user_object_index = user_object_index
def call_method(
self,
@ -343,7 +350,29 @@ class EventVariable(VariableTracker):
from ..utils import proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait", "record", "synchronize"):
if name == "wait":
tx.output.create_proxy(
"call_function",
torch.ops.streams.wait_event,
(
self.user_object_index,
EventVariable._get_stream_arg(tx, args, kwargs).user_object_index,
),
{},
)
return ConstantVariable(None)
elif name == "record":
tx.output.create_proxy(
"call_function",
torch.ops.streams.record_event,
(
self.user_object_index,
EventVariable._get_stream_arg(tx, args, kwargs).user_object_index,
),
{},
)
return ConstantVariable(None)
elif name == "synchronize":
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
@ -373,6 +402,39 @@ class EventVariable(VariableTracker):
def as_proxy(self) -> Proxy:
return self.proxy
@staticmethod
def _get_stream_arg(
tx: "InstructionTranslator",
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "StreamVariable":
stream_arg = None
if args:
stream_arg = args[0]
elif kwargs:
stream_arg = kwargs.get("stream")
if not stream_arg:
stream_arg = tx.symbolic_stream_state.cur_stream()
return stream_arg # type: ignore[return-value]
@staticmethod
def make_construct_in_graph_event_fn(
args: TupleVariable, kwargs: ConstDictVariable
) -> Callable[[int, "PyCodegen"], None]:
def fn(index: int, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.utils.__name__, "build_event"
)
)
codegen(args)
codegen(kwargs)
codegen.extend_output(create_call_function(2, False))
return fn
def reconstruct(self, codegen: "PyCodegen") -> None:
# If we got here, this event is fully subsumed by the graph - this means it is
# not an input or global

View File

@ -1270,7 +1270,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# pyrefly: ignore [unbound-name]
return VariableTracker.build(tx, module, new_source)
@register(torch.accelerator.current_stream)
@register(torch.accelerator.current_stream, torch.cuda.current_stream)
def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs):
if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs):
unimplemented_v2(

View File

@ -839,6 +839,34 @@ class UserDefinedClassVariable(UserDefinedVariable):
"call_function", get_external_object_by_index, (ind,), {}
),
)
elif issubclass(self.value, torch.Event):
from .constant import ConstantVariable
from .lists import TupleVariable
# Register newly created event for reconstruction
var_kwargs = ConstDictVariable(
{ConstantVariable(k): v for k, v in kwargs.items()}
)
var_args = TupleVariable(list(args))
event = self.value(
*(var_args.as_python_constant()),
**(var_kwargs.as_python_constant()),
)
from ..graph_bytecode_inputs import register_graph_created_object
from .streams import EventVariable
ind = register_graph_created_object(
event,
EventVariable.make_construct_in_graph_event_fn(
var_args, var_kwargs
),
)
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function", get_external_object_by_index, (ind,), {}
),
)
else:
tensor_variable = wrap_fx_proxy(
tx=tx,

View File

@ -2379,6 +2379,10 @@ fallback_randn_default = fallback_handler(aten.randn.default)
fallback_randn_generator = fallback_handler(aten.randn.generator)
make_fallback(aten.randint)
# TODO: mlazos reevaluate if we want to codegen something different
make_fallback(torch.ops.streams.record_event.default)
make_fallback(torch.ops.streams.wait_event.default)
@register_lowering(aten.rand)
def rand(*args, **kwargs):