Compare commits

...

17 Commits

Author SHA1 Message Date
f70a6ac1a6 [user-streams] Fix stream graph output semantics
ghstack-source-id: a1761206efa02920945b94e1ff811abeed6e470b
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164819
2025-10-06 22:09:05 -07:00
020bd1f830 [dynamo] Remove retrieving objects by ID
ghstack-source-id: f09cb7bc515fb4f7e195d75ef0dff2340584c473
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162905
2025-10-06 22:09:00 -07:00
c84e73df6e [user-streams] Add basic stream tests
ghstack-source-id: a0860743dd23356a9b69889c0762799a6c848b47
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164523

merge into streams suite
2025-10-06 22:09:00 -07:00
6120d39fdb [User-streams] Make torch.Event weakref compatible
ghstack-source-id: 49e3de1c6f1f57bd33330bffb257e01cc28bdda1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164522
2025-10-06 22:08:59 -07:00
17ed117a90 [user-streams] Make cuda streams weakref compatible
ghstack-source-id: 7f211d3af308e5924e4725efed914adf23613727
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164304
2025-10-06 22:08:59 -07:00
3e941aebb7 [user-cuda-streams] Add cuda streams test suite
ghstack-source-id: 782f3ac95798625c54d71a41249f4d31786831c9
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162901
2025-10-06 22:08:58 -07:00
e403203714 [user-streams] Support streams as contexts
ghstack-source-id: d95de5ef14d5be19d536d53b142aa4126dcdf1e0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164507
2025-10-06 22:08:58 -07:00
1e80b1ad7d [user-streams] Have StreamVariable inherit from StreamContextVariable
ghstack-source-id: e97a7966236489f10bb71c0178f973b0790a8170
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164344
2025-10-06 22:08:57 -07:00
3dbe758856 [user-streams] Move StreamContextVariable into streams module
finish moving

ghstack-source-id: bc16a138acb41b68164322ce8748d1da74318332
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164343
2025-10-06 22:08:57 -07:00
7faa4842f8 [user-streams] Exclude non-tensor nodes from stream args
ghstack-source-id: bff24c337acc3c38d8c23c80543bf763b645e506
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164818
2025-10-06 22:08:56 -07:00
45200769d1 [user-streams] Track external/internal nodes for stream context
ghstack-source-id: 9bbbf8cbc520f0f7a52acee50ee99771e711434e
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162904
2025-10-06 22:08:52 -07:00
38695fbfb4 [user-streams] update stream context to use fork/join
ghstack-source-id: 47010b1cf4a6ff3fe80165fbbed29dddb3b33cc9
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162903
2025-10-06 22:08:52 -07:00
0c01bec755 [user-streams] Add stream state manager
ghstack-source-id: aff52f66920a421884232d4e96a2ac1a80ec68d5
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162902
2025-10-06 22:08:51 -07:00
02cbfe3469 [user-cuda-streams] Add fork/join custom ops
Make custom ops inplace

ghstack-source-id: 3e41664853ac60d49b60bdebd8b3859d227925ad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162900
2025-10-06 22:08:50 -07:00
3e8fcf18ab [user-streams] Handle aliasing properly
ghstack-source-id: 9d1810e19ec99c3b148f446b56f85354549104a2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163028
2025-10-06 22:08:50 -07:00
3e5b122936 [user-cuda-streams] Pass streams/events to the graph via lookup table
ghstack-source-id: 72a6321c9de91f6c6c5e8b27a998c60853ffe5d2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162899

test fixes
2025-10-06 22:08:49 -07:00
422233fde2 [user-streams] Move stream code to streams module
ghstack-source-id: 627a90d386ed00706ad9d04f607732a3a6a79fc4
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163027
2025-10-06 22:08:49 -07:00
19 changed files with 760 additions and 269 deletions

146
test/dynamo/test_streams.py Normal file
View File

@ -0,0 +1,146 @@
# Owner(s): ["module: dynamo"]
import weakref
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
class TestStreams(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
@classmethod
def tearDownClass(cls):
super().tearDownClass()
def test_stream_weakref(self):
s = torch.Stream()
weakref.ref(s)
def test_event_weakref(self):
e = torch.Event()
weakref.ref(e)
def test_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_stream_context_graph_break(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
torch._dynamo.graph_break()
y = y + 1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
fn_opt = torch.compile(fn)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_stream_input(self):
def fn(x, y, s):
z = torch.add(x, y)
y = z + 2
return y, s
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(device="cuda"))
expected = fn(*inp)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_local_stream_return(self):
def fn(x, y):
s = torch.Stream()
z = torch.add(x, y)
y = z + 2
return y, s
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
fn_opt = torch.compile(fn, fullgraph=True)
_, s0 = fn_opt(*inp)
_, s1 = fn_opt(*inp)
# Streams will be different values for each invocation
# so don't check for equality
self.assertIsInstance(s0, torch.Stream)
# Stream should be newly allocated on each call
self.assertNotEqual(s0, s1)
def test_get_current_stream_return(self):
def fn(x, s):
with s:
s0 = torch.cuda.current_stream()
return x, s0
s_inp = torch.Stream(device="cuda")
inp = (torch.ones(2, 2) + 1, s_inp)
fn_opt = torch.compile(fn, fullgraph=True)
_, s0 = fn_opt(*inp)
_, s1 = fn_opt(*inp)
self.assertEqual(s_inp, s0)
self.assertEqual(s0, s1)
def test_nested_stream_enter_exit(self):
pass
def test_stream_enter_exit_graph_break(self):
pass
def test_nested_stream_enter_exit_graph_break(self):
pass
def test_local_stream_enter_exit(self):
pass
def test_local_stream_nested_enter_exit(self):
pass
def test_stream_with_mutation(self):
pass
def test_run_opcheck(self):
from torch._dynamo.variables.streams import fork_stream_, join_stream_
from torch.library import opcheck
sample_inputs = [
(1, torch.device("cuda:0"), 1, [torch.randn(3), torch.randn(3)]),
(
2,
torch.device("cuda:0"),
0,
[torch.randn(2, 3, device="cuda"), torch.randn(2, 3, device="cuda")],
),
]
for args in sample_inputs:
opcheck(fork_stream_, args)
opcheck(join_stream_, args)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -149,7 +149,6 @@ def reset() -> None:
GenerationTracker.clear()
TensorifyState.clear()
torch._dynamo.utils.warn_once_cache.clear()
torch._dynamo.utils.user_obj_id_to_weakref.clear()
torch._C._autograd._saved_tensors_hooks_set_tracing(False)

View File

@ -116,6 +116,7 @@ from .exc import (
unimplemented_v2,
Unsupported,
)
from .graph_bytecode_inputs import reset_user_object_tracking
from .guards import (
CheckFunctionManager,
get_and_maybe_log_recompilation_reasons,
@ -310,6 +311,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
torch.fx._symbolic_trace._maybe_revert_all_patches()
)
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
reset_user_object_tracking()
try:
return fn(*args, **kwargs)
finally:

View File

@ -2734,6 +2734,12 @@
}
],
"GB0272": [
{
"Gb_type": "Failed to make weakref to User Object when storing by ID",
"Context": "user_objected: {obj}",
"Explanation": "Object does not allow us to make a weakref to it",
"Hints": []
},
{
"Gb_type": "Failed to make weakref to User Object",
"Context": "user_objected: {obj}",
@ -2763,5 +2769,13 @@
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
]
}
],
"GB0275": [
{
"Gb_type": "Failed to make weakref to User Object",
"Context": "user_object: {value}",
"Explanation": "Object does not allow us to make a weakref to it",
"Hints": []
}
]
}

View File

@ -0,0 +1,65 @@
import weakref
from typing import Any
from torch._dynamo.source import Source
# This file is to handle types that we don't want to support
# as explicit FX graph inputs. This uses a sidetable which
# we populate in bytecode and is loaded during graph execution
# We use a dynamo-generated index as a level of indirection
# this allows us to register objects externally in pre-graph bytecode that we want
# to pass to the graph, but not support their types as graph inputs
index_to_source: dict[int, Source] = {}
index_to_user_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
def has_user_objects() -> bool:
return bool(index_to_source)
def get_user_object_by_index(index: int) -> Any:
print(index)
breakpoint()
assert index in index_to_user_object_weakref, (
"Index not registered in index_to_user_object_weakref"
)
print(index_to_user_object_weakref[index])
obj = index_to_user_object_weakref[index]()
assert obj is not None, "User object is no longer alive"
return index_to_user_object_weakref[index]()
def store_user_object_weakrefs(*args: Any) -> None:
global index_to_user_object_weakref
index_to_user_object_weakref.clear()
index_to_user_object_weakref.update(
{i: weakref.ref(arg) for i, arg in enumerate(args)}
)
def reset_user_object_tracking() -> None:
index_to_source.clear()
index_to_user_object_weakref.clear()
# Register a user object to be used in the graph
def register_user_object(value: Any, source: Source) -> int:
global index_to_source
index = len(index_to_source)
index_to_source[index] = source
try:
index_to_user_object_weakref[index] = weakref.ref(value)
except TypeError as e:
from .exc import unimplemented_v2
unimplemented_v2(
gb_type="Failed to make weakref to User Object",
context=f"user_object: {value}",
explanation="Object does not allow us to make a weakref to it",
hints=[],
from_exc=e,
)
return index

View File

@ -2162,6 +2162,8 @@ class GuardBuilder(GuardBuilderBase):
range,
dict_keys,
torch.Size,
torch.Stream,
torch.cuda.streams.Stream,
*np_types,
*ok_mutable_types,
}

View File

@ -99,6 +99,7 @@ from .exc import (
unimplemented_v2,
unimplemented_v2_with_warning,
)
from .graph_bytecode_inputs import has_user_objects, index_to_source
from .graph_deduplication import apply_graph_deduplication
from .graph_region_tracker import GraphRegionTracker
from .guards import GuardBuilder, install_guard
@ -1508,6 +1509,27 @@ class OutputGraph(OutputGraphCommon):
from .decorators import disable
if has_user_objects():
# NB: This is where we store possible user objects before running the graph
# index_to_user_object_weakref is the function used in the graph to translate
# the dynamo-generated index into the actual object passed to the compiled function.
# We generate bytecode to store all user objects at the proper index in the below
# call.
codegen = PyCodegen(
self.root_tx, root, overridden_sources=overridden_sources
)
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__,
"store_user_object_weakrefs",
)
)
for source in reversed(index_to_source.values()):
codegen(source)
codegen.call_function(len(index_to_source), False)
codegen.pop_top()
self.add_output_instructions(codegen.get_instructions())
# to handle random calls
if len(self.random_calls) > 0:
random_calls_instructions = []
@ -1645,7 +1667,7 @@ class OutputGraph(OutputGraphCommon):
)
elif (
vt.source is not None
and (source := getattr(vt.source, "base", None))
and (source := getattr(vt.source, "base", None)) # type: ignore[assignment]
and source.is_input
):
self.export_metadata.output_return_type[idx] = (

View File

@ -3333,6 +3333,7 @@ def get_fake_value(
UserError,
UserErrorType,
)
from .variables.streams import stream_state_mgr
op = node.op
@ -3347,13 +3348,27 @@ def get_fake_value(
if (
torch._dynamo.config.use_graph_deduplication
or torch._dynamo.config.track_nodes_for_deduplication
or stream_state_mgr.in_stream_context()
):
flat_args_kwargs = get_fake_values_from_nodes(
tx, _get_flat_args(node, {}), allow_non_graph_fake
)
id_to_initial_version = {
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
}
flat_args = _get_flat_args(node, {})
if stream_state_mgr.in_stream_context():
for arg in flat_args:
if isinstance(arg, torch.fx.Node):
stream_state_mgr.track_node(arg)
if (
torch._dynamo.config.use_graph_deduplication
or torch._dynamo.config.track_nodes_for_deduplication
):
flat_args_kwargs = get_fake_values_from_nodes(
tx, flat_args, allow_non_graph_fake
)
id_to_initial_version = {
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
}
else:
flat_args_kwargs = []
id_to_initial_version = {}
else:
flat_args_kwargs = []
id_to_initial_version = {}
@ -3502,6 +3517,9 @@ def get_fake_value(
torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val
)
if stream_state_mgr.in_stream_context():
stream_state_mgr.track_internal_node(node)
if (
torch._dynamo.config.use_graph_deduplication
or torch._dynamo.config.track_nodes_for_deduplication
@ -4701,34 +4719,6 @@ def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]:
return tensor_dict
# This is useful for reconstructing within the Dynamo graph the non-graph-input objects
# whose lifetime is governed by the user.
# e.g. torch.cuda.Event is a prime example.
user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {}
def get_user_object_from_id(obj_id: int) -> Any:
obj = user_obj_id_to_weakref[obj_id]()
assert obj is not None, "User object is no longer alive"
return obj
def store_user_object_weakref(obj: object) -> None:
obj_id = id(obj)
try:
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)
except TypeError as e:
from .exc import unimplemented_v2
unimplemented_v2(
gb_type="Failed to make weakref to User Object",
context=f"user_objected: {obj}",
explanation="Object does not allow us to make a weakref to it",
hints=[],
from_exc=e,
)
class CompileTimeInstructionCounter:
_counter: int = 0
_id: int = -1

View File

@ -36,8 +36,6 @@ from .ctx_manager import (
JvpIncrementNestingCtxManagerVariable,
SDPAKernelVariable,
SetFwdGradEnabledContextManager,
StreamContextVariable,
StreamVariable,
TemporarilyPopInterpreterStackCtxManagerVariable,
VmapIncrementNestingCtxManagerVariable,
WithEnterFunctionVariable,
@ -130,6 +128,7 @@ from .nn_module import (
)
from .optimizer import OptimizerVariable
from .sdpa import SDPAParamsVariable
from .streams import EventVariable, StreamContextVariable, StreamVariable
from .tensor import (
DataPtrVariable,
FakeItemVariable,

View File

@ -45,6 +45,10 @@ import sympy
import torch
from torch import SymInt
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.graph_bytecode_inputs import (
get_user_object_by_index,
register_user_object,
)
from torch._dynamo.utils import (
get_metrics_context,
is_int_specialization_case,
@ -172,11 +176,8 @@ from .ctx_manager import (
AutocastModeVariable,
DynamoConfigPatchVariable,
ErrorOnGraphBreakVariable,
EventVariable,
NullContextVariable,
PreserveVersionContextVariable,
StreamContextVariable,
StreamVariable,
)
from .dicts import (
ConstDictVariable,
@ -257,6 +258,7 @@ from .nn_module import (
from .optimizer import OptimizerVariable
from .script_object import TorchScriptObjectVariable
from .sdpa import SDPAParamsVariable
from .streams import EventVariable, StreamContextVariable, StreamVariable
from .tensor import (
NumpyNdarrayVariable,
supported_const_comparison_op_values,
@ -1036,24 +1038,20 @@ class VariableBuilder:
stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
return StreamContextVariable.create(self.tx, stream_var)
elif isinstance(value, torch.Stream):
self.install_guards(GuardBuilder.ID_MATCH)
# This refers to the device-agnostic torch.Stream
self.install_guards(GuardBuilder.TYPE_MATCH)
index = register_user_object(value, self.source)
stream_proxy = self.tx.output.create_proxy(
"call_function",
type(value),
(),
{
"stream_id": value.stream_id,
"device_index": value.device_index,
"device_type": value.device_type,
},
"call_function", get_user_object_by_index, (index,), {}
)
set_example_value(stream_proxy.node, value)
return StreamVariable(
var = StreamVariable(
stream_proxy,
value,
value.device,
source=self.source,
)
return self.tx.output.side_effects.track_object_existing(value, var)
elif isinstance(value, (torch._C._SDPAParams)):
self.install_guards(GuardBuilder.TYPE_MATCH)
return SDPAParamsVariable.create(self.tx, value, self.source)
@ -1061,12 +1059,12 @@ class VariableBuilder:
self.install_guards(GuardBuilder.ID_MATCH)
return FuncTorchInterpreterVariable(value)
elif isinstance(value, torch.Event):
self.install_guards(GuardBuilder.ID_MATCH)
torch._dynamo.utils.store_user_object_weakref(value)
self.install_guards(GuardBuilder.TYPE_MATCH)
index = register_user_object(value, self.source)
event_proxy = self.tx.output.create_proxy(
"call_function",
torch._dynamo.utils.get_user_object_from_id,
(id(value),),
get_user_object_by_index,
(index,),
{},
)
set_example_value(event_proxy.node, value)
@ -2980,8 +2978,8 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
set_example_value(proxy.node, example_value)
return SymNodeVariable(proxy, example_value, **options)
elif (
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, torch.Stream)
isinstance(example_value, torch.Stream)
and proxy.node.target == get_user_object_by_index
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()
@ -3069,6 +3067,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
set_example_value(proxy.node, example_value)
return ConstantVariable.create(example_value, **options)
else:
breakpoint()
unimplemented_v2(
gb_type="torch.* op returned non-Tensor",
context=f"example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}",

View File

@ -83,7 +83,6 @@ from ..utils import (
)
from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker
from .constant import ConstantVariable
from .ctx_manager import EventVariable, StreamVariable
from .dicts import (
ConstDictVariable,
DefaultDictVariable,
@ -101,6 +100,7 @@ from .lists import (
TupleIteratorVariable,
TupleVariable,
)
from .streams import EventVariable, StreamVariable
from .tensor import (
FakeItemVariable,
supported_comparison_ops,

View File

@ -34,7 +34,6 @@ from ..bytecode_transformation import (
create_instruction,
create_setup_with,
)
from ..device_interface import get_interface_for_device
from ..exc import unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalStateSource
@ -991,70 +990,6 @@ class ProfilerContextVariable(ContextWrappingVariable):
)
class StreamContextVariable(ContextWrappingVariable):
@staticmethod
def create(tx: "InstructionTranslator", target_value, **kwargs):
from .builder import wrap_fx_proxy_cls
current_stream_method = get_interface_for_device(
target_value.device
).current_stream
current_stream = wrap_fx_proxy_cls(
StreamVariable,
tx,
tx.output.create_proxy(
"call_function",
current_stream_method,
(None,),
{},
),
)
return StreamContextVariable(
target_values=[target_value],
initial_values=[current_stream],
device=target_value.device,
**kwargs,
)
def __init__(self, target_values, device, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.device = device
self.set_stream = get_interface_for_device(self.device).set_stream
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
def enter(self, tx):
# stream generated inside the traced function
if self.target_values[0].as_proxy() is not None:
tx.output.create_proxy(
"call_function",
self.set_stream,
(self.target_values[0].as_proxy(),),
{},
)
# stream passed from outside the traced function
else:
stream = self.target_values[0].value
tx.output.create_proxy(
"call_function",
self.set_stream_id,
(stream.stream_id, stream.device_index, stream.device_type),
{},
)
self.set_stream(self.target_values[0].value)
self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
def exit(self, tx: "InstructionTranslator", *args):
tx.output.create_proxy(
"call_function",
self.set_stream,
(self.initial_values[0].as_proxy(),),
{},
)
self.cleanup_assert()
class PreserveVersionContextVariable(ContextWrappingVariable):
"""
Wraps torch.autograd._unsafe_preserve_version_counter
@ -1262,142 +1197,6 @@ class SDPAKernelVariable(ContextWrappingVariable):
return "_sdpa_kernel_variadic"
class StreamVariable(VariableTracker):
def __init__(self, proxy, value, device, **kwargs) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
assert value.device.type == device.type, (
"stream value is not equal to the passed device"
)
super().__init__(**kwargs)
self.proxy = proxy
self.value = value
self.device = device
def python_type(self):
return torch.Stream
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
assert hasattr(self.value, name), f"no stream method found named {name}"
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait_stream", "synchronize", "wait_event"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return variables.ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=variables.ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name == "record_event":
return wrap_fx_proxy_cls(
target_cls=EventVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
# NB : Checking for mutation is necessary because we compare
# constant values
other = args[0]
if not isinstance(other, StreamVariable):
return variables.ConstantVariable.create(NotImplemented)
return variables.ConstantVariable.create(
cmp_name_to_op_mapping[name](self.value, other.value)
)
return super().call_method(tx, name, args, kwargs)
def as_proxy(self):
return self.proxy
def reconstruct(self, codegen: "PyCodegen"):
# If we got here, this stream is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
# Since we just proved that - for other such structures, like lists and dicts, reconstruction
# is fine and sound according to dynamo principles of treating collectives. However,
# streams are special in that we want to preserve the identity of the stream as the same as in the graph
# Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
# yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
# design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
prefix = f"_stream_{self.device}"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))
class EventVariable(VariableTracker):
def __init__(self, proxy, value, **kwargs) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
super().__init__(**kwargs)
self.proxy = proxy
self.value = value
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
from ..utils import proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait", "record", "synchronize"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return variables.ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=variables.ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
else:
method_name = (
f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
)
unimplemented_v2(
gb_type="Unsupported event method",
context=str(name),
explanation=f"Dynamo doesn't support tracing the {method_name} method. "
f"We currently support wait, record, synchronize, and query.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
def as_proxy(self):
return self.proxy
def reconstruct(self, codegen: "PyCodegen"):
# If we got here, this event is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
# Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
prefix = "_event"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))
class DynamoConfigPatchVariable(ContextWrappingVariable):
"""represents torch._dynamo.patch_dynamo_config"""

View File

@ -0,0 +1,420 @@
from collections import deque
from dataclasses import dataclass
from typing import Any, Optional
import torch
from torch.fx import Node, Proxy
from torch.utils._ordered_set import OrderedSet
from .. import graph_break_hints
from ..bytecode_transformation import create_call_function
from ..device_interface import get_interface_for_device
from ..exc import TYPE_CHECKING, unimplemented_v2
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import ContextWrappingVariable
from .misc import GetAttrVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from ..codegen import PyCodegen
from torch._library.custom_ops import custom_op
# Avoid circular dependency for the dataclass
TensorVariable = Any
Tensor = torch.Tensor
@custom_op("streams::fork", mutates_args={"args"})
def fork_stream_(
index: int, device: torch.device, device_index: int, args: list[Tensor]
) -> None:
pass
@fork_stream_.register_fake
def _(index: int, device: torch.device, device_index: int, args: list[Tensor]) -> None:
pass
@custom_op("streams::join", mutates_args={"args"})
def join_stream_(
index: int, device: torch.device, device_index: int, args: list[Tensor]
) -> None:
pass
@join_stream_.register_fake
def _(index: int, device: torch.device, device_index: int, args: list[Tensor]) -> None:
pass
_keep_alive = []
def add_dynamo_owned_stream(s):
global _keep_alive
_keep_alive.append(s)
# Stream state consists of the fork stream node
# and the external to the stream that are accessed from within the
# stream
@dataclass
class StreamState:
# the fork node that initiated the creation of this stream state
# we will finalize it once the stream state is popped
fork_node: Node
# Nodes not created within the stream
external_nodes: OrderedSet[Node]
# Nodes created within the stream
internal_nodes: OrderedSet[Node]
class StreamStateManager:
"""
Class used to track the current stream context we are in and identify
any used tensors as external (created outside the stream context) or
internal (created within the stream context). We use this information to
ensure the fork op is dependent on any external tensors, so that it will not
be reordered before them or after ops which use the externally created tensors.
Analagously, we use the internal tensors to ensure that the join op is not
reordered before any internally created tensors or after ops which use the
internally created tensors.
To actually implement this, we have a stack of stream states which track any external tensors that
have not yet been seen within the stream context and any tensors created within the stream context.
Once we exit the stream context we populate the args of fork with all external tensors which have been used,
and join with any internal tensors that were created.
"""
def __init__(self) -> None:
self.state_stack: deque[StreamState] = deque()
def in_stream_context(self) -> bool:
return bool(self.state_stack)
def track_internal_node(self, node: Node) -> None:
# if we are in a stream context, all created nodes are internal
if self.in_stream_context():
val = node.meta.get("example_value")
if isinstance(val, torch.Tensor):
# Only add tensor nodes
# if we have seen the node before, it is an internal
self._cur_state().internal_nodes.add(node)
def track_node(self, node: Node) -> None:
# If we are in a stream context, args of ops may be external
if self.in_stream_context():
val = node.meta.get("example_value")
if isinstance(val, torch.Tensor) and node not in self._internal_nodes():
self._external_nodes().add(node)
def push_stream_state(self, node: Node) -> None:
self.state_stack.append(StreamState(node, OrderedSet(), OrderedSet()))
def pop_stream_state(self) -> StreamState:
assert self.state_stack, "No stream state to pop"
return self.state_stack.pop()
def _cur_state(self) -> StreamState:
assert self.state_stack, "No stream state to pop"
return self.state_stack[-1]
def _internal_nodes(self) -> OrderedSet[Node]:
return self._cur_state().internal_nodes
def _external_nodes(self) -> OrderedSet[Node]:
return self._cur_state().external_nodes
stream_state_mgr = StreamStateManager()
class StreamContextVariable(ContextWrappingVariable):
"""This represents torch.cuda.StreamContext"""
@staticmethod
def create(
tx: "InstructionTranslator",
target_value: "StreamVariable",
**kwargs: dict[str, Any],
) -> "StreamContextVariable":
return StreamContextVariable(
target_values=[target_value],
initial_values=[
StreamContextVariable._get_current_stream(target_value.device, tx)
],
device=target_value.device,
**kwargs,
)
def __init__(
self,
target_values: list["StreamVariable"],
device: torch.device,
initial_values: Optional[list["StreamVariable"]] = None,
**kwargs: dict[str, Any],
) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.device = device
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
stream_id, device, device_index = (
StreamContextVariable._extract_stream_properties(
self._get_target_values()[0].as_proxy()
)
)
proxy = tx.output.create_proxy(
"call_function",
torch.ops.streams.fork.default,
(stream_id, device, device_index, []),
{},
)
stream_state_mgr.push_stream_state(proxy.node)
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker":
state = stream_state_mgr.pop_stream_state()
initial_stream_proxy = self.initial_values[0].as_proxy()
stream_id, device, device_index = (
StreamContextVariable._extract_stream_properties(initial_stream_proxy)
)
tx.output.create_node(
"call_function",
torch.ops.streams.join.default,
(
stream_id.node,
device.node,
device_index.node,
list(state.internal_nodes),
),
{},
)
state.fork_node.args = (
state.fork_node.args[0],
state.fork_node.args[1],
state.fork_node.args[2],
list(state.external_nodes),
)
return ConstantVariable.create(None)
@staticmethod
def _extract_stream_properties(stream_proxy: Proxy) -> tuple[Proxy, Proxy, Proxy]:
stream_index = GetAttrVariable.create_getattr_proxy(stream_proxy, "stream_id")
stream_device = GetAttrVariable.create_getattr_proxy(stream_proxy, "device")
stream_device_index = GetAttrVariable.create_getattr_proxy(
stream_proxy, "device_index"
)
return stream_index, stream_device, stream_device_index
@staticmethod
def _get_current_stream(
device: torch.device, tx: "InstructionTranslator"
) -> "StreamVariable":
from .builder import wrap_fx_proxy_cls
current_stream_method = get_interface_for_device(device).current_stream
current_stream = wrap_fx_proxy_cls(
StreamVariable,
tx,
tx.output.create_proxy(
"call_function",
current_stream_method,
(None,),
{},
),
)
return current_stream
def _get_target_values(self) -> list["StreamVariable"]:
# We need this to be overridable, since StreamVariable does
# not store target values (it does not require any arguments)
# and captures the current stream at the time of entering the context
return self.target_values
def supports_graph_breaks(self) -> bool:
return True
class StreamVariable(StreamContextVariable):
"""Represents the device-agnostic torch.Stream class"""
def __init__(
self,
proxy: Proxy,
value: torch.Stream,
device: torch.device,
**kwargs: Any,
) -> None:
user_ind = kwargs.pop("user_obj_index", None)
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
assert value.device.type == device.type, (
"stream value is not equal to the passed device"
)
super().__init__(target_values=[], initial_values=None, device=device, **kwargs)
self.proxy = proxy
self.value = value
self.device = device
self.user_ind = user_ind
def python_type(self) -> type:
return torch.Stream
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
assert hasattr(self.value, name), f"no stream method found named {name}"
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait_stream", "synchronize", "wait_event"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name == "record_event":
return wrap_fx_proxy_cls(
target_cls=EventVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
from ..guards import GuardBuilder, install_guard
if self.source:
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
# NB : Checking for mutation is necessary because we compare
# constant values
other = args[0]
if not isinstance(other, StreamVariable):
return ConstantVariable.create(NotImplemented)
if other.source:
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type]
)
return super().call_method(tx, name, args, kwargs)
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
# NB: Set initial values and target values when we enter
# Don't do this at object creation, as we need to record the current stream
# at the time the context is entered.
self.initial_values = [
StreamContextVariable._get_current_stream(self.device, tx)
]
return super().enter(tx)
def as_proxy(self) -> Proxy:
return self.proxy
def module_name(self) -> str:
return "torch._C"
def fn_name(self) -> str:
return "Stream"
def reconstruct(self, codegen: "PyCodegen") -> None:
# If we got here, this stream is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
if self.user_ind is not None:
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__,
"get_user_object_by_index",
)
)
codegen.append_output(codegen.create_load_const(self.user_ind))
codegen.extend_output(create_call_function(1, False))
else:
# TODO mlazos: evaluate if we still need this
prefix = f"_stream_{self.device}"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))
def _get_target_values(self) -> list["StreamVariable"]:
return [self]
class EventVariable(VariableTracker):
def __init__(self, proxy: Proxy, value: torch.Event, **kwargs: Any) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
super().__init__(**kwargs)
self.proxy = proxy
self.value = value
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
from ..utils import proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait", "record", "synchronize"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
else:
method_name = (
f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
)
unimplemented_v2(
gb_type="Unsupported event method",
context=str(name),
explanation=f"Dynamo doesn't support tracing the {method_name} method. "
f"We currently support wait, record, synchronize, and query.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
def as_proxy(self) -> Proxy:
return self.proxy
def reconstruct(self, codegen: "PyCodegen") -> None:
# If we got here, this event is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
# Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
prefix = "_event"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))

View File

@ -245,6 +245,7 @@ class BaseTorchVariable(VariableTracker):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
print(value)
self.value = value
def reconstruct(self, codegen: "PyCodegen"):

View File

@ -58,6 +58,7 @@ from ..exc import (
raise_observed_exception,
unimplemented_v2,
)
from ..graph_bytecode_inputs import get_user_object_by_index, register_user_object
from ..guards import GuardBuilder, install_guard
from ..source import (
AttrSource,
@ -66,6 +67,7 @@ from ..source import (
DictGetItemSource,
GetItemSource,
RandomValueSource,
TorchSource,
TypeDictSource,
TypeMROSource,
TypeSource,
@ -792,14 +794,33 @@ class UserDefinedClassVariable(UserDefinedVariable):
)
args = [stacked]
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
*proxy_args_kwargs(args, kwargs),
),
)
if issubclass(self.value, torch.Stream):
torch_stream_src = CallFunctionNoArgsSource(
AttrSource(TorchSource(), "Stream")
)
# Register newly created stream for reconstruction
stream = self.value()
from .streams import add_dynamo_owned_stream
add_dynamo_owned_stream(stream)
ind = register_user_object(stream, torch_stream_src)
breakpoint()
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function", get_user_object_by_index, (ind,), {}
),
user_obj_index=ind,
)
else:
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
*proxy_args_kwargs(args, kwargs),
),
)
return tensor_variable
elif self.value is random.Random:

View File

@ -49,6 +49,7 @@ static PyObject* THPEvent_pynew(
}
THPEvent* self = (THPEvent*)ptr.get();
self->weakreflist = nullptr;
// TODO: blocking and interprocess are not supported yet. To support them, the
// flag system of c10::Event needs to be refactored. C10::Event should also
@ -73,6 +74,7 @@ PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
TORCH_CHECK(self, "Failed to allocate memory for Event");
auto self_ = reinterpret_cast<THPEvent*>(self.get());
self_->weakreflist = nullptr;
new (&self_->event) c10::Event(device_type, flag);
return self.release();
}
@ -82,6 +84,9 @@ static void THPEvent_dealloc(THPEvent* self) {
pybind11::gil_scoped_release no_gil{};
self->event.~Event();
}
if (self->weakreflist != nullptr) {
PyObject_ClearWeakRefs((PyObject*)self);
}
Py_TYPE(self)->tp_free((PyObject*)self);
}
@ -300,7 +305,7 @@ PyTypeObject THPEventType = {
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
offsetof(THPEvent, weakreflist), /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
THPEvent_methods, /* tp_methods */

View File

@ -7,6 +7,7 @@
struct TORCH_API THPEvent {
PyObject_HEAD
c10::Event event;
PyObject* weakreflist;
};
TORCH_API extern PyTypeObject* THPEventClass;
TORCH_API extern PyTypeObject THPEventType;

View File

@ -95,6 +95,7 @@ static PyObject* THPStream_pynew(
self->device_index = static_cast<int64_t>(stream_opt->device_index());
self->device_type = static_cast<int64_t>(stream_opt->device_type());
self->context = nullptr;
self->weakreflist = nullptr;
return (PyObject*)ptr.release();
END_HANDLE_TH_ERRORS
@ -114,11 +115,15 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
self->device_index = static_cast<int64_t>(stream.device_index());
self->device_type = static_cast<int64_t>(stream.device_type());
self->context = nullptr;
self->weakreflist = nullptr;
return ptr.release();
END_HANDLE_TH_ERRORS
}
static void THPStream_dealloc(THPStream* self) {
if (self->weakreflist != nullptr) {
PyObject_ClearWeakRefs((PyObject*)self);
}
Py_TYPE(self)->tp_free((PyObject*)self);
}
@ -436,7 +441,7 @@ static PyTypeObject THPStreamType = {
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
THPStream_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
offsetof(THPStream, weakreflist), /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
// NOLINTNEXTLINE(*const-cast)

View File

@ -13,6 +13,7 @@ struct THPStream {
int64_t device_index;
// Used to switch stream context management, initialized lazily.
PyObject* context;
PyObject* weakreflist;
};
extern TORCH_API PyTypeObject* THPStreamClass;