Compare commits

...

8 Commits

Author SHA1 Message Date
f1fc0f62e0 [user-streams] Switch to fx annotations at trace time
ghstack-source-id: 78387255490e4ea1a17a077139dfed0c14b63a67
Pull-Request: https://github.com/pytorch/pytorch/pull/166472
2025-11-02 00:14:18 -07:00
75e545f81d [user-streams] cleanup StreamVariable signature
ghstack-source-id: 5d0ae060fba88a93233158b6b7c820b55cdc8c85
Pull-Request: https://github.com/pytorch/pytorch/pull/166471
2025-11-02 00:14:18 -07:00
6e6378784f [dynamo] Remove retrieving objects by ID
ghstack-source-id: 1616bfcfebfe3e39a28102942948723843e21dbe
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162905
2025-11-02 00:14:17 -07:00
abe6cbecca [user-streams] Add basic stream tests
ghstack-source-id: 76e57c28fc23bdfaf17bc1ac55b866fbcb48fd46
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164523

merge into streams suite
2025-11-02 00:14:17 -07:00
20e0aff9e5 [user-streams] Handle returning the current stream with/without device index
ghstack-source-id: 08ca30503d3c128fb09c30fa0956c2dc856bccf5
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165356
2025-11-02 00:14:16 -07:00
9113637f2d [user-streams] Track symbolic current stream
merge into stream tests

ghstack-source-id: 2b76bd383834639019014dca5cbf81ffd04724bf
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165212

symbolic streams update
2025-11-02 00:14:16 -07:00
075a01ec48 [user-streams] Add current stream source
ghstack-source-id: 5a9cef10d316ecd6e6eac0f01630a9d57938e999
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165211
2025-11-02 00:14:15 -07:00
f640d0c072 [user-streams] Fix stream graph output semantics
ghstack-source-id: eac4dc71e41d2dd0169da1a7491f42f96f2f71f9
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164819
2025-11-02 00:14:15 -07:00
18 changed files with 502 additions and 173 deletions

View File

@ -230,7 +230,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 20)
self.assertExpectedInline(str(cnts.op_count), """9""")
@unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/118204
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@ -335,7 +335,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 37)
self.assertExpectedInline(str(cnts.op_count), """15""")
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_stream_compared_with_constant(self):
@ -517,7 +517,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = opt_fn(x, cur_stream, new_stream)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 27)
self.assertExpectedInline(str(cnts.op_count), """16""")
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_event_method(self):
@ -537,7 +537,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
with torch.cuda.stream(new_stream):
x = torch.add(x, 4)
new_event = torch.cuda.Event()
new_event = torch.Event()
new_event.record(new_stream)
new_event.wait(cur_stream)
@ -557,7 +557,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 27)
self.assertExpectedInline(str(cnts.op_count), """16""")
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_device(self):

View File

@ -1540,7 +1540,7 @@ cannot resume from torch._dynamo.step_unsupported()
Developer debug context:
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0283.html
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0284.html
from user code:
File "test_error_messages.py", line N, in fn

View File

@ -1,12 +1,20 @@
# Owner(s): ["module: dynamo"]
import functools
import unittest
import weakref
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import requires_cuda
requires_multigpu = functools.partial(
unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices"
)
class TestStreams(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
@ -26,6 +34,144 @@ class TestStreams(torch._dynamo.test_case.TestCase):
e = torch.Event()
weakref.ref(e)
@requires_cuda
def test_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
@requires_cuda
def test_stream_context_graph_break(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
torch._dynamo.graph_break()
y = y + 1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
fn_opt = torch.compile(fn)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
@requires_cuda
def test_stream_input(self):
def fn(x, y, s):
z = torch.add(x, y)
y = z + 2
return y, s
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(device="cuda"))
expected = fn(*inp)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
@requires_cuda
def test_local_stream_return(self):
def fn(x, y):
s = torch.Stream()
z = torch.add(x, y)
y = z + 2
return y, s
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
fn_opt = torch.compile(fn, fullgraph=True)
_, s0 = fn_opt(*inp)
_, s1 = fn_opt(*inp)
# Streams will be different values for each invocation
# so don't check for equality
self.assertIsInstance(s0, torch.Stream)
# Stream should be newly allocated on each call
self.assertNotEqual(s0, s1)
@requires_cuda
def test_get_current_stream_return(self):
def fn(x, s):
with s:
s0 = torch.accelerator.current_stream()
return x, s0
s_inp = torch.Stream(device="cuda")
inp = (torch.ones(2, 2) + 1, s_inp)
fn_opt = torch.compile(fn, fullgraph=True)
_, s0 = fn_opt(*inp)
_, s1 = fn_opt(*inp)
self.assertEqual(s_inp, s0)
self.assertEqual(s0, s1)
@requires_cuda
@requires_multigpu()
def test_get_current_stream_return_different_device(self):
def fn(x, s0, s1):
with s1:
with s0:
s = torch.accelerator.current_stream(torch.device("cuda:1"))
return s
s0 = torch.Stream(device="cuda:0")
s1 = torch.Stream(device="cuda:1")
inp = (torch.ones(2, 2) + 1, s0, s1)
fn_opt = torch.compile(fn, fullgraph=True)
s_act = fn_opt(*inp)
s_exp = fn(*inp)
self.assertEqual(s_act, s_exp)
@requires_cuda
@requires_multigpu()
def test_get_current_stream_return_no_index(self):
def fn(x, s0, s1):
with s1:
with s0:
s = torch.accelerator.current_stream(torch.device("cuda"))
return s
s0 = torch.Stream(device="cuda:0")
s1 = torch.Stream(device="cuda:1")
inp = (torch.ones(2, 2) + 1, s0, s1)
fn_opt = torch.compile(fn, fullgraph=True)
s_act = fn_opt(*inp)
s_exp = fn(*inp)
self.assertEqual(s_act, s_exp)
def test_nested_stream_enter_exit(self):
pass
def test_stream_enter_exit_graph_break(self):
pass
def test_nested_stream_enter_exit_graph_break(self):
pass
def test_local_stream_enter_exit(self):
pass
def test_local_stream_nested_enter_exit(self):
pass
def test_stream_with_mutation(self):
pass
@requires_cuda
def test_run_opcheck(self):
from torch._dynamo.variables.streams import fork_stream, join_stream

View File

@ -466,6 +466,7 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
"handle_cudnn_is_acceptable", # No global state
"handle_assert", # No global state (constant)
"handle_nested_tensor", # No global state
"handle_current_stream", # Safely implemented
)
for fn in handlers:
if isinstance(fn, staticmethod) or inspect.ismethod(fn):

View File

@ -2,7 +2,7 @@
import sys
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
class DummyPrivateUse1Module:
@ -60,6 +60,9 @@ class TestExtensionUtils(TestCase):
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
torch._register_device_module("privateuseone", DummyPrivateUse1Module)
@skipIfTorchDynamo(
"accelerator doesn't compose with privateuse1 : https://github.com/pytorch/pytorch/issues/166696"
)
def test_external_module_register_with_renamed_backend(self):
torch.utils.rename_privateuse1_backend("foo")
with self.assertRaisesRegex(RuntimeError, "has already been set"):

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: PrivateUse1"]
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
class DummyPrivateUse1Module:
@ -31,6 +31,9 @@ class DummyPrivateUse1Module:
class TestRenamePrivateuseoneToExistingBackend(TestCase):
@skipIfTorchDynamo(
"TorchDynamo exposes https://github.com/pytorch/pytorch/issues/166696"
)
def test_external_module_register_with_existing_backend(self):
torch.utils.rename_privateuse1_backend("maia")
with self.assertRaisesRegex(RuntimeError, "has already been set"):

View File

@ -155,7 +155,6 @@ def reset() -> None:
GenerationTracker.clear()
TensorifyState.clear()
torch._dynamo.utils.warn_once_cache.clear()
torch._dynamo.utils.user_obj_id_to_weakref.clear()
torch._C._autograd._saved_tensors_hooks_set_tracing(False)

View File

@ -2495,6 +2495,30 @@
}
],
"GB0249": [
{
"Gb_type": "bad device argument to torch.accelerator.current_stream",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
},
{
"Gb_type": "bad device argument to torch.get_device_module",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
},
{
"Gb_type": "bad device argument to torch.accelerator.current_stream",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
},
{
"Gb_type": "bad device argument to torch.get_device_module",
"Context": "args={args}, kwargs={kwargs}",
@ -2853,6 +2877,14 @@
}
],
"GB0283": [
{
"Gb_type": "Failed to make weakref to graph-created external object",
"Context": "user_object: {example_value}",
"Explanation": "Object does not allow us to make a weakref to it",
"Hints": []
}
],
"GB0284": [
{
"Gb_type": "cannot resume from torch._dynamo.step_unsupported()",
"Context": "",
@ -2863,5 +2895,25 @@
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
]
}
],
"GB0285": [
{
"Gb_type": "unsupported arguments to torch.accelerator.current_stream",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "torch.accelerator.current_stream accepts one optional argument `device`",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0286": [
{
"Gb_type": "bad device argument to torch.get_device_module",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
]
}

View File

@ -1,9 +1,11 @@
import weakref
from typing import Any
from typing import Any, Callable
from torch._dynamo.source import Source
PyCodegen = Any
# This file is to handle types that we don't want to support
# as explicit FX graph inputs. This uses a sidetable which
# we populate in bytecode and is loaded during graph execution
@ -11,44 +13,70 @@ from torch._dynamo.source import Source
# We use a dynamo-generated index as a level of indirection
# this allows us to register objects externally in pre-graph bytecode that we want
# to pass to the graph, but not support their types as graph inputs
index_to_source: dict[int, Source] = {}
index_to_bytecode_constructor: dict[int, Callable[[PyCodegen], None]] = {}
index_to_user_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
index_to_external_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
keep_alive: list[Any] = []
def has_user_objects() -> bool:
return bool(index_to_source)
return bool(index_to_bytecode_constructor)
def get_user_object_by_index(index: int) -> Any:
assert index in index_to_user_object_weakref, (
def get_external_object_by_index(index: int) -> Any:
assert index in index_to_external_object_weakref, (
"Index not registered in index_to_user_object_weakref"
)
obj = index_to_user_object_weakref[index]()
obj = index_to_external_object_weakref[index]()
assert obj is not None, "User object is no longer alive"
return index_to_user_object_weakref[index]()
return index_to_external_object_weakref[index]()
def store_user_object_weakrefs(*args: Any) -> None:
global index_to_user_object_weakref
index_to_user_object_weakref.clear()
index_to_user_object_weakref.update(
global index_to_external_object_weakref
index_to_external_object_weakref.clear()
index_to_external_object_weakref.update(
{i: weakref.ref(arg) for i, arg in enumerate(args)}
)
def reset_user_object_tracking() -> None:
index_to_source.clear()
index_to_user_object_weakref.clear()
index_to_bytecode_constructor.clear()
index_to_external_object_weakref.clear()
keep_alive.clear()
def register_graph_created_object(
example_value: Any, construct_fn: Callable[[int, PyCodegen], None]
) -> int:
global index_to_bytecode_constructor
global keep_alive
keep_alive.append(example_value)
index = len(index_to_bytecode_constructor)
index_to_bytecode_constructor[index] = lambda cg: construct_fn(index, cg)
try:
index_to_external_object_weakref[index] = weakref.ref(example_value)
except TypeError as e:
from .exc import unimplemented_v2
unimplemented_v2(
gb_type="Failed to make weakref to graph-created external object",
context=f"user_object: {example_value}",
explanation="Object does not allow us to make a weakref to it",
hints=[],
from_exc=e,
)
return index
# Register a user object to be used in the graph
def register_user_object(value: Any, source: Source) -> int:
global index_to_source
index = len(index_to_source)
index_to_source[index] = source
global index_to_bytecode_constructor
index = len(index_to_bytecode_constructor)
index_to_bytecode_constructor[index] = lambda cg: cg(source)
try:
index_to_user_object_weakref[index] = weakref.ref(value)
index_to_external_object_weakref[index] = weakref.ref(value)
except TypeError as e:
from .exc import unimplemented_v2

View File

@ -132,6 +132,7 @@ from .source import (
CodeSource,
ConstantSource,
ConstDictKeySource,
CurrentStreamSource,
DataclassFieldsSource,
DefaultsSource,
DictGetItemSource,
@ -181,6 +182,7 @@ from .utils import (
common_constant_types,
dataclass_fields,
dict_keys,
get_current_stream,
get_custom_getattr,
get_torch_function_mode_stack,
get_torch_function_mode_stack_at,
@ -759,6 +761,7 @@ def _get_closure_vars() -> dict[str, object]:
"___dataclass_fields": dataclass_fields,
"___namedtuple_fields": lambda x: x._fields,
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
"___get_current_stream": get_current_stream,
"__math_isnan": math.isnan,
"__numpy_isnan": None if np is None else np.isnan,
"inf": float("inf"),
@ -1448,6 +1451,13 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, CurrentStreamSource):
out = root_guard_manager.lambda_manager(
python_lambda=lambda _: get_current_stream(source.device),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, GradSource):
assert base_guard_manager # to make mypy happy
out = base_guard_manager.grad_manager(

View File

@ -101,7 +101,7 @@ from .exc import (
unimplemented_v2,
unimplemented_v2_with_warning,
)
from .graph_bytecode_inputs import has_user_objects, index_to_source
from .graph_bytecode_inputs import has_user_objects, index_to_bytecode_constructor
from .graph_deduplication import apply_graph_deduplication
from .graph_region_tracker import GraphRegionTracker
from .guards import GuardBuilder, install_guard
@ -1541,9 +1541,19 @@ class OutputGraph(OutputGraphCommon):
"store_user_object_weakrefs",
)
)
for source in reversed(index_to_source.values()):
codegen(source)
codegen.call_function(len(index_to_source), False)
tmp_vars = []
for constructor in reversed(index_to_bytecode_constructor.values()):
constructor(codegen)
var_name = (
self.new_var()
) # keep alive any temp objects for the rest of the frame
codegen.store(var_name)
tmp_vars.append(var_name)
for var_name in tmp_vars:
codegen.append_output(codegen.create_load(var_name))
codegen.call_function(len(index_to_bytecode_constructor), False)
codegen.pop_top()
self.add_output_instructions(codegen.get_instructions())

View File

@ -23,6 +23,7 @@ import functools
from collections.abc import Callable
from typing import Any, Optional, TYPE_CHECKING, Union
from torch import device as device_type
from torch._guards import ChainedSource, Guard, GuardSource, Source
from . import utils
@ -1082,6 +1083,30 @@ class ShapeEnvSource(Source):
return GuardSource.SHAPE_ENV
@dataclasses.dataclass(frozen=True)
class CurrentStreamSource(Source):
device: device_type
def name(self) -> str:
return f"___get_current_stream(torch.device('{self.device.type}', {self.device.index}))"
def reconstruct(self, codegen: "PyCodegen") -> None:
num_args = 1
codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "get_current_stream")
)
codegen.add_push_null(lambda: codegen.load_import_from("torch", "device"))
codegen.extend_output([codegen.create_load_const(self.device.type)])
if self.device.index is not None:
num_args += 1
codegen.extend_output([codegen.create_load_const(self.device.index)])
codegen.extend_output(create_call_function(num_args, False))
codegen.extend_output(create_call_function(1, False))
def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL
@dataclasses.dataclass(frozen=True)
class BackwardStateSource(Source):
def name(self) -> str:

View File

@ -171,6 +171,7 @@ from .variables.misc import (
UnknownVariable,
)
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
from .variables.streams import SymbolicStreamState
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
from .variables.torch_function import (
SymbolicTorchFunctionState,
@ -1104,6 +1105,7 @@ class InstructionTranslatorBase(
symbolic_locals: dict[str, VariableTracker]
symbolic_globals: dict[str, VariableTracker]
symbolic_torch_function_state: SymbolicTorchFunctionState
symbolic_stream_state: SymbolicStreamState
post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]]
stack: list[VariableTracker]
instruction_pointer: Optional[int]
@ -4243,6 +4245,7 @@ class InstructionTranslatorBase(
symbolic_locals: dict[str, VariableTracker],
symbolic_globals: dict[str, VariableTracker],
symbolic_torch_function_state: SymbolicTorchFunctionState,
symbolic_stream_state: SymbolicStreamState,
f_code: types.CodeType,
export: bool,
inline_depth: int,
@ -4262,6 +4265,7 @@ class InstructionTranslatorBase(
self.symbolic_locals = symbolic_locals
self.symbolic_globals = symbolic_globals
self.symbolic_torch_function_state = symbolic_torch_function_state
self.symbolic_stream_state = symbolic_stream_state
# used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals)
# in order to generate any nested closures
self.post_prune_cell_and_freevars = None
@ -4416,6 +4420,7 @@ class InstructionTranslator(InstructionTranslatorBase):
# A global var is inserted only after a STORE_GLOBAL happens to it
symbolic_globals={},
symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
symbolic_stream_state=None, # type: ignore[arg-type] # set below
f_code=f_code,
export=export,
inline_depth=0,
@ -4520,6 +4525,8 @@ class InstructionTranslator(InstructionTranslatorBase):
torch_function_mode_stack
)
self.symbolic_stream_state = SymbolicStreamState()
if export:
# export gets confused if we never realize unused inputs
# in export mode just eagerly realize everything
@ -4846,6 +4853,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_state,
parent.symbolic_stream_state,
func,
)
else:
@ -4857,6 +4865,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_state,
parent.symbolic_stream_state,
# pyrefly: ignore [bad-argument-type]
func,
)
@ -4940,6 +4949,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
symbolic_locals: dict[str, VariableTracker],
symbolic_globals: dict[str, VariableTracker],
symbolic_torch_function_state: SymbolicTorchFunctionState,
symbolic_stream_state: SymbolicStreamState,
funcvar: BaseUserFunctionVariable,
) -> None:
f_globals = funcvar.get_globals() # type: ignore[attr-defined]
@ -4973,6 +4983,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
symbolic_locals=symbolic_locals,
symbolic_globals=symbolic_globals,
symbolic_torch_function_state=symbolic_torch_function_state,
symbolic_stream_state=symbolic_stream_state,
instructions=instructions,
code_options={k: getattr(code, k) for k in get_code_keys()},
f_code=code,

View File

@ -4695,6 +4695,10 @@ def clear_torch_function_mode_stack() -> None:
_pop_torch_function_stack()
def get_current_stream(device: torch.device) -> torch.Stream:
return torch.accelerator.current_stream(device)
# call from C dynamo in order to inspect values in pdb
def _breakpoint_for_c_dynamo(*args: Any) -> None:
breakpoint()
@ -4759,33 +4763,8 @@ def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]:
return tensor_dict
# This is useful for reconstructing within the Dynamo graph the non-graph-input objects
# whose lifetime is governed by the user.
# e.g. torch.cuda.Event is a prime example.
user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {}
# TODO: mlazos to remove after replacing w/ above API
def get_user_object_from_id(obj_id: int) -> Any:
obj = user_obj_id_to_weakref[obj_id]()
assert obj is not None, "User object is no longer alive"
return obj
def store_user_object_weakref(obj: object) -> None:
obj_id = id(obj)
try:
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)
except TypeError as e:
from .exc import unimplemented_v2
unimplemented_v2(
gb_type="Failed to make weakref to User Object when storing by ID",
context=f"user_objected: {obj}",
explanation="Object does not allow us to make a weakref to it",
hints=[],
from_exc=e,
)
def build_stream(args: tuple[Any], kwargs: dict[Any, Any]) -> torch.Stream:
return torch._C.Stream(*args, **kwargs)
class CompileTimeInstructionCounter:

View File

@ -46,7 +46,7 @@ import torch
from torch import SymInt
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.graph_bytecode_inputs import (
get_user_object_by_index,
get_external_object_by_index,
register_user_object,
)
from torch._dynamo.utils import (
@ -1057,13 +1057,12 @@ class VariableBuilder:
self.install_guards(GuardBuilder.TYPE_MATCH)
index = register_user_object(value, self.source)
stream_proxy = self.tx.output.create_proxy(
"call_function", get_user_object_by_index, (index,), {}
"call_function", get_external_object_by_index, (index,), {}
)
set_example_value(stream_proxy.node, value)
var = StreamVariable(
stream_proxy,
value,
value.device,
source=self.source,
)
return self.tx.output.side_effects.track_object_existing(value, var)
@ -1078,7 +1077,7 @@ class VariableBuilder:
index = register_user_object(value, self.source)
event_proxy = self.tx.output.create_proxy(
"call_function",
get_user_object_by_index,
get_external_object_by_index,
(index,),
{},
)
@ -3006,14 +3005,15 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
set_example_value(proxy.node, example_value)
return SymNodeVariable(proxy, example_value, **options)
elif (
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, torch.Stream)
isinstance(example_value, torch.Stream)
and proxy.node.target
in (get_external_object_by_index, torch.accelerator.current_stream)
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()
]:
set_example_value(proxy.node, example_value)
return StreamVariable(proxy, example_value, example_value.device, **options)
return StreamVariable(proxy, example_value, **options)
elif (
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, torch.Event)

View File

@ -1,15 +1,18 @@
from typing import Any, Optional
import collections
from typing import Any, Callable, Optional
import torch
from torch._dynamo.variables.dicts import ConstDictVariable
from torch._dynamo.variables.lists import TupleVariable
from torch.fx import Proxy
from .. import graph_break_hints
from ..device_interface import get_interface_for_device
from ..bytecode_transformation import create_call_function
from ..exc import TYPE_CHECKING, unimplemented_v2
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import ContextWrappingVariable
from .misc import GetAttrVariable
from .ctx_manager import FxTracebackAnnotateVariable
from .lazy import LazyVariableTracker
if TYPE_CHECKING:
@ -63,104 +66,89 @@ def _(
pass
class StreamContextVariable(ContextWrappingVariable):
class SymbolicStreamState:
"""Track the currently entered stream if any"""
def __init__(self) -> None:
from ..source import CurrentStreamSource
cur_stack: list[StreamVariable] = []
if torch.accelerator.is_available():
stream_var = LazyVariableTracker.create(
torch.accelerator.current_stream(),
source=CurrentStreamSource(torch.accelerator.current_stream().device),
)
cur_stack = [stream_var] # type: ignore[list-item]
self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque(
cur_stack
)
def enter_stream(self, stream: "StreamVariable") -> None:
self.cur_stream_stack.append(stream)
def exit_stream(self) -> None:
self.cur_stream_stack.pop()
def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable":
if device is not None:
for stream in reversed(self.cur_stream_stack):
if stream.device == device:
return stream
return self.cur_stream_stack[-1]
def in_stream_context(self) -> bool:
return len(self.cur_stream_stack) > 0
class StreamContextVariable(FxTracebackAnnotateVariable):
"""This represents torch.cuda.StreamContext"""
@staticmethod
def create(
tx: "InstructionTranslator",
target_value: "StreamVariable",
stream_to_enter: "StreamVariable",
**kwargs: dict[str, Any],
) -> "StreamContextVariable":
return StreamContextVariable(
target_values=[target_value],
initial_values=[
StreamContextVariable._get_current_stream(target_value.device, tx)
],
device=target_value.device,
stream_to_enter,
**kwargs,
)
def __init__(
self,
target_values: list["StreamVariable"],
device: torch.device,
initial_values: Optional[list["StreamVariable"]] = None,
stream: Optional["StreamVariable"],
**kwargs: dict[str, Any],
) -> None:
self.stream = stream
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
target_values={"stream": self.get_stream().user_object_index},
initial_values=None,
**kwargs,
)
# pyrefly: ignore [read-only]
self.device = device
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
def enter(
self, tx: "InstructionTranslator", *args: tuple[Any]
) -> "VariableTracker":
# to stream, from stream is the order of the arguments
# we are entering the target, and leaving the initial stream
tx.output.create_proxy(
"call_function",
torch.ops.streams.fork.default,
self._target_stream_proxies() + self._initial_stream_proxies(),
{},
)
return ConstantVariable.create(None)
tx.symbolic_stream_state.enter_stream(self.get_stream())
return super().enter(tx)
def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker":
# to stream, from stream is the order of the arguments
# we are leaving the target, and entering the initial stream
tx.output.create_proxy(
"call_function",
torch.ops.streams.join.default,
self._initial_stream_proxies() + self._target_stream_proxies(),
{},
)
return ConstantVariable.create(None)
def _initial_stream_proxies(self) -> tuple[Proxy, Proxy]:
assert self.initial_values, "No initial stream to move from"
return StreamContextVariable._extract_stream_properties(
self.initial_values[0].as_proxy()
)
def _target_stream_proxies(self) -> tuple[Proxy, Proxy]:
return StreamContextVariable._extract_stream_properties(
self._get_target_values()[0].as_proxy()
)
@staticmethod
def _extract_stream_properties(stream_proxy: Proxy) -> tuple[Proxy, Proxy]:
stream_index = GetAttrVariable.create_getattr_proxy(stream_proxy, "stream_id")
stream_device = GetAttrVariable.create_getattr_proxy(stream_proxy, "device")
return stream_index, stream_device
@staticmethod
def _get_current_stream(
device: torch.device, tx: "InstructionTranslator"
) -> "StreamVariable":
from .builder import wrap_fx_proxy_cls
current_stream_method = get_interface_for_device(device).current_stream
current_stream = wrap_fx_proxy_cls(
StreamVariable,
tx,
tx.output.create_proxy(
"call_function",
current_stream_method,
(None,),
{},
),
)
return current_stream
def _get_target_values(self) -> list["StreamVariable"]:
# We need this to be overridable, since StreamVariable does
# not store target values (it does not require any arguments)
# and captures the current stream at the time of entering the context
return self.target_values
tx.symbolic_stream_state.exit_stream()
return super().exit(tx, *args)
def supports_graph_breaks(self) -> bool:
return True
def get_stream(self) -> "StreamVariable":
assert self.stream, "Stream context should have a separate stream"
return self.stream
class StreamVariable(StreamContextVariable):
"""Represents the device-agnostic torch.Stream class"""
@ -169,19 +157,21 @@ class StreamVariable(StreamContextVariable):
self,
proxy: Proxy,
value: torch.Stream,
device: torch.device,
**kwargs: Any,
) -> None:
# Index into the user object table
# used to pass arbitrary objects to the graph
user_object_index = kwargs.pop("user_obj_index", None)
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
assert value.device.type == device.type, (
"stream value is not equal to the passed device"
)
super().__init__(target_values=[], initial_values=None, device=device, **kwargs)
self.proxy = proxy
self.value = value
# pyrefly: ignore [read-only]
self.device = device
self.device = value.device
# pyrefly: ignore [read-only]
self.user_object_index = user_object_index
super().__init__(None, **kwargs)
def python_type(self) -> type:
return torch.Stream
@ -240,15 +230,6 @@ class StreamVariable(StreamContextVariable):
return super().call_method(tx, name, args, kwargs)
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
# NB: Set initial values when we enter
# Don't do this at object creation, as we need to record the current stream
# at the time the context is entered.
self.initial_values = [
StreamContextVariable._get_current_stream(self.device, tx)
]
return super().enter(tx)
def as_proxy(self) -> Proxy:
return self.proxy
@ -262,18 +243,39 @@ class StreamVariable(StreamContextVariable):
# If we got here, this stream is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
# Since we just proved that - for other such structures, like lists and dicts, reconstruction
# is fine and sound according to dynamo principles of treating collectives. However,
# streams are special in that we want to preserve the identity of the stream as the same as in the graph
# Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
# yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
# design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
prefix = f"_stream_{self.device}"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))
if self.user_object_index is not None:
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__,
"get_external_object_by_index",
)
)
codegen.append_output(codegen.create_load_const(self.user_object_index))
codegen.extend_output(create_call_function(1, False))
else:
# TODO mlazos: evaluate if we still need this
prefix = f"_stream_{self.device}"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))
def _get_target_values(self) -> list["StreamVariable"]:
return [self]
def get_stream(self) -> "StreamVariable":
return self
@staticmethod
def make_construct_in_graph_stream_fn(
args: TupleVariable, kwargs: ConstDictVariable
) -> Callable[[int, "PyCodegen"], None]:
def fn(index: int, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.utils.__name__, "build_stream"
)
)
codegen(args)
codegen(kwargs)
codegen.extend_output(create_call_function(2, False))
return fn
class EventVariable(VariableTracker):

View File

@ -1268,6 +1268,35 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# pyrefly: ignore [unbound-name]
return VariableTracker.build(tx, module, new_source)
@register(torch.accelerator.current_stream)
def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs):
if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs):
unimplemented_v2(
gb_type="unsupported arguments to torch.accelerator.current_stream",
context=f"args={args}, kwargs={kwargs}",
explanation="torch.accelerator.current_stream accepts one optional argument `device`",
hints=[
*graph_break_hints.USER_ERROR,
],
)
try:
if kwargs:
device = torch.device(kwargs["device"].as_python_constant())
elif args:
device = torch.device(args[0].as_python_constant())
else:
device = None
return tx.symbolic_stream_state.cur_stream(device)
except Exception as e:
unimplemented_v2(
gb_type="bad device argument to torch.accelerator.current_stream",
context=f"args={args}, kwargs={kwargs}",
explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
hints=[*graph_break_hints.USER_ERROR],
from_exc=e,
)
@register(torch.set_default_device)
def handle_set_default_device(
self, tx: "InstructionTranslator", *args, **kwargs

View File

@ -58,6 +58,7 @@ from ..exc import (
raise_observed_exception,
unimplemented_v2,
)
from ..graph_bytecode_inputs import get_external_object_by_index
from ..guards import GuardBuilder, install_guard
from ..source import (
AttrSource,
@ -94,7 +95,7 @@ from ..utils import (
unpatched_nn_module_getattr,
)
from .base import raise_type_error_exc, ValueMutationNew, VariableTracker
from .dicts import DefaultDictVariable
from .dicts import ConstDictVariable, DefaultDictVariable
from .lists import SizeVariable
@ -809,14 +810,44 @@ class UserDefinedClassVariable(UserDefinedVariable):
)
args = [stacked]
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
*proxy_args_kwargs(args, kwargs),
),
)
if issubclass(self.value, torch.Stream):
from .constant import ConstantVariable
from .lists import TupleVariable
# Register newly created stream for reconstruction
var_kwargs = ConstDictVariable(
{ConstantVariable(k): v for k, v in kwargs.items()}
)
var_args = TupleVariable(list(args))
stream = self.value(
*(var_args.as_python_constant()),
**(var_kwargs.as_python_constant()),
)
from ..graph_bytecode_inputs import register_graph_created_object
from .streams import StreamVariable
ind = register_graph_created_object(
stream,
StreamVariable.make_construct_in_graph_stream_fn(
var_args, var_kwargs
),
)
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function", get_external_object_by_index, (ind,), {}
),
user_obj_index=ind,
)
else:
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
*proxy_args_kwargs(args, kwargs),
),
)
return tensor_variable
elif self.value is random.Random: