Compare commits

...

20 Commits

Author SHA1 Message Date
bcf1a53297 backward working kind of 2 2025-10-14 10:37:37 -07:00
75c6f9b93b backward working kind of 2025-10-13 17:39:06 -07:00
cf8fb02c33 [user-streams] Add backward support for fork/join 2025-10-13 16:20:45 -07:00
bfd2b03577 [dynamo] Remove retrieving objects by ID
ghstack-source-id: a0ca523223a05e9128393c4a7a1c2c14edb588ba
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162905
2025-10-13 14:23:02 -07:00
6441f7a7fe [user-streams] Add basic stream tests
ghstack-source-id: 74749ef6e8684b42565f7bc57cc9a4875088b8d1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164523

merge into streams suite
2025-10-13 14:23:01 -07:00
6d30dba93d [user-streams] Handle returning the current stream with/without device index
ghstack-source-id: a7ff4bad74eccff1f9b590242188eac1ebe203a8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165356
2025-10-13 14:22:59 -07:00
dc90a72bb5 [user-streams] Track symbolic current stream
merge into stream tests

ghstack-source-id: 0fab78f039eb26365195a7b5ceca756dd58e9724
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165212

symbolic streams update
2025-10-13 14:22:58 -07:00
44de0318c4 [user-streams] Add current stream source
ghstack-source-id: 71305220b562005f6d581506ba35c0ae06d470a9
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165211
2025-10-13 14:22:56 -07:00
66c8640559 [user-streams] Fix stream graph output semantics
ghstack-source-id: 75778deaa3a00c5162ada276aadeebce7f5ffce9
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164819

updates to graph semantics changes

More fixes

ghstack-source-id: 75778deaa3a00c5162ada276aadeebce7f5ffce9
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165027

bytecode output fixes
2025-10-13 14:22:55 -07:00
923a7c7bcc [User-streams] Make torch.Event weakref compatible
ghstack-source-id: 048d1a0f66984d6100352dc1df116f7367372155
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164522
2025-10-13 14:22:54 -07:00
e2b0cfe647 [user-streams] Make cuda streams weakref compatible
ghstack-source-id: 78c8a6c26fb331d3c124bc4c068608457d2506ce
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164304
2025-10-13 14:22:52 -07:00
60508c7ed8 [user-cuda-streams] Add cuda streams test suite
ghstack-source-id: 9df39f429ea341e48c001af566f4196552b92882
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162901
2025-10-13 14:22:51 -07:00
a43c5f210b [user-streams] Support streams as contexts
ghstack-source-id: 2c80cbd5b1e8f1f27527474bf1f602397e0a5899
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164507
2025-10-13 14:22:49 -07:00
e8bd37d77c [user-streams] Have StreamVariable inherit from StreamContextVariable
ghstack-source-id: d94859d8de4b319e74180d5d0a8a98989941d779
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164344

inheritance changes
2025-10-13 14:22:48 -07:00
28742d61be [user-streams] Move StreamContextVariable into streams module
finish moving

ghstack-source-id: 50bd3f8e73009232df51c275e5b04090009f7215
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164343
2025-10-13 14:22:47 -07:00
104dec4c55 [user-streams] update stream context to use fork/join
ghstack-source-id: 761aa0919208ca365486582ae94418f6d2db1dfb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162903
2025-10-13 14:22:46 -07:00
4040707f1f [user-cuda-streams] Add fork/join custom ops
Make custom ops inplace

ghstack-source-id: bd66808560960c755fd9a4c42a9374fa20543737
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162900
2025-10-13 14:22:44 -07:00
c0ec620a09 [user-streams] Handle aliasing properly
ghstack-source-id: 4cb9a4413f22d5d4c5fa8f24326b2e646c7bec6d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163028
2025-10-13 14:22:43 -07:00
47d2882ea6 [user-cuda-streams] Pass streams/events to the graph via lookup table
ghstack-source-id: 1ea262874768b8e125708d29ca63f0b7a13419b4
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162899

test fixes
2025-10-13 14:22:43 -07:00
2df9f24b3f [user-streams] Move stream code to streams module
ghstack-source-id: e8346745d1a459c284ecbafe439c06f0e9070689
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163027
2025-10-13 14:22:41 -07:00
21 changed files with 920 additions and 264 deletions

187
test/dynamo/test_streams.py Normal file
View File

@ -0,0 +1,187 @@
# Owner(s): ["module: dynamo"]
import weakref
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch.testing._internal.common_utils import requires_cuda
class TestStreams(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
@classmethod
def tearDownClass(cls):
super().tearDownClass()
def test_stream_weakref(self):
s = torch.Stream()
weakref.ref(s)
def test_event_weakref(self):
e = torch.Event()
weakref.ref(e)
def test_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_stream_context_graph_break(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
torch._dynamo.graph_break()
y = y + 1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
fn_opt = torch.compile(fn)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_stream_input(self):
def fn(x, y, s):
z = torch.add(x, y)
y = z + 2
return y, s
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(device="cuda"))
expected = fn(*inp)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_local_stream_return(self):
def fn(x, y):
s = torch.Stream()
z = torch.add(x, y)
y = z + 2
return y, s
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
fn_opt = torch.compile(fn, fullgraph=True)
_, s0 = fn_opt(*inp)
_, s1 = fn_opt(*inp)
# Streams will be different values for each invocation
# so don't check for equality
self.assertIsInstance(s0, torch.Stream)
# Stream should be newly allocated on each call
self.assertNotEqual(s0, s1)
def test_get_current_stream_return(self):
def fn(x, s):
with s:
s0 = torch.accelerator.current_stream()
return x, s0
s_inp = torch.Stream(device="cuda")
inp = (torch.ones(2, 2) + 1, s_inp)
fn_opt = torch.compile(fn, fullgraph=True)
_, s0 = fn_opt(*inp)
_, s1 = fn_opt(*inp)
self.assertEqual(s_inp, s0)
self.assertEqual(s0, s1)
def test_get_current_stream_return_different_device(self):
def fn(x, s0, s1):
with s1:
with s0:
s = torch.accelerator.current_stream(torch.device("cuda:1"))
return s
s0 = torch.Stream(device="cuda:0")
s1 = torch.Stream(device="cuda:1")
inp = (torch.ones(2, 2) + 1, s0, s1)
fn_opt = torch.compile(fn, fullgraph=True)
s_act = fn_opt(*inp)
s_exp = fn(*inp)
self.assertEqual(s_act, s_exp)
def test_get_current_stream_return_no_index(self):
def fn(x, s0, s1):
with s1:
with s0:
s = torch.accelerator.current_stream(torch.device("cuda"))
return s
s0 = torch.Stream(device="cuda:0")
s1 = torch.Stream(device="cuda:1")
inp = (torch.ones(2, 2) + 1, s0, s1)
fn_opt = torch.compile(fn, fullgraph=True)
s_act = fn_opt(*inp)
s_exp = fn(*inp)
self.assertEqual(s_act, s_exp)
def test_fork_join_backward(self):
def fn(x, s0):
with s0:
y = torch.add(x, x)
return y
inp = (torch.ones(2, 2, requires_grad=True) + 1, torch.Stream(device="cuda"))
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
actual.sum().backward()
# expected = fn(*inp)
# expected.sum().backward()
# self.assertEqual(expected, actual)
def test_nested_stream_enter_exit(self):
pass
def test_stream_enter_exit_graph_break(self):
pass
def test_nested_stream_enter_exit_graph_break(self):
pass
def test_local_stream_enter_exit(self):
pass
def test_local_stream_nested_enter_exit(self):
pass
def test_stream_with_mutation(self):
pass
@requires_cuda
def test_run_opcheck(self):
from torch._dynamo.variables.streams import fork_stream, join_stream
from torch.library import opcheck
sample_inputs = [
(0, torch.device("cuda:0"), 1, torch.device("cuda:1")),
(2, torch.device("cuda:2"), 3, torch.device("cuda:1")),
]
for args in sample_inputs:
opcheck(fork_stream, args)
opcheck(join_stream, args)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -153,7 +153,6 @@ def reset() -> None:
GenerationTracker.clear()
TensorifyState.clear()
torch._dynamo.utils.warn_once_cache.clear()
torch._dynamo.utils.user_obj_id_to_weakref.clear()
torch._C._autograd._saved_tensors_hooks_set_tracing(False)

View File

@ -116,6 +116,7 @@ from .exc import (
unimplemented_v2,
Unsupported,
)
from .graph_bytecode_inputs import reset_user_object_tracking
from .guards import (
CheckFunctionManager,
get_and_maybe_log_recompilation_reasons,
@ -314,6 +315,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
torch.fx._symbolic_trace._maybe_revert_all_patches()
)
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
reset_user_object_tracking()
try:
return fn(*args, **kwargs)
finally:

View File

@ -2495,6 +2495,14 @@
}
],
"GB0249": [
{
"Gb_type": "bad device argument to torch.accelerator.current_stream",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
},
{
"Gb_type": "bad device argument to torch.get_device_module",
"Context": "args={args}, kwargs={kwargs}",
@ -2734,6 +2742,12 @@
}
],
"GB0272": [
{
"Gb_type": "Failed to make weakref to User Object when storing by ID",
"Context": "user_objected: {obj}",
"Explanation": "Object does not allow us to make a weakref to it",
"Hints": []
},
{
"Gb_type": "Failed to make weakref to User Object",
"Context": "user_objected: {obj}",
@ -2776,5 +2790,41 @@
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
]
}
],
"GB0276": [
{
"Gb_type": "Failed to make weakref to User Object",
"Context": "user_object: {value}",
"Explanation": "Object does not allow us to make a weakref to it",
"Hints": []
}
],
"GB0277": [
{
"Gb_type": "Failed to make weakref to graph-created external object",
"Context": "user_object: {example_value}",
"Explanation": "Object does not allow us to make a weakref to it",
"Hints": []
}
],
"GB0278": [
{
"Gb_type": "unsupported arguments to torch.accelerator.current_stream",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "torch.accelerator.current_stream accepts one optional argument `device`",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0279": [
{
"Gb_type": "bad device argument to torch.get_device_module",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
]
}

View File

@ -0,0 +1,90 @@
import weakref
from typing import Any, Callable
from torch._dynamo.source import Source
PyCodegen = Any
# This file is to handle types that we don't want to support
# as explicit FX graph inputs. This uses a sidetable which
# we populate in bytecode and is loaded during graph execution
# We use a dynamo-generated index as a level of indirection
# this allows us to register objects externally in pre-graph bytecode that we want
# to pass to the graph, but not support their types as graph inputs
index_to_bytecode_constructor: dict[int, Callable[[PyCodegen], None]] = {}
index_to_external_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
keep_alive: list[Any] = []
def has_user_objects() -> bool:
return bool(index_to_bytecode_constructor)
def get_external_object_by_index(index: int) -> Any:
assert index in index_to_external_object_weakref, (
"Index not registered in index_to_user_object_weakref"
)
obj = index_to_external_object_weakref[index]()
assert obj is not None, "User object is no longer alive"
return index_to_external_object_weakref[index]()
def store_user_object_weakrefs(*args: Any) -> None:
global index_to_external_object_weakref
index_to_external_object_weakref.clear()
index_to_external_object_weakref.update(
{i: weakref.ref(arg) for i, arg in enumerate(args)}
)
def reset_user_object_tracking() -> None:
index_to_bytecode_constructor.clear()
index_to_external_object_weakref.clear()
keep_alive.clear()
def register_graph_created_object(
example_value: Any, construct_fn: Callable[[int, PyCodegen], None]
) -> int:
global index_to_bytecode_constructor
global keep_alive
keep_alive.append(example_value)
index = len(index_to_bytecode_constructor)
index_to_bytecode_constructor[index] = lambda cg: construct_fn(index, cg)
try:
index_to_external_object_weakref[index] = weakref.ref(example_value)
except TypeError as e:
from .exc import unimplemented_v2
unimplemented_v2(
gb_type="Failed to make weakref to graph-created external object",
context=f"user_object: {example_value}",
explanation="Object does not allow us to make a weakref to it",
hints=[],
from_exc=e,
)
return index
# Register a user object to be used in the graph
def register_user_object(value: Any, source: Source) -> int:
global index_to_bytecode_constructor
index = len(index_to_bytecode_constructor)
index_to_bytecode_constructor[index] = lambda cg: cg(source)
try:
index_to_external_object_weakref[index] = weakref.ref(value)
except TypeError as e:
from .exc import unimplemented_v2
unimplemented_v2(
gb_type="Failed to make weakref to User Object",
context=f"user_object: {value}",
explanation="Object does not allow us to make a weakref to it",
hints=[],
from_exc=e,
)
return index

View File

@ -132,6 +132,7 @@ from .source import (
CodeSource,
ConstantSource,
ConstDictKeySource,
CurrentStreamSource,
DataclassFieldsSource,
DefaultsSource,
DictGetItemSource,
@ -181,6 +182,7 @@ from .utils import (
common_constant_types,
dataclass_fields,
dict_keys,
get_current_stream,
get_custom_getattr,
get_torch_function_mode_stack,
get_torch_function_mode_stack_at,
@ -757,6 +759,7 @@ def _get_closure_vars() -> dict[str, object]:
"___dataclass_fields": dataclass_fields,
"___namedtuple_fields": lambda x: x._fields,
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
"___get_current_stream": get_current_stream,
"__math_isnan": math.isnan,
"__numpy_isnan": None if np is None else np.isnan,
"inf": float("inf"),
@ -1448,6 +1451,13 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, CurrentStreamSource):
out = root_guard_manager.lambda_manager(
python_lambda=lambda _: get_current_stream(source.device),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, GradSource):
assert base_guard_manager # to make mypy happy
out = base_guard_manager.grad_manager(
@ -2166,6 +2176,8 @@ class GuardBuilder(GuardBuilderBase):
range,
dict_keys,
torch.Size,
torch.Stream,
torch.cuda.streams.Stream,
*np_types,
*ok_mutable_types,
}

View File

@ -100,6 +100,7 @@ from .exc import (
unimplemented_v2,
unimplemented_v2_with_warning,
)
from .graph_bytecode_inputs import has_user_objects, index_to_bytecode_constructor
from .graph_deduplication import apply_graph_deduplication
from .graph_region_tracker import GraphRegionTracker
from .guards import GuardBuilder, install_guard
@ -1512,6 +1513,37 @@ class OutputGraph(OutputGraphCommon):
from .decorators import disable
if has_user_objects():
# NB: This is where we store possible user objects before running the graph
# index_to_user_object_weakref is the function used in the graph to translate
# the dynamo-generated index into the actual object passed to the compiled function.
# We generate bytecode to store all user objects at the proper index in the below
# call.
codegen = PyCodegen(
self.root_tx, root, overridden_sources=overridden_sources
)
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__,
"store_user_object_weakrefs",
)
)
tmp_vars = []
for constructor in reversed(index_to_bytecode_constructor.values()):
constructor(codegen)
var_name = (
self.new_var()
) # keep alive any temp objects for the rest of the frame
codegen.store(var_name)
tmp_vars.append(var_name)
for var_name in tmp_vars:
codegen.append_output(codegen.create_load(var_name))
codegen.call_function(len(index_to_bytecode_constructor), False)
codegen.pop_top()
self.add_output_instructions(codegen.get_instructions())
# to handle random calls
if len(self.random_calls) > 0:
random_calls_instructions = []
@ -1657,7 +1689,7 @@ class OutputGraph(OutputGraphCommon):
)
elif (
vt.source is not None
and (source := getattr(vt.source, "base", None))
and (source := getattr(vt.source, "base", None)) # type: ignore[assignment]
and source.is_input
):
self.export_metadata.output_return_type[idx] = (

View File

@ -22,6 +22,7 @@ import enum
import functools
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from torch import device as device_type
from torch._guards import ChainedSource, Guard, GuardSource, Source
from . import utils
@ -1078,6 +1079,30 @@ class ShapeEnvSource(Source):
return GuardSource.SHAPE_ENV
@dataclasses.dataclass(frozen=True)
class CurrentStreamSource(Source):
device: device_type
def name(self) -> str:
return f"___get_current_stream(torch.device('{self.device.type}', {self.device.index}))"
def reconstruct(self, codegen: "PyCodegen") -> None:
num_args = 1
codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "get_current_stream")
)
codegen.add_push_null(lambda: codegen.load_import_from("torch", "device"))
codegen.extend_output([codegen.create_load_const(self.device.type)])
if self.device.index is not None:
num_args += 1
codegen.extend_output([codegen.create_load_const(self.device.index)])
codegen.extend_output(create_call_function(num_args, False))
codegen.extend_output(create_call_function(1, False))
def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL
@dataclasses.dataclass(frozen=True)
class BackwardStateSource(Source):
def name(self) -> str:

View File

@ -173,6 +173,7 @@ from .variables.misc import (
UnknownVariable,
)
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
from .variables.streams import SymbolicStreamState
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
from .variables.torch_function import (
SymbolicTorchFunctionState,
@ -1170,6 +1171,7 @@ class InstructionTranslatorBase(
symbolic_locals: dict[str, VariableTracker]
symbolic_globals: dict[str, VariableTracker]
symbolic_torch_function_state: SymbolicTorchFunctionState
symbolic_stream_state: SymbolicStreamState
post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]]
stack: list[VariableTracker]
instruction_pointer: Optional[int]
@ -4069,6 +4071,7 @@ class InstructionTranslatorBase(
symbolic_locals: dict[str, VariableTracker],
symbolic_globals: dict[str, VariableTracker],
symbolic_torch_function_state: SymbolicTorchFunctionState,
symbolic_stream_state: SymbolicStreamState,
f_code: types.CodeType,
export: bool,
inline_depth: int,
@ -4088,6 +4091,7 @@ class InstructionTranslatorBase(
self.symbolic_locals = symbolic_locals
self.symbolic_globals = symbolic_globals
self.symbolic_torch_function_state = symbolic_torch_function_state
self.symbolic_stream_state = symbolic_stream_state
# used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals)
# in order to generate any nested closures
self.post_prune_cell_and_freevars = None
@ -4241,6 +4245,7 @@ class InstructionTranslator(InstructionTranslatorBase):
# A global var is inserted only after a STORE_GLOBAL happens to it
symbolic_globals={},
symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
symbolic_stream_state=None, # type: ignore[arg-type] # set below
f_code=f_code,
export=export,
inline_depth=0,
@ -4345,6 +4350,8 @@ class InstructionTranslator(InstructionTranslatorBase):
torch_function_mode_stack
)
self.symbolic_stream_state = SymbolicStreamState()
if export:
# export gets confused if we never realize unused inputs
# in export mode just eagerly realize everything
@ -4673,6 +4680,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_state,
parent.symbolic_stream_state,
func,
)
else:
@ -4684,6 +4692,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_state,
parent.symbolic_stream_state,
# pyrefly: ignore # bad-argument-type
func,
)
@ -4767,6 +4776,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
symbolic_locals: dict[str, VariableTracker],
symbolic_globals: dict[str, VariableTracker],
symbolic_torch_function_state: SymbolicTorchFunctionState,
symbolic_stream_state: SymbolicStreamState,
funcvar: BaseUserFunctionVariable,
) -> None:
f_globals = funcvar.get_globals() # type: ignore[attr-defined]
@ -4800,6 +4810,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
symbolic_locals=symbolic_locals,
symbolic_globals=symbolic_globals,
symbolic_torch_function_state=symbolic_torch_function_state,
symbolic_stream_state=symbolic_stream_state,
instructions=instructions,
code_options={k: getattr(code, k) for k in get_code_keys()},
f_code=code,

View File

@ -4655,6 +4655,10 @@ def clear_torch_function_mode_stack() -> None:
_pop_torch_function_stack()
def get_current_stream(device: torch.device) -> torch.Stream:
return torch.accelerator.current_stream(device)
# call from C dynamo in order to inspect values in pdb
def _breakpoint_for_c_dynamo(*args: Any) -> None:
breakpoint()
@ -4719,34 +4723,6 @@ def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]:
return tensor_dict
# This is useful for reconstructing within the Dynamo graph the non-graph-input objects
# whose lifetime is governed by the user.
# e.g. torch.cuda.Event is a prime example.
user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {}
def get_user_object_from_id(obj_id: int) -> Any:
obj = user_obj_id_to_weakref[obj_id]()
assert obj is not None, "User object is no longer alive"
return obj
def store_user_object_weakref(obj: object) -> None:
obj_id = id(obj)
try:
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)
except TypeError as e:
from .exc import unimplemented_v2
unimplemented_v2(
gb_type="Failed to make weakref to User Object",
context=f"user_objected: {obj}",
explanation="Object does not allow us to make a weakref to it",
hints=[],
from_exc=e,
)
class CompileTimeInstructionCounter:
_counter: int = 0
_id: int = -1

View File

@ -37,8 +37,6 @@ from .ctx_manager import (
JvpIncrementNestingCtxManagerVariable,
SDPAKernelVariable,
SetFwdGradEnabledContextManager,
StreamContextVariable,
StreamVariable,
TemporarilyPopInterpreterStackCtxManagerVariable,
VmapIncrementNestingCtxManagerVariable,
WithEnterFunctionVariable,
@ -131,6 +129,7 @@ from .nn_module import (
)
from .optimizer import OptimizerVariable
from .sdpa import SDPAParamsVariable
from .streams import EventVariable, StreamContextVariable, StreamVariable
from .tensor import (
DataPtrVariable,
FakeItemVariable,

View File

@ -45,6 +45,10 @@ import sympy
import torch
from torch import SymInt
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.graph_bytecode_inputs import (
get_external_object_by_index,
register_user_object,
)
from torch._dynamo.utils import (
get_metrics_context,
is_int_specialization_case,
@ -172,11 +176,8 @@ from .ctx_manager import (
AutocastModeVariable,
DynamoConfigPatchVariable,
ErrorOnGraphBreakVariable,
EventVariable,
NullContextVariable,
PreserveVersionContextVariable,
StreamContextVariable,
StreamVariable,
)
from .dicts import (
ConstDictVariable,
@ -257,6 +258,7 @@ from .nn_module import (
from .optimizer import OptimizerVariable
from .script_object import TorchScriptObjectVariable
from .sdpa import SDPAParamsVariable
from .streams import EventVariable, StreamContextVariable, StreamVariable
from .tensor import (
NumpyNdarrayVariable,
supported_const_comparison_op_values,
@ -1036,24 +1038,20 @@ class VariableBuilder:
stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
return StreamContextVariable.create(self.tx, stream_var)
elif isinstance(value, torch.Stream):
self.install_guards(GuardBuilder.ID_MATCH)
# This refers to the device-agnostic torch.Stream
self.install_guards(GuardBuilder.TYPE_MATCH)
index = register_user_object(value, self.source)
stream_proxy = self.tx.output.create_proxy(
"call_function",
type(value),
(),
{
"stream_id": value.stream_id,
"device_index": value.device_index,
"device_type": value.device_type,
},
"call_function", get_external_object_by_index, (index,), {}
)
set_example_value(stream_proxy.node, value)
return StreamVariable(
var = StreamVariable(
stream_proxy,
value,
value.device,
source=self.source,
)
return self.tx.output.side_effects.track_object_existing(value, var)
elif isinstance(value, (torch._C._SDPAParams)):
self.install_guards(GuardBuilder.TYPE_MATCH)
return SDPAParamsVariable.create(self.tx, value, self.source)
@ -1061,12 +1059,12 @@ class VariableBuilder:
self.install_guards(GuardBuilder.ID_MATCH)
return FuncTorchInterpreterVariable(value)
elif isinstance(value, torch.Event):
self.install_guards(GuardBuilder.ID_MATCH)
torch._dynamo.utils.store_user_object_weakref(value)
self.install_guards(GuardBuilder.TYPE_MATCH)
index = register_user_object(value, self.source)
event_proxy = self.tx.output.create_proxy(
"call_function",
torch._dynamo.utils.get_user_object_from_id,
(id(value),),
get_external_object_by_index,
(index,),
{},
)
set_example_value(event_proxy.node, value)
@ -2980,8 +2978,9 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
set_example_value(proxy.node, example_value)
return SymNodeVariable(proxy, example_value, **options)
elif (
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, torch.Stream)
isinstance(example_value, torch.Stream)
and proxy.node.target
in (get_external_object_by_index, torch.accelerator.current_stream)
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()

View File

@ -83,7 +83,6 @@ from ..utils import (
)
from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker
from .constant import ConstantVariable
from .ctx_manager import EventVariable, StreamVariable
from .dicts import (
ConstDictVariable,
DefaultDictVariable,
@ -101,6 +100,7 @@ from .lists import (
TupleIteratorVariable,
TupleVariable,
)
from .streams import EventVariable, StreamVariable
from .tensor import (
FakeItemVariable,
supported_comparison_ops,

View File

@ -34,7 +34,6 @@ from ..bytecode_transformation import (
create_instruction,
create_setup_with,
)
from ..device_interface import get_interface_for_device
from ..exc import unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalStateSource
@ -991,70 +990,6 @@ class ProfilerContextVariable(ContextWrappingVariable):
)
class StreamContextVariable(ContextWrappingVariable):
@staticmethod
def create(tx: "InstructionTranslator", target_value, **kwargs):
from .builder import wrap_fx_proxy_cls
current_stream_method = get_interface_for_device(
target_value.device
).current_stream
current_stream = wrap_fx_proxy_cls(
StreamVariable,
tx,
tx.output.create_proxy(
"call_function",
current_stream_method,
(None,),
{},
),
)
return StreamContextVariable(
target_values=[target_value],
initial_values=[current_stream],
device=target_value.device,
**kwargs,
)
def __init__(self, target_values, device, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.device = device
self.set_stream = get_interface_for_device(self.device).set_stream
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
def enter(self, tx):
# stream generated inside the traced function
if self.target_values[0].as_proxy() is not None:
tx.output.create_proxy(
"call_function",
self.set_stream,
(self.target_values[0].as_proxy(),),
{},
)
# stream passed from outside the traced function
else:
stream = self.target_values[0].value
tx.output.create_proxy(
"call_function",
self.set_stream_id,
(stream.stream_id, stream.device_index, stream.device_type),
{},
)
self.set_stream(self.target_values[0].value)
self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
def exit(self, tx: "InstructionTranslator", *args):
tx.output.create_proxy(
"call_function",
self.set_stream,
(self.initial_values[0].as_proxy(),),
{},
)
self.cleanup_assert()
class PreserveVersionContextVariable(ContextWrappingVariable):
"""
Wraps torch.autograd._unsafe_preserve_version_counter
@ -1290,142 +1225,6 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
return "annotate"
class StreamVariable(VariableTracker):
def __init__(self, proxy, value, device, **kwargs) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
assert value.device.type == device.type, (
"stream value is not equal to the passed device"
)
super().__init__(**kwargs)
self.proxy = proxy
self.value = value
self.device = device
def python_type(self):
return torch.Stream
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
assert hasattr(self.value, name), f"no stream method found named {name}"
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait_stream", "synchronize", "wait_event"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return variables.ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=variables.ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name == "record_event":
return wrap_fx_proxy_cls(
target_cls=EventVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
# NB : Checking for mutation is necessary because we compare
# constant values
other = args[0]
if not isinstance(other, StreamVariable):
return variables.ConstantVariable.create(NotImplemented)
return variables.ConstantVariable.create(
cmp_name_to_op_mapping[name](self.value, other.value)
)
return super().call_method(tx, name, args, kwargs)
def as_proxy(self):
return self.proxy
def reconstruct(self, codegen: "PyCodegen"):
# If we got here, this stream is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
# Since we just proved that - for other such structures, like lists and dicts, reconstruction
# is fine and sound according to dynamo principles of treating collectives. However,
# streams are special in that we want to preserve the identity of the stream as the same as in the graph
# Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
# yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
# design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
prefix = f"_stream_{self.device}"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))
class EventVariable(VariableTracker):
def __init__(self, proxy, value, **kwargs) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
super().__init__(**kwargs)
self.proxy = proxy
self.value = value
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
from ..utils import proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait", "record", "synchronize"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return variables.ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=variables.ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
else:
method_name = (
f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
)
unimplemented_v2(
gb_type="Unsupported event method",
context=str(name),
explanation=f"Dynamo doesn't support tracing the {method_name} method. "
f"We currently support wait, record, synchronize, and query.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
def as_proxy(self):
return self.proxy
def reconstruct(self, codegen: "PyCodegen"):
# If we got here, this event is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
# Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
prefix = "_event"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))
class DynamoConfigPatchVariable(ContextWrappingVariable):
"""represents torch._dynamo.patch_dynamo_config"""

View File

@ -0,0 +1,418 @@
import collections
from typing import Any, Optional
import torch
from torch.fx import Proxy
from .. import graph_break_hints
from ..bytecode_transformation import create_call_function
from ..device_interface import get_interface_for_device
from ..exc import TYPE_CHECKING, unimplemented_v2
from ..source import AttrSource, CallFunctionNoArgsSource, TorchSource
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import ContextWrappingVariable
from .lazy import LazyVariableTracker
from .misc import GetAttrVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from ..codegen import PyCodegen
from torch._library.custom_ops import custom_op
Tensor = torch.Tensor
from torch._higher_order_ops.effects import _EffectType, _register_effectful_op
@custom_op("streams::fork", mutates_args=())
def fork_stream(
from_index: int,
from_device: torch.device,
to_index: int,
to_device: torch.device,
) -> int:
return from_index
@fork_stream.register_fake
def _(
from_index: int,
from_device: torch.device,
to_index: int,
to_device: torch.device,
) -> int:
return from_index
def fork_backward(ctx, grad_output):
from_index, from_device, to_index, to_device = ctx.args
from_index = join_stream(to_index, to_device, from_index, from_device)
return None, from_index, None, None, None, None
def fork_setup_context(ctx, inputs, output):
from_index, from_device, to_index, to_device, _ = inputs
ctx.args = (from_index, from_device, to_index, to_device)
_register_effectful_op(fork_stream._opoverload, _EffectType.ORDERED)
fork_stream.register_autograd(fork_backward, setup_context=fork_setup_context)
@custom_op("streams::join", mutates_args=())
def join_stream(
from_index: int,
from_device: torch.device,
to_index: int,
to_device: torch.device,
) -> int:
return from_index
@join_stream.register_fake
def _(
from_index: int,
from_device: torch.device,
to_index: int,
to_device: torch.device,
) -> int:
return from_index
def join_backward(ctx, grad_output):
from_index, from_device, to_index, to_device = ctx.args
from_index = fork_stream(from_index, from_device, to_index, to_device)
return None, from_index, None, None, None, None
def join_setup_context(ctx, inputs, output):
from_index, from_device, to_index, to_device = inputs
ctx.args = (from_index, from_device, to_index, to_device)
_register_effectful_op(join_stream._opoverload, _EffectType.ORDERED)
join_stream.register_autograd(join_backward, setup_context=join_setup_context)
class SymbolicStreamState:
"""Track the currently entered stream if any"""
def __init__(self) -> None:
from ..source import CurrentStreamSource
stream_var = LazyVariableTracker.create(
torch.accelerator.current_stream(),
source=CurrentStreamSource(torch.accelerator.current_stream().device),
)
self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque(
[stream_var] # type: ignore[list-item]
)
def enter_stream(self, stream: "StreamVariable") -> None:
self.cur_stream_stack.append(stream)
def exit_stream(self) -> None:
self.cur_stream_stack.pop()
def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable":
if device is not None:
for stream in reversed(self.cur_stream_stack):
if stream.device == device:
return stream
return self.cur_stream_stack[-1]
def in_stream_context(self) -> bool:
return len(self.cur_stream_stack) > 0
class StreamContextVariable(ContextWrappingVariable):
"""This represents torch.cuda.StreamContext"""
@staticmethod
def create(
tx: "InstructionTranslator",
target_value: "StreamVariable",
**kwargs: dict[str, Any],
) -> "StreamContextVariable":
return StreamContextVariable(
target_values=[target_value],
initial_values=[
StreamContextVariable._get_current_stream(target_value.device, tx)
],
device=target_value.device,
**kwargs,
)
def __init__(
self,
target_values: list["StreamVariable"],
device: torch.device,
initial_values: Optional[list["StreamVariable"]] = None,
**kwargs: dict[str, Any],
) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.device = device
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
# to stream, from stream is the order of the arguments
# we are entering the target, and leaving the initial stream
tx.symbolic_stream_state.enter_stream(self._get_target_values()[0])
tx.output.create_proxy(
"call_function",
torch.ops.streams.fork.default,
self._target_stream_proxies() + self._initial_stream_proxies(),
{},
)
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker":
# to stream, from stream is the order of the arguments
# we are leaving the target, and entering the initial stream
tx.symbolic_stream_state.exit_stream()
tx.output.create_proxy(
"call_function",
torch.ops.streams.join.default,
self._initial_stream_proxies() + self._target_stream_proxies(),
{},
)
return ConstantVariable.create(None)
def _initial_stream_proxies(self) -> tuple[Proxy, Proxy]:
assert self.initial_values, "No initial stream to move from"
return StreamContextVariable._extract_stream_properties(
self.initial_values[0].as_proxy()
)
def _target_stream_proxies(self) -> tuple[Proxy, Proxy]:
return StreamContextVariable._extract_stream_properties(
self._get_target_values()[0].as_proxy()
)
@staticmethod
def _extract_stream_properties(stream_proxy: Proxy) -> tuple[Proxy, Proxy]:
stream_index = GetAttrVariable.create_getattr_proxy(stream_proxy, "stream_id")
stream_device = GetAttrVariable.create_getattr_proxy(stream_proxy, "device")
return stream_index, stream_device
@staticmethod
def _get_current_stream(
device: torch.device, tx: "InstructionTranslator"
) -> "StreamVariable":
from .builder import wrap_fx_proxy_cls
current_stream_method = get_interface_for_device(device).current_stream
current_stream = wrap_fx_proxy_cls(
StreamVariable,
tx,
tx.output.create_proxy(
"call_function",
current_stream_method,
(None,),
{},
),
)
return current_stream
def _get_target_values(self) -> list["StreamVariable"]:
# We need this to be overridable, since StreamVariable does
# not store target values (it does not require any arguments)
# and captures the current stream at the time of entering the context
return self.target_values
def supports_graph_breaks(self) -> bool:
return True
class StreamVariable(StreamContextVariable):
"""Represents the device-agnostic torch.Stream class"""
def __init__(
self,
proxy: Proxy,
value: torch.Stream,
device: torch.device,
**kwargs: Any,
) -> None:
# Index into the user object table
# used to pass arbitrary objects to the graph
user_object_index = kwargs.pop("user_obj_index", None)
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
assert value.device.type == device.type, (
"stream value is not equal to the passed device"
)
super().__init__(target_values=[], initial_values=None, device=device, **kwargs)
self.proxy = proxy
self.value = value
self.device = device
self.user_object_index = user_object_index
def python_type(self) -> type:
return torch.Stream
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
assert hasattr(self.value, name), f"no stream method found named {name}"
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait_stream", "synchronize", "wait_event"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name == "record_event":
return wrap_fx_proxy_cls(
target_cls=EventVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
from ..guards import GuardBuilder, install_guard
if self.source:
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
# NB : Checking for mutation is necessary because we compare
# constant values
other = args[0]
if not isinstance(other, StreamVariable):
return ConstantVariable.create(NotImplemented)
if other.source:
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type]
)
return super().call_method(tx, name, args, kwargs)
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
# NB: Set initial values when we enter
# Don't do this at object creation, as we need to record the current stream
# at the time the context is entered.
self.initial_values = [
StreamContextVariable._get_current_stream(self.device, tx)
]
return super().enter(tx)
def as_proxy(self) -> Proxy:
return self.proxy
def module_name(self) -> str:
return "torch._C"
def fn_name(self) -> str:
return "Stream"
def reconstruct(self, codegen: "PyCodegen") -> None:
# If we got here, this stream is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
if self.user_object_index is not None:
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__,
"get_external_object_by_index",
)
)
codegen.append_output(codegen.create_load_const(self.user_object_index))
codegen.extend_output(create_call_function(1, False))
else:
# TODO mlazos: evaluate if we still need this
prefix = f"_stream_{self.device}"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))
@staticmethod
def construct_in_graph_stream(index: int, codegen: "PyCodegen") -> None:
# Use source to create the right bytecode, this
# isn't an actual input
source = CallFunctionNoArgsSource(AttrSource(TorchSource(), "Stream"))
codegen(source)
def _get_target_values(self) -> list["StreamVariable"]:
return [self]
class EventVariable(VariableTracker):
def __init__(self, proxy: Proxy, value: torch.Event, **kwargs: Any) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
super().__init__(**kwargs)
self.proxy = proxy
self.value = value
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
from ..utils import proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait", "record", "synchronize"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
else:
method_name = (
f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
)
unimplemented_v2(
gb_type="Unsupported event method",
context=str(name),
explanation=f"Dynamo doesn't support tracing the {method_name} method. "
f"We currently support wait, record, synchronize, and query.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
def as_proxy(self) -> Proxy:
return self.proxy
def reconstruct(self, codegen: "PyCodegen") -> None:
# If we got here, this event is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
# Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
prefix = "_event"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))

View File

@ -1237,6 +1237,35 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# pyrefly: ignore # unbound-name
return VariableTracker.build(tx, module, new_source)
@register(torch.accelerator.current_stream)
def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs):
if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs):
unimplemented_v2(
gb_type="unsupported arguments to torch.accelerator.current_stream",
context=f"args={args}, kwargs={kwargs}",
explanation="torch.accelerator.current_stream accepts one optional argument `device`",
hints=[
*graph_break_hints.USER_ERROR,
],
)
try:
if kwargs:
device = torch.device(kwargs["device"].as_python_constant())
elif args:
device = torch.device(args[0].as_python_constant())
else:
device = None
return tx.symbolic_stream_state.cur_stream(device)
except Exception as e:
unimplemented_v2(
gb_type="bad device argument to torch.accelerator.current_stream",
context=f"args={args}, kwargs={kwargs}",
explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
hints=[*graph_break_hints.USER_ERROR],
from_exc=e,
)
@register(torch.set_default_device)
def handle_set_default_device(
self, tx: "InstructionTranslator", *args, **kwargs

View File

@ -58,6 +58,7 @@ from ..exc import (
raise_observed_exception,
unimplemented_v2,
)
from ..graph_bytecode_inputs import get_external_object_by_index
from ..guards import GuardBuilder, install_guard
from ..source import (
AttrSource,
@ -792,14 +793,31 @@ class UserDefinedClassVariable(UserDefinedVariable):
)
args = [stacked]
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
*proxy_args_kwargs(args, kwargs),
),
)
if issubclass(self.value, torch.Stream):
# Register newly created stream for reconstruction
stream = self.value()
from ..graph_bytecode_inputs import register_graph_created_object
from .streams import StreamVariable
ind = register_graph_created_object(
stream, StreamVariable.construct_in_graph_stream
)
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function", get_external_object_by_index, (ind,), {}
),
user_obj_index=ind,
)
else:
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
*proxy_args_kwargs(args, kwargs),
),
)
return tensor_variable
elif self.value is random.Random:

View File

@ -49,6 +49,7 @@ static PyObject* THPEvent_pynew(
}
THPEvent* self = (THPEvent*)ptr.get();
self->weakreflist = nullptr;
// TODO: blocking and interprocess are not supported yet. To support them, the
// flag system of c10::Event needs to be refactored. C10::Event should also
@ -73,6 +74,7 @@ PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
TORCH_CHECK(self, "Failed to allocate memory for Event");
auto self_ = reinterpret_cast<THPEvent*>(self.get());
self_->weakreflist = nullptr;
new (&self_->event) c10::Event(device_type, flag);
return self.release();
}
@ -82,6 +84,7 @@ static void THPEvent_dealloc(THPEvent* self) {
pybind11::gil_scoped_release no_gil{};
self->event.~Event();
}
PyObject_ClearWeakRefs((PyObject*)self);
Py_TYPE(self)->tp_free((PyObject*)self);
}
@ -274,7 +277,8 @@ static PyMethodDef THPEvent_methods[] = {
{"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
{"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
{nullptr}};
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Winvalid-offsetof"
PyTypeObject THPEventType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch.Event", /* tp_name */
@ -300,7 +304,7 @@ PyTypeObject THPEventType = {
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
offsetof(THPEvent, weakreflist), /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
THPEvent_methods, /* tp_methods */
@ -315,6 +319,7 @@ PyTypeObject THPEventType = {
nullptr, /* tp_alloc */
THPEvent_pynew, /* tp_new */
};
#pragma GCC diagnostic pop
void THPEvent_init(PyObject* module) {
THPEventClass = &THPEventType;

View File

@ -7,6 +7,7 @@
struct TORCH_API THPEvent {
PyObject_HEAD
c10::Event event;
PyObject* weakreflist;
};
TORCH_API extern PyTypeObject* THPEventClass;
TORCH_API extern PyTypeObject THPEventType;

View File

@ -95,6 +95,7 @@ static PyObject* THPStream_pynew(
self->device_index = static_cast<int64_t>(stream_opt->device_index());
self->device_type = static_cast<int64_t>(stream_opt->device_type());
self->context = nullptr;
self->weakreflist = nullptr;
return (PyObject*)ptr.release();
END_HANDLE_TH_ERRORS
@ -114,11 +115,13 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
self->device_index = static_cast<int64_t>(stream.device_index());
self->device_type = static_cast<int64_t>(stream.device_type());
self->context = nullptr;
self->weakreflist = nullptr;
return ptr.release();
END_HANDLE_TH_ERRORS
}
static void THPStream_dealloc(THPStream* self) {
PyObject_ClearWeakRefs((PyObject*)self);
Py_TYPE(self)->tp_free((PyObject*)self);
}
@ -436,7 +439,7 @@ static PyTypeObject THPStreamType = {
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
THPStream_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
offsetof(THPStream, weakreflist), /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
// NOLINTNEXTLINE(*const-cast)

View File

@ -13,6 +13,7 @@ struct THPStream {
int64_t device_index;
// Used to switch stream context management, initialized lazily.
PyObject* context;
PyObject* weakreflist;
};
extern TORCH_API PyTypeObject* THPStreamClass;