mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	Compare commits
	
		
			20 Commits
		
	
	
		
			ciflow/bin
			...
			mlazos/use
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| bcf1a53297 | |||
| 75c6f9b93b | |||
| cf8fb02c33 | |||
| bfd2b03577 | |||
| 6441f7a7fe | |||
| 6d30dba93d | |||
| dc90a72bb5 | |||
| 44de0318c4 | |||
| 66c8640559 | |||
| 923a7c7bcc | |||
| e2b0cfe647 | |||
| 60508c7ed8 | |||
| a43c5f210b | |||
| e8bd37d77c | |||
| 28742d61be | |||
| 104dec4c55 | |||
| 4040707f1f | |||
| c0ec620a09 | |||
| 47d2882ea6 | |||
| 2df9f24b3f | 
							
								
								
									
										187
									
								
								test/dynamo/test_streams.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										187
									
								
								test/dynamo/test_streams.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,187 @@ | ||||
| # Owner(s): ["module: dynamo"] | ||||
| import weakref | ||||
|  | ||||
| import torch | ||||
| import torch._dynamo.test_case | ||||
| import torch._dynamo.testing | ||||
| from torch.testing._internal.common_utils import requires_cuda | ||||
|  | ||||
|  | ||||
| class TestStreams(torch._dynamo.test_case.TestCase): | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         super().setUpClass() | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls): | ||||
|         super().tearDownClass() | ||||
|  | ||||
|     def test_stream_weakref(self): | ||||
|         s = torch.Stream() | ||||
|         weakref.ref(s) | ||||
|  | ||||
|     def test_event_weakref(self): | ||||
|         e = torch.Event() | ||||
|         weakref.ref(e) | ||||
|  | ||||
|     def test_stream_enter_exit(self): | ||||
|         def fn(x, y): | ||||
|             s2 = torch.Stream() | ||||
|             s1 = torch.Stream() | ||||
|             with s1: | ||||
|                 z1 = torch.add(x, y) | ||||
|             with s2: | ||||
|                 z = torch.add(x, y) | ||||
|                 y = z + 2 + z1 | ||||
|  | ||||
|             return y | ||||
|  | ||||
|         inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) | ||||
|         expected = fn(*inp) | ||||
|         fn_opt = torch.compile(fn, fullgraph=True) | ||||
|         actual = fn_opt(*inp) | ||||
|         self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_stream_context_graph_break(self): | ||||
|         def fn(x, y): | ||||
|             s2 = torch.Stream() | ||||
|             s1 = torch.Stream() | ||||
|             with s1: | ||||
|                 z1 = torch.add(x, y) | ||||
|             with s2: | ||||
|                 z = torch.add(x, y) | ||||
|                 y = z + 2 + z1 | ||||
|                 torch._dynamo.graph_break() | ||||
|                 y = y + 1 | ||||
|  | ||||
|             return y | ||||
|  | ||||
|         inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) | ||||
|         expected = fn(*inp) | ||||
|         fn_opt = torch.compile(fn) | ||||
|         actual = fn_opt(*inp) | ||||
|         self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_stream_input(self): | ||||
|         def fn(x, y, s): | ||||
|             z = torch.add(x, y) | ||||
|             y = z + 2 | ||||
|             return y, s | ||||
|  | ||||
|         inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(device="cuda")) | ||||
|         expected = fn(*inp) | ||||
|         fn_opt = torch.compile(fn, fullgraph=True) | ||||
|         actual = fn_opt(*inp) | ||||
|         self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_local_stream_return(self): | ||||
|         def fn(x, y): | ||||
|             s = torch.Stream() | ||||
|             z = torch.add(x, y) | ||||
|             y = z + 2 | ||||
|             return y, s | ||||
|  | ||||
|         inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) | ||||
|         fn_opt = torch.compile(fn, fullgraph=True) | ||||
|         _, s0 = fn_opt(*inp) | ||||
|         _, s1 = fn_opt(*inp) | ||||
|         # Streams will be different values for each invocation | ||||
|         # so don't check for equality | ||||
|         self.assertIsInstance(s0, torch.Stream) | ||||
|         # Stream should be newly allocated on each call | ||||
|         self.assertNotEqual(s0, s1) | ||||
|  | ||||
|     def test_get_current_stream_return(self): | ||||
|         def fn(x, s): | ||||
|             with s: | ||||
|                 s0 = torch.accelerator.current_stream() | ||||
|             return x, s0 | ||||
|  | ||||
|         s_inp = torch.Stream(device="cuda") | ||||
|         inp = (torch.ones(2, 2) + 1, s_inp) | ||||
|         fn_opt = torch.compile(fn, fullgraph=True) | ||||
|         _, s0 = fn_opt(*inp) | ||||
|         _, s1 = fn_opt(*inp) | ||||
|         self.assertEqual(s_inp, s0) | ||||
|         self.assertEqual(s0, s1) | ||||
|  | ||||
|     def test_get_current_stream_return_different_device(self): | ||||
|         def fn(x, s0, s1): | ||||
|             with s1: | ||||
|                 with s0: | ||||
|                     s = torch.accelerator.current_stream(torch.device("cuda:1")) | ||||
|             return s | ||||
|  | ||||
|         s0 = torch.Stream(device="cuda:0") | ||||
|         s1 = torch.Stream(device="cuda:1") | ||||
|         inp = (torch.ones(2, 2) + 1, s0, s1) | ||||
|         fn_opt = torch.compile(fn, fullgraph=True) | ||||
|         s_act = fn_opt(*inp) | ||||
|         s_exp = fn(*inp) | ||||
|         self.assertEqual(s_act, s_exp) | ||||
|  | ||||
|     def test_get_current_stream_return_no_index(self): | ||||
|         def fn(x, s0, s1): | ||||
|             with s1: | ||||
|                 with s0: | ||||
|                     s = torch.accelerator.current_stream(torch.device("cuda")) | ||||
|             return s | ||||
|  | ||||
|         s0 = torch.Stream(device="cuda:0") | ||||
|         s1 = torch.Stream(device="cuda:1") | ||||
|         inp = (torch.ones(2, 2) + 1, s0, s1) | ||||
|         fn_opt = torch.compile(fn, fullgraph=True) | ||||
|         s_act = fn_opt(*inp) | ||||
|         s_exp = fn(*inp) | ||||
|         self.assertEqual(s_act, s_exp) | ||||
|  | ||||
|     def test_fork_join_backward(self): | ||||
|         def fn(x, s0): | ||||
|             with s0: | ||||
|                 y = torch.add(x, x) | ||||
|             return y | ||||
|  | ||||
|         inp = (torch.ones(2, 2, requires_grad=True) + 1, torch.Stream(device="cuda")) | ||||
|         fn_opt = torch.compile(fn, fullgraph=True) | ||||
|         actual = fn_opt(*inp) | ||||
|         actual.sum().backward() | ||||
|         # expected = fn(*inp) | ||||
|         # expected.sum().backward() | ||||
|         # self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_nested_stream_enter_exit(self): | ||||
|         pass | ||||
|  | ||||
|     def test_stream_enter_exit_graph_break(self): | ||||
|         pass | ||||
|  | ||||
|     def test_nested_stream_enter_exit_graph_break(self): | ||||
|         pass | ||||
|  | ||||
|     def test_local_stream_enter_exit(self): | ||||
|         pass | ||||
|  | ||||
|     def test_local_stream_nested_enter_exit(self): | ||||
|         pass | ||||
|  | ||||
|     def test_stream_with_mutation(self): | ||||
|         pass | ||||
|  | ||||
|     @requires_cuda | ||||
|     def test_run_opcheck(self): | ||||
|         from torch._dynamo.variables.streams import fork_stream, join_stream | ||||
|         from torch.library import opcheck | ||||
|  | ||||
|         sample_inputs = [ | ||||
|             (0, torch.device("cuda:0"), 1, torch.device("cuda:1")), | ||||
|             (2, torch.device("cuda:2"), 3, torch.device("cuda:1")), | ||||
|         ] | ||||
|         for args in sample_inputs: | ||||
|             opcheck(fork_stream, args) | ||||
|             opcheck(join_stream, args) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     from torch._dynamo.test_case import run_tests | ||||
|  | ||||
|     run_tests() | ||||
| @ -153,7 +153,6 @@ def reset() -> None: | ||||
|         GenerationTracker.clear() | ||||
|         TensorifyState.clear() | ||||
|         torch._dynamo.utils.warn_once_cache.clear() | ||||
|         torch._dynamo.utils.user_obj_id_to_weakref.clear() | ||||
|         torch._C._autograd._saved_tensors_hooks_set_tracing(False) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -116,6 +116,7 @@ from .exc import ( | ||||
|     unimplemented_v2, | ||||
|     Unsupported, | ||||
| ) | ||||
| from .graph_bytecode_inputs import reset_user_object_tracking | ||||
| from .guards import ( | ||||
|     CheckFunctionManager, | ||||
|     get_and_maybe_log_recompilation_reasons, | ||||
| @ -314,6 +315,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: | ||||
|                 torch.fx._symbolic_trace._maybe_revert_all_patches() | ||||
|             ) | ||||
|             exit_stack.enter_context(torch_function_mode_stack_state_mgr) | ||||
|             reset_user_object_tracking() | ||||
|             try: | ||||
|                 return fn(*args, **kwargs) | ||||
|             finally: | ||||
|  | ||||
| @ -2495,6 +2495,14 @@ | ||||
|     } | ||||
|   ], | ||||
|   "GB0249": [ | ||||
|     { | ||||
|       "Gb_type": "bad device argument to torch.accelerator.current_stream", | ||||
|       "Context": "args={args}, kwargs={kwargs}", | ||||
|       "Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", | ||||
|       "Hints": [ | ||||
|         "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "Gb_type": "bad device argument to torch.get_device_module", | ||||
|       "Context": "args={args}, kwargs={kwargs}", | ||||
| @ -2734,6 +2742,12 @@ | ||||
|     } | ||||
|   ], | ||||
|   "GB0272": [ | ||||
|     { | ||||
|       "Gb_type": "Failed to make weakref to User Object when storing by ID", | ||||
|       "Context": "user_objected: {obj}", | ||||
|       "Explanation": "Object does not allow us to make a weakref to it", | ||||
|       "Hints": [] | ||||
|     }, | ||||
|     { | ||||
|       "Gb_type": "Failed to make weakref to User Object", | ||||
|       "Context": "user_objected: {obj}", | ||||
| @ -2776,5 +2790,41 @@ | ||||
|         "This is likely to be a Dynamo bug. Please report an issue to PyTorch." | ||||
|       ] | ||||
|     } | ||||
|   ], | ||||
|   "GB0276": [ | ||||
|     { | ||||
|       "Gb_type": "Failed to make weakref to User Object", | ||||
|       "Context": "user_object: {value}", | ||||
|       "Explanation": "Object does not allow us to make a weakref to it", | ||||
|       "Hints": [] | ||||
|     } | ||||
|   ], | ||||
|   "GB0277": [ | ||||
|     { | ||||
|       "Gb_type": "Failed to make weakref to graph-created external object", | ||||
|       "Context": "user_object: {example_value}", | ||||
|       "Explanation": "Object does not allow us to make a weakref to it", | ||||
|       "Hints": [] | ||||
|     } | ||||
|   ], | ||||
|   "GB0278": [ | ||||
|     { | ||||
|       "Gb_type": "unsupported arguments to torch.accelerator.current_stream", | ||||
|       "Context": "args={args}, kwargs={kwargs}", | ||||
|       "Explanation": "torch.accelerator.current_stream accepts one optional argument `device`", | ||||
|       "Hints": [ | ||||
|         "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." | ||||
|       ] | ||||
|     } | ||||
|   ], | ||||
|   "GB0279": [ | ||||
|     { | ||||
|       "Gb_type": "bad device argument to torch.get_device_module", | ||||
|       "Context": "args={args}, kwargs={kwargs}", | ||||
|       "Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", | ||||
|       "Hints": [ | ||||
|         "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." | ||||
|       ] | ||||
|     } | ||||
|   ] | ||||
| } | ||||
|  | ||||
							
								
								
									
										90
									
								
								torch/_dynamo/graph_bytecode_inputs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								torch/_dynamo/graph_bytecode_inputs.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,90 @@ | ||||
| import weakref | ||||
| from typing import Any, Callable | ||||
|  | ||||
| from torch._dynamo.source import Source | ||||
|  | ||||
|  | ||||
| PyCodegen = Any | ||||
|  | ||||
| # This file is to handle types that we don't want to support | ||||
| # as explicit FX graph inputs. This uses a sidetable which | ||||
| # we populate in bytecode and is loaded during graph execution | ||||
|  | ||||
| # We use a dynamo-generated index as a level of indirection | ||||
| # this allows us to register objects externally in pre-graph bytecode that we want | ||||
| # to pass to the graph, but not support their types as graph inputs | ||||
| index_to_bytecode_constructor: dict[int, Callable[[PyCodegen], None]] = {} | ||||
|  | ||||
| index_to_external_object_weakref: dict[int, weakref.ReferenceType[Any]] = {} | ||||
|  | ||||
| keep_alive: list[Any] = [] | ||||
|  | ||||
|  | ||||
| def has_user_objects() -> bool: | ||||
|     return bool(index_to_bytecode_constructor) | ||||
|  | ||||
|  | ||||
| def get_external_object_by_index(index: int) -> Any: | ||||
|     assert index in index_to_external_object_weakref, ( | ||||
|         "Index not registered in index_to_user_object_weakref" | ||||
|     ) | ||||
|     obj = index_to_external_object_weakref[index]() | ||||
|     assert obj is not None, "User object is no longer alive" | ||||
|     return index_to_external_object_weakref[index]() | ||||
|  | ||||
|  | ||||
| def store_user_object_weakrefs(*args: Any) -> None: | ||||
|     global index_to_external_object_weakref | ||||
|     index_to_external_object_weakref.clear() | ||||
|     index_to_external_object_weakref.update( | ||||
|         {i: weakref.ref(arg) for i, arg in enumerate(args)} | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def reset_user_object_tracking() -> None: | ||||
|     index_to_bytecode_constructor.clear() | ||||
|     index_to_external_object_weakref.clear() | ||||
|     keep_alive.clear() | ||||
|  | ||||
|  | ||||
| def register_graph_created_object( | ||||
|     example_value: Any, construct_fn: Callable[[int, PyCodegen], None] | ||||
| ) -> int: | ||||
|     global index_to_bytecode_constructor | ||||
|     global keep_alive | ||||
|     keep_alive.append(example_value) | ||||
|     index = len(index_to_bytecode_constructor) | ||||
|     index_to_bytecode_constructor[index] = lambda cg: construct_fn(index, cg) | ||||
|     try: | ||||
|         index_to_external_object_weakref[index] = weakref.ref(example_value) | ||||
|     except TypeError as e: | ||||
|         from .exc import unimplemented_v2 | ||||
|  | ||||
|         unimplemented_v2( | ||||
|             gb_type="Failed to make weakref to graph-created external object", | ||||
|             context=f"user_object: {example_value}", | ||||
|             explanation="Object does not allow us to make a weakref to it", | ||||
|             hints=[], | ||||
|             from_exc=e, | ||||
|         ) | ||||
|     return index | ||||
|  | ||||
|  | ||||
| # Register a user object to be used in the graph | ||||
| def register_user_object(value: Any, source: Source) -> int: | ||||
|     global index_to_bytecode_constructor | ||||
|     index = len(index_to_bytecode_constructor) | ||||
|     index_to_bytecode_constructor[index] = lambda cg: cg(source) | ||||
|     try: | ||||
|         index_to_external_object_weakref[index] = weakref.ref(value) | ||||
|     except TypeError as e: | ||||
|         from .exc import unimplemented_v2 | ||||
|  | ||||
|         unimplemented_v2( | ||||
|             gb_type="Failed to make weakref to User Object", | ||||
|             context=f"user_object: {value}", | ||||
|             explanation="Object does not allow us to make a weakref to it", | ||||
|             hints=[], | ||||
|             from_exc=e, | ||||
|         ) | ||||
|     return index | ||||
| @ -132,6 +132,7 @@ from .source import ( | ||||
|     CodeSource, | ||||
|     ConstantSource, | ||||
|     ConstDictKeySource, | ||||
|     CurrentStreamSource, | ||||
|     DataclassFieldsSource, | ||||
|     DefaultsSource, | ||||
|     DictGetItemSource, | ||||
| @ -181,6 +182,7 @@ from .utils import ( | ||||
|     common_constant_types, | ||||
|     dataclass_fields, | ||||
|     dict_keys, | ||||
|     get_current_stream, | ||||
|     get_custom_getattr, | ||||
|     get_torch_function_mode_stack, | ||||
|     get_torch_function_mode_stack_at, | ||||
| @ -757,6 +759,7 @@ def _get_closure_vars() -> dict[str, object]: | ||||
|             "___dataclass_fields": dataclass_fields, | ||||
|             "___namedtuple_fields": lambda x: x._fields, | ||||
|             "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, | ||||
|             "___get_current_stream": get_current_stream, | ||||
|             "__math_isnan": math.isnan, | ||||
|             "__numpy_isnan": None if np is None else np.isnan, | ||||
|             "inf": float("inf"), | ||||
| @ -1448,6 +1451,13 @@ class GuardBuilder(GuardBuilderBase): | ||||
|                 example_value=example_value, | ||||
|                 guard_manager_enum=guard_manager_enum, | ||||
|             ) | ||||
|         elif istype(source, CurrentStreamSource): | ||||
|             out = root_guard_manager.lambda_manager( | ||||
|                 python_lambda=lambda _: get_current_stream(source.device), | ||||
|                 source=source_name, | ||||
|                 example_value=example_value, | ||||
|                 guard_manager_enum=guard_manager_enum, | ||||
|             ) | ||||
|         elif istype(source, GradSource): | ||||
|             assert base_guard_manager  # to make mypy happy | ||||
|             out = base_guard_manager.grad_manager( | ||||
| @ -2166,6 +2176,8 @@ class GuardBuilder(GuardBuilderBase): | ||||
|                 range, | ||||
|                 dict_keys, | ||||
|                 torch.Size, | ||||
|                 torch.Stream, | ||||
|                 torch.cuda.streams.Stream, | ||||
|                 *np_types, | ||||
|                 *ok_mutable_types, | ||||
|             } | ||||
|  | ||||
| @ -100,6 +100,7 @@ from .exc import ( | ||||
|     unimplemented_v2, | ||||
|     unimplemented_v2_with_warning, | ||||
| ) | ||||
| from .graph_bytecode_inputs import has_user_objects, index_to_bytecode_constructor | ||||
| from .graph_deduplication import apply_graph_deduplication | ||||
| from .graph_region_tracker import GraphRegionTracker | ||||
| from .guards import GuardBuilder, install_guard | ||||
| @ -1512,6 +1513,37 @@ class OutputGraph(OutputGraphCommon): | ||||
|  | ||||
|         from .decorators import disable | ||||
|  | ||||
|         if has_user_objects(): | ||||
|             # NB: This is where we store possible user objects before running the graph | ||||
|             # index_to_user_object_weakref is the function used in the graph to translate | ||||
|             # the dynamo-generated index into the actual object passed to the compiled function. | ||||
|             # We generate bytecode to store all user objects at the proper index in the below | ||||
|             # call. | ||||
|             codegen = PyCodegen( | ||||
|                 self.root_tx, root, overridden_sources=overridden_sources | ||||
|             ) | ||||
|             codegen.add_push_null( | ||||
|                 lambda: codegen.load_import_from( | ||||
|                     torch._dynamo.graph_bytecode_inputs.__name__, | ||||
|                     "store_user_object_weakrefs", | ||||
|                 ) | ||||
|             ) | ||||
|             tmp_vars = [] | ||||
|             for constructor in reversed(index_to_bytecode_constructor.values()): | ||||
|                 constructor(codegen) | ||||
|                 var_name = ( | ||||
|                     self.new_var() | ||||
|                 )  # keep alive any temp objects for the rest of the frame | ||||
|                 codegen.store(var_name) | ||||
|                 tmp_vars.append(var_name) | ||||
|  | ||||
|             for var_name in tmp_vars: | ||||
|                 codegen.append_output(codegen.create_load(var_name)) | ||||
|  | ||||
|             codegen.call_function(len(index_to_bytecode_constructor), False) | ||||
|             codegen.pop_top() | ||||
|             self.add_output_instructions(codegen.get_instructions()) | ||||
|  | ||||
|         # to handle random calls | ||||
|         if len(self.random_calls) > 0: | ||||
|             random_calls_instructions = [] | ||||
| @ -1657,7 +1689,7 @@ class OutputGraph(OutputGraphCommon): | ||||
|                             ) | ||||
|                         elif ( | ||||
|                             vt.source is not None | ||||
|                             and (source := getattr(vt.source, "base", None)) | ||||
|                             and (source := getattr(vt.source, "base", None))  # type: ignore[assignment] | ||||
|                             and source.is_input | ||||
|                         ): | ||||
|                             self.export_metadata.output_return_type[idx] = ( | ||||
|  | ||||
| @ -22,6 +22,7 @@ import enum | ||||
| import functools | ||||
| from typing import Any, Callable, Optional, TYPE_CHECKING, Union | ||||
|  | ||||
| from torch import device as device_type | ||||
| from torch._guards import ChainedSource, Guard, GuardSource, Source | ||||
|  | ||||
| from . import utils | ||||
| @ -1078,6 +1079,30 @@ class ShapeEnvSource(Source): | ||||
|         return GuardSource.SHAPE_ENV | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass(frozen=True) | ||||
| class CurrentStreamSource(Source): | ||||
|     device: device_type | ||||
|  | ||||
|     def name(self) -> str: | ||||
|         return f"___get_current_stream(torch.device('{self.device.type}', {self.device.index}))" | ||||
|  | ||||
|     def reconstruct(self, codegen: "PyCodegen") -> None: | ||||
|         num_args = 1 | ||||
|         codegen.add_push_null( | ||||
|             lambda: codegen.load_import_from(utils.__name__, "get_current_stream") | ||||
|         ) | ||||
|         codegen.add_push_null(lambda: codegen.load_import_from("torch", "device")) | ||||
|         codegen.extend_output([codegen.create_load_const(self.device.type)]) | ||||
|         if self.device.index is not None: | ||||
|             num_args += 1 | ||||
|             codegen.extend_output([codegen.create_load_const(self.device.index)]) | ||||
|         codegen.extend_output(create_call_function(num_args, False)) | ||||
|         codegen.extend_output(create_call_function(1, False)) | ||||
|  | ||||
|     def guard_source(self) -> GuardSource: | ||||
|         return GuardSource.GLOBAL | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass(frozen=True) | ||||
| class BackwardStateSource(Source): | ||||
|     def name(self) -> str: | ||||
|  | ||||
| @ -173,6 +173,7 @@ from .variables.misc import ( | ||||
|     UnknownVariable, | ||||
| ) | ||||
| from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable | ||||
| from .variables.streams import SymbolicStreamState | ||||
| from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable | ||||
| from .variables.torch_function import ( | ||||
|     SymbolicTorchFunctionState, | ||||
| @ -1170,6 +1171,7 @@ class InstructionTranslatorBase( | ||||
|     symbolic_locals: dict[str, VariableTracker] | ||||
|     symbolic_globals: dict[str, VariableTracker] | ||||
|     symbolic_torch_function_state: SymbolicTorchFunctionState | ||||
|     symbolic_stream_state: SymbolicStreamState | ||||
|     post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]] | ||||
|     stack: list[VariableTracker] | ||||
|     instruction_pointer: Optional[int] | ||||
| @ -4069,6 +4071,7 @@ class InstructionTranslatorBase( | ||||
|         symbolic_locals: dict[str, VariableTracker], | ||||
|         symbolic_globals: dict[str, VariableTracker], | ||||
|         symbolic_torch_function_state: SymbolicTorchFunctionState, | ||||
|         symbolic_stream_state: SymbolicStreamState, | ||||
|         f_code: types.CodeType, | ||||
|         export: bool, | ||||
|         inline_depth: int, | ||||
| @ -4088,6 +4091,7 @@ class InstructionTranslatorBase( | ||||
|         self.symbolic_locals = symbolic_locals | ||||
|         self.symbolic_globals = symbolic_globals | ||||
|         self.symbolic_torch_function_state = symbolic_torch_function_state | ||||
|         self.symbolic_stream_state = symbolic_stream_state | ||||
|         # used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals) | ||||
|         # in order to generate any nested closures | ||||
|         self.post_prune_cell_and_freevars = None | ||||
| @ -4241,6 +4245,7 @@ class InstructionTranslator(InstructionTranslatorBase): | ||||
|             # A global var is inserted only after a STORE_GLOBAL happens to it | ||||
|             symbolic_globals={}, | ||||
|             symbolic_torch_function_state=None,  # type: ignore[arg-type] # set below | ||||
|             symbolic_stream_state=None,  # type: ignore[arg-type] # set below | ||||
|             f_code=f_code, | ||||
|             export=export, | ||||
|             inline_depth=0, | ||||
| @ -4345,6 +4350,8 @@ class InstructionTranslator(InstructionTranslatorBase): | ||||
|                 torch_function_mode_stack | ||||
|             ) | ||||
|  | ||||
|             self.symbolic_stream_state = SymbolicStreamState() | ||||
|  | ||||
|             if export: | ||||
|                 # export gets confused if we never realize unused inputs | ||||
|                 # in export mode just eagerly realize everything | ||||
| @ -4673,6 +4680,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): | ||||
|                 sub_locals, | ||||
|                 parent.symbolic_globals, | ||||
|                 parent.symbolic_torch_function_state, | ||||
|                 parent.symbolic_stream_state, | ||||
|                 func, | ||||
|             ) | ||||
|         else: | ||||
| @ -4684,6 +4692,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): | ||||
|                 sub_locals, | ||||
|                 parent.symbolic_globals, | ||||
|                 parent.symbolic_torch_function_state, | ||||
|                 parent.symbolic_stream_state, | ||||
|                 # pyrefly: ignore  # bad-argument-type | ||||
|                 func, | ||||
|             ) | ||||
| @ -4767,6 +4776,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): | ||||
|         symbolic_locals: dict[str, VariableTracker], | ||||
|         symbolic_globals: dict[str, VariableTracker], | ||||
|         symbolic_torch_function_state: SymbolicTorchFunctionState, | ||||
|         symbolic_stream_state: SymbolicStreamState, | ||||
|         funcvar: BaseUserFunctionVariable, | ||||
|     ) -> None: | ||||
|         f_globals = funcvar.get_globals()  # type: ignore[attr-defined] | ||||
| @ -4800,6 +4810,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): | ||||
|             symbolic_locals=symbolic_locals, | ||||
|             symbolic_globals=symbolic_globals, | ||||
|             symbolic_torch_function_state=symbolic_torch_function_state, | ||||
|             symbolic_stream_state=symbolic_stream_state, | ||||
|             instructions=instructions, | ||||
|             code_options={k: getattr(code, k) for k in get_code_keys()}, | ||||
|             f_code=code, | ||||
|  | ||||
| @ -4655,6 +4655,10 @@ def clear_torch_function_mode_stack() -> None: | ||||
|         _pop_torch_function_stack() | ||||
|  | ||||
|  | ||||
| def get_current_stream(device: torch.device) -> torch.Stream: | ||||
|     return torch.accelerator.current_stream(device) | ||||
|  | ||||
|  | ||||
| # call from C dynamo in order to inspect values in pdb | ||||
| def _breakpoint_for_c_dynamo(*args: Any) -> None: | ||||
|     breakpoint() | ||||
| @ -4719,34 +4723,6 @@ def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]: | ||||
|     return tensor_dict | ||||
|  | ||||
|  | ||||
| # This is useful for reconstructing within the Dynamo graph the non-graph-input objects | ||||
| # whose lifetime is governed by the user. | ||||
| # e.g. torch.cuda.Event is a prime example. | ||||
| user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {} | ||||
|  | ||||
|  | ||||
| def get_user_object_from_id(obj_id: int) -> Any: | ||||
|     obj = user_obj_id_to_weakref[obj_id]() | ||||
|     assert obj is not None, "User object is no longer alive" | ||||
|     return obj | ||||
|  | ||||
|  | ||||
| def store_user_object_weakref(obj: object) -> None: | ||||
|     obj_id = id(obj) | ||||
|     try: | ||||
|         user_obj_id_to_weakref[obj_id] = weakref.ref(obj) | ||||
|     except TypeError as e: | ||||
|         from .exc import unimplemented_v2 | ||||
|  | ||||
|         unimplemented_v2( | ||||
|             gb_type="Failed to make weakref to User Object", | ||||
|             context=f"user_objected: {obj}", | ||||
|             explanation="Object does not allow us to make a weakref to it", | ||||
|             hints=[], | ||||
|             from_exc=e, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class CompileTimeInstructionCounter: | ||||
|     _counter: int = 0 | ||||
|     _id: int = -1 | ||||
|  | ||||
| @ -37,8 +37,6 @@ from .ctx_manager import ( | ||||
|     JvpIncrementNestingCtxManagerVariable, | ||||
|     SDPAKernelVariable, | ||||
|     SetFwdGradEnabledContextManager, | ||||
|     StreamContextVariable, | ||||
|     StreamVariable, | ||||
|     TemporarilyPopInterpreterStackCtxManagerVariable, | ||||
|     VmapIncrementNestingCtxManagerVariable, | ||||
|     WithEnterFunctionVariable, | ||||
| @ -131,6 +129,7 @@ from .nn_module import ( | ||||
| ) | ||||
| from .optimizer import OptimizerVariable | ||||
| from .sdpa import SDPAParamsVariable | ||||
| from .streams import EventVariable, StreamContextVariable, StreamVariable | ||||
| from .tensor import ( | ||||
|     DataPtrVariable, | ||||
|     FakeItemVariable, | ||||
|  | ||||
| @ -45,6 +45,10 @@ import sympy | ||||
| import torch | ||||
| from torch import SymInt | ||||
| from torch._dispatch.python import enable_python_dispatcher | ||||
| from torch._dynamo.graph_bytecode_inputs import ( | ||||
|     get_external_object_by_index, | ||||
|     register_user_object, | ||||
| ) | ||||
| from torch._dynamo.utils import ( | ||||
|     get_metrics_context, | ||||
|     is_int_specialization_case, | ||||
| @ -172,11 +176,8 @@ from .ctx_manager import ( | ||||
|     AutocastModeVariable, | ||||
|     DynamoConfigPatchVariable, | ||||
|     ErrorOnGraphBreakVariable, | ||||
|     EventVariable, | ||||
|     NullContextVariable, | ||||
|     PreserveVersionContextVariable, | ||||
|     StreamContextVariable, | ||||
|     StreamVariable, | ||||
| ) | ||||
| from .dicts import ( | ||||
|     ConstDictVariable, | ||||
| @ -257,6 +258,7 @@ from .nn_module import ( | ||||
| from .optimizer import OptimizerVariable | ||||
| from .script_object import TorchScriptObjectVariable | ||||
| from .sdpa import SDPAParamsVariable | ||||
| from .streams import EventVariable, StreamContextVariable, StreamVariable | ||||
| from .tensor import ( | ||||
|     NumpyNdarrayVariable, | ||||
|     supported_const_comparison_op_values, | ||||
| @ -1036,24 +1038,20 @@ class VariableBuilder: | ||||
|             stream_var = VariableBuilder(self.tx, stream_source)(value.stream) | ||||
|             return StreamContextVariable.create(self.tx, stream_var) | ||||
|         elif isinstance(value, torch.Stream): | ||||
|             self.install_guards(GuardBuilder.ID_MATCH) | ||||
|             # This refers to the device-agnostic torch.Stream | ||||
|             self.install_guards(GuardBuilder.TYPE_MATCH) | ||||
|             index = register_user_object(value, self.source) | ||||
|             stream_proxy = self.tx.output.create_proxy( | ||||
|                 "call_function", | ||||
|                 type(value), | ||||
|                 (), | ||||
|                 { | ||||
|                     "stream_id": value.stream_id, | ||||
|                     "device_index": value.device_index, | ||||
|                     "device_type": value.device_type, | ||||
|                 }, | ||||
|                 "call_function", get_external_object_by_index, (index,), {} | ||||
|             ) | ||||
|             set_example_value(stream_proxy.node, value) | ||||
|             return StreamVariable( | ||||
|             var = StreamVariable( | ||||
|                 stream_proxy, | ||||
|                 value, | ||||
|                 value.device, | ||||
|                 source=self.source, | ||||
|             ) | ||||
|             return self.tx.output.side_effects.track_object_existing(value, var) | ||||
|         elif isinstance(value, (torch._C._SDPAParams)): | ||||
|             self.install_guards(GuardBuilder.TYPE_MATCH) | ||||
|             return SDPAParamsVariable.create(self.tx, value, self.source) | ||||
| @ -1061,12 +1059,12 @@ class VariableBuilder: | ||||
|             self.install_guards(GuardBuilder.ID_MATCH) | ||||
|             return FuncTorchInterpreterVariable(value) | ||||
|         elif isinstance(value, torch.Event): | ||||
|             self.install_guards(GuardBuilder.ID_MATCH) | ||||
|             torch._dynamo.utils.store_user_object_weakref(value) | ||||
|             self.install_guards(GuardBuilder.TYPE_MATCH) | ||||
|             index = register_user_object(value, self.source) | ||||
|             event_proxy = self.tx.output.create_proxy( | ||||
|                 "call_function", | ||||
|                 torch._dynamo.utils.get_user_object_from_id, | ||||
|                 (id(value),), | ||||
|                 get_external_object_by_index, | ||||
|                 (index,), | ||||
|                 {}, | ||||
|             ) | ||||
|             set_example_value(event_proxy.node, value) | ||||
| @ -2980,8 +2978,9 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe | ||||
|         set_example_value(proxy.node, example_value) | ||||
|         return SymNodeVariable(proxy, example_value, **options) | ||||
|     elif ( | ||||
|         inspect.isclass(proxy.node.target) | ||||
|         and issubclass(proxy.node.target, torch.Stream) | ||||
|         isinstance(example_value, torch.Stream) | ||||
|         and proxy.node.target | ||||
|         in (get_external_object_by_index, torch.accelerator.current_stream) | ||||
|     ) or proxy.node.target in [ | ||||
|         device_interface.current_stream | ||||
|         for _, device_interface in get_registered_device_interfaces() | ||||
|  | ||||
| @ -83,7 +83,6 @@ from ..utils import ( | ||||
| ) | ||||
| from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker | ||||
| from .constant import ConstantVariable | ||||
| from .ctx_manager import EventVariable, StreamVariable | ||||
| from .dicts import ( | ||||
|     ConstDictVariable, | ||||
|     DefaultDictVariable, | ||||
| @ -101,6 +100,7 @@ from .lists import ( | ||||
|     TupleIteratorVariable, | ||||
|     TupleVariable, | ||||
| ) | ||||
| from .streams import EventVariable, StreamVariable | ||||
| from .tensor import ( | ||||
|     FakeItemVariable, | ||||
|     supported_comparison_ops, | ||||
|  | ||||
| @ -34,7 +34,6 @@ from ..bytecode_transformation import ( | ||||
|     create_instruction, | ||||
|     create_setup_with, | ||||
| ) | ||||
| from ..device_interface import get_interface_for_device | ||||
| from ..exc import unimplemented_v2 | ||||
| from ..guards import GuardBuilder, install_guard | ||||
| from ..source import AttrSource, GlobalStateSource | ||||
| @ -991,70 +990,6 @@ class ProfilerContextVariable(ContextWrappingVariable): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class StreamContextVariable(ContextWrappingVariable): | ||||
|     @staticmethod | ||||
|     def create(tx: "InstructionTranslator", target_value, **kwargs): | ||||
|         from .builder import wrap_fx_proxy_cls | ||||
|  | ||||
|         current_stream_method = get_interface_for_device( | ||||
|             target_value.device | ||||
|         ).current_stream | ||||
|         current_stream = wrap_fx_proxy_cls( | ||||
|             StreamVariable, | ||||
|             tx, | ||||
|             tx.output.create_proxy( | ||||
|                 "call_function", | ||||
|                 current_stream_method, | ||||
|                 (None,), | ||||
|                 {}, | ||||
|             ), | ||||
|         ) | ||||
|         return StreamContextVariable( | ||||
|             target_values=[target_value], | ||||
|             initial_values=[current_stream], | ||||
|             device=target_value.device, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|     def __init__(self, target_values, device, initial_values=None, **kwargs) -> None: | ||||
|         super().__init__( | ||||
|             target_values=target_values, initial_values=initial_values, **kwargs | ||||
|         ) | ||||
|         self.device = device | ||||
|         self.set_stream = get_interface_for_device(self.device).set_stream | ||||
|         self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id | ||||
|  | ||||
|     def enter(self, tx): | ||||
|         # stream generated inside the traced function | ||||
|         if self.target_values[0].as_proxy() is not None: | ||||
|             tx.output.create_proxy( | ||||
|                 "call_function", | ||||
|                 self.set_stream, | ||||
|                 (self.target_values[0].as_proxy(),), | ||||
|                 {}, | ||||
|             ) | ||||
|         # stream passed from outside the traced function | ||||
|         else: | ||||
|             stream = self.target_values[0].value | ||||
|             tx.output.create_proxy( | ||||
|                 "call_function", | ||||
|                 self.set_stream_id, | ||||
|                 (stream.stream_id, stream.device_index, stream.device_type), | ||||
|                 {}, | ||||
|             ) | ||||
|         self.set_stream(self.target_values[0].value) | ||||
|         self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value)) | ||||
|  | ||||
|     def exit(self, tx: "InstructionTranslator", *args): | ||||
|         tx.output.create_proxy( | ||||
|             "call_function", | ||||
|             self.set_stream, | ||||
|             (self.initial_values[0].as_proxy(),), | ||||
|             {}, | ||||
|         ) | ||||
|         self.cleanup_assert() | ||||
|  | ||||
|  | ||||
| class PreserveVersionContextVariable(ContextWrappingVariable): | ||||
|     """ | ||||
|     Wraps torch.autograd._unsafe_preserve_version_counter | ||||
| @ -1290,142 +1225,6 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable): | ||||
|         return "annotate" | ||||
|  | ||||
|  | ||||
| class StreamVariable(VariableTracker): | ||||
|     def __init__(self, proxy, value, device, **kwargs) -> None: | ||||
|         if proxy is not None and "example_value" in proxy.node.meta: | ||||
|             assert proxy.node.meta["example_value"] == value | ||||
|         assert value.device.type == device.type, ( | ||||
|             "stream value is not equal to the passed device" | ||||
|         ) | ||||
|         super().__init__(**kwargs) | ||||
|         self.proxy = proxy | ||||
|         self.value = value | ||||
|         self.device = device | ||||
|  | ||||
|     def python_type(self): | ||||
|         return torch.Stream | ||||
|  | ||||
|     def call_method( | ||||
|         self, | ||||
|         tx, | ||||
|         name, | ||||
|         args: "list[VariableTracker]", | ||||
|         kwargs: "dict[str, VariableTracker]", | ||||
|     ) -> "VariableTracker": | ||||
|         assert hasattr(self.value, name), f"no stream method found named {name}" | ||||
|  | ||||
|         from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs | ||||
|         from .builder import wrap_fx_proxy_cls | ||||
|  | ||||
|         if name in ("wait_stream", "synchronize", "wait_event"): | ||||
|             tx.output.create_proxy( | ||||
|                 "call_method", name, *proxy_args_kwargs([self] + args, kwargs) | ||||
|             ) | ||||
|             return variables.ConstantVariable(None) | ||||
|         elif name == "query": | ||||
|             return wrap_fx_proxy_cls( | ||||
|                 target_cls=variables.ConstantVariable, | ||||
|                 tx=tx, | ||||
|                 proxy=tx.output.create_proxy( | ||||
|                     "call_method", name, *proxy_args_kwargs([self] + args, kwargs) | ||||
|                 ), | ||||
|             ) | ||||
|         elif name == "record_event": | ||||
|             return wrap_fx_proxy_cls( | ||||
|                 target_cls=EventVariable, | ||||
|                 tx=tx, | ||||
|                 proxy=tx.output.create_proxy( | ||||
|                     "call_method", name, *proxy_args_kwargs([self] + args, kwargs) | ||||
|                 ), | ||||
|             ) | ||||
|         elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs: | ||||
|             # NB : Checking for mutation is necessary because we compare | ||||
|             # constant values | ||||
|             other = args[0] | ||||
|             if not isinstance(other, StreamVariable): | ||||
|                 return variables.ConstantVariable.create(NotImplemented) | ||||
|             return variables.ConstantVariable.create( | ||||
|                 cmp_name_to_op_mapping[name](self.value, other.value) | ||||
|             ) | ||||
|  | ||||
|         return super().call_method(tx, name, args, kwargs) | ||||
|  | ||||
|     def as_proxy(self): | ||||
|         return self.proxy | ||||
|  | ||||
|     def reconstruct(self, codegen: "PyCodegen"): | ||||
|         # If we got here, this stream is fully subsumed by the graph - this means it is | ||||
|         # not an input or global | ||||
|         assert not self.source | ||||
|         # Since we just proved that - for other such structures, like lists and dicts, reconstruction | ||||
|         # is fine and sound according to dynamo principles of treating collectives. However, | ||||
|         # streams are special in that we want to preserve the identity of the stream as the same as in the graph | ||||
|         # Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not | ||||
|         # yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending | ||||
|         # design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there. | ||||
|         prefix = f"_stream_{self.device}" | ||||
|         name = codegen.tx.output.install_global_by_id(prefix, self.value) | ||||
|         codegen.append_output(codegen.create_load_global(name, add=True)) | ||||
|  | ||||
|  | ||||
| class EventVariable(VariableTracker): | ||||
|     def __init__(self, proxy, value, **kwargs) -> None: | ||||
|         if proxy is not None and "example_value" in proxy.node.meta: | ||||
|             assert proxy.node.meta["example_value"] == value | ||||
|         super().__init__(**kwargs) | ||||
|         self.proxy = proxy | ||||
|         self.value = value | ||||
|  | ||||
|     def call_method( | ||||
|         self, | ||||
|         tx, | ||||
|         name, | ||||
|         args: "list[VariableTracker]", | ||||
|         kwargs: "dict[str, VariableTracker]", | ||||
|     ) -> "VariableTracker": | ||||
|         from ..utils import proxy_args_kwargs | ||||
|         from .builder import wrap_fx_proxy_cls | ||||
|  | ||||
|         if name in ("wait", "record", "synchronize"): | ||||
|             tx.output.create_proxy( | ||||
|                 "call_method", name, *proxy_args_kwargs([self] + args, kwargs) | ||||
|             ) | ||||
|             return variables.ConstantVariable(None) | ||||
|         elif name == "query": | ||||
|             return wrap_fx_proxy_cls( | ||||
|                 target_cls=variables.ConstantVariable, | ||||
|                 tx=tx, | ||||
|                 proxy=tx.output.create_proxy( | ||||
|                     "call_method", name, *proxy_args_kwargs([self] + args, kwargs) | ||||
|                 ), | ||||
|             ) | ||||
|         else: | ||||
|             method_name = ( | ||||
|                 f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}" | ||||
|             ) | ||||
|             unimplemented_v2( | ||||
|                 gb_type="Unsupported event method", | ||||
|                 context=str(name), | ||||
|                 explanation=f"Dynamo doesn't support tracing the {method_name} method. " | ||||
|                 f"We currently support wait, record, synchronize, and query.", | ||||
|                 hints=[ | ||||
|                     *graph_break_hints.SUPPORTABLE, | ||||
|                 ], | ||||
|             ) | ||||
|  | ||||
|     def as_proxy(self): | ||||
|         return self.proxy | ||||
|  | ||||
|     def reconstruct(self, codegen: "PyCodegen"): | ||||
|         # If we got here, this event is fully subsumed by the graph - this means it is | ||||
|         # not an input or global | ||||
|         assert not self.source | ||||
|         # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there. | ||||
|         prefix = "_event" | ||||
|         name = codegen.tx.output.install_global_by_id(prefix, self.value) | ||||
|         codegen.append_output(codegen.create_load_global(name, add=True)) | ||||
|  | ||||
|  | ||||
| class DynamoConfigPatchVariable(ContextWrappingVariable): | ||||
|     """represents torch._dynamo.patch_dynamo_config""" | ||||
|  | ||||
|  | ||||
							
								
								
									
										418
									
								
								torch/_dynamo/variables/streams.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										418
									
								
								torch/_dynamo/variables/streams.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,418 @@ | ||||
| import collections | ||||
| from typing import Any, Optional | ||||
|  | ||||
| import torch | ||||
| from torch.fx import Proxy | ||||
|  | ||||
| from .. import graph_break_hints | ||||
| from ..bytecode_transformation import create_call_function | ||||
| from ..device_interface import get_interface_for_device | ||||
| from ..exc import TYPE_CHECKING, unimplemented_v2 | ||||
| from ..source import AttrSource, CallFunctionNoArgsSource, TorchSource | ||||
| from .base import VariableTracker | ||||
| from .constant import ConstantVariable | ||||
| from .ctx_manager import ContextWrappingVariable | ||||
| from .lazy import LazyVariableTracker | ||||
| from .misc import GetAttrVariable | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from torch._dynamo.symbolic_convert import InstructionTranslator | ||||
|  | ||||
|     from ..codegen import PyCodegen | ||||
|  | ||||
| from torch._library.custom_ops import custom_op | ||||
|  | ||||
|  | ||||
| Tensor = torch.Tensor | ||||
|  | ||||
| from torch._higher_order_ops.effects import _EffectType, _register_effectful_op | ||||
|  | ||||
|  | ||||
| @custom_op("streams::fork", mutates_args=()) | ||||
| def fork_stream( | ||||
|     from_index: int, | ||||
|     from_device: torch.device, | ||||
|     to_index: int, | ||||
|     to_device: torch.device, | ||||
| ) -> int: | ||||
|     return from_index | ||||
|  | ||||
|  | ||||
| @fork_stream.register_fake | ||||
| def _( | ||||
|     from_index: int, | ||||
|     from_device: torch.device, | ||||
|     to_index: int, | ||||
|     to_device: torch.device, | ||||
| ) -> int: | ||||
|     return from_index | ||||
|  | ||||
|  | ||||
| def fork_backward(ctx, grad_output): | ||||
|     from_index, from_device, to_index, to_device = ctx.args | ||||
|     from_index = join_stream(to_index, to_device, from_index, from_device) | ||||
|     return None, from_index, None, None, None, None | ||||
|  | ||||
|  | ||||
| def fork_setup_context(ctx, inputs, output): | ||||
|     from_index, from_device, to_index, to_device, _ = inputs | ||||
|     ctx.args = (from_index, from_device, to_index, to_device) | ||||
|  | ||||
|  | ||||
| _register_effectful_op(fork_stream._opoverload, _EffectType.ORDERED) | ||||
| fork_stream.register_autograd(fork_backward, setup_context=fork_setup_context) | ||||
|  | ||||
|  | ||||
| @custom_op("streams::join", mutates_args=()) | ||||
| def join_stream( | ||||
|     from_index: int, | ||||
|     from_device: torch.device, | ||||
|     to_index: int, | ||||
|     to_device: torch.device, | ||||
| ) -> int: | ||||
|     return from_index | ||||
|  | ||||
|  | ||||
| @join_stream.register_fake | ||||
| def _( | ||||
|     from_index: int, | ||||
|     from_device: torch.device, | ||||
|     to_index: int, | ||||
|     to_device: torch.device, | ||||
| ) -> int: | ||||
|     return from_index | ||||
|  | ||||
|  | ||||
| def join_backward(ctx, grad_output): | ||||
|     from_index, from_device, to_index, to_device = ctx.args | ||||
|     from_index = fork_stream(from_index, from_device, to_index, to_device) | ||||
|     return None, from_index, None, None, None, None | ||||
|  | ||||
|  | ||||
| def join_setup_context(ctx, inputs, output): | ||||
|     from_index, from_device, to_index, to_device = inputs | ||||
|     ctx.args = (from_index, from_device, to_index, to_device) | ||||
|  | ||||
|  | ||||
| _register_effectful_op(join_stream._opoverload, _EffectType.ORDERED) | ||||
| join_stream.register_autograd(join_backward, setup_context=join_setup_context) | ||||
|  | ||||
|  | ||||
| class SymbolicStreamState: | ||||
|     """Track the currently entered stream if any""" | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         from ..source import CurrentStreamSource | ||||
|  | ||||
|         stream_var = LazyVariableTracker.create( | ||||
|             torch.accelerator.current_stream(), | ||||
|             source=CurrentStreamSource(torch.accelerator.current_stream().device), | ||||
|         ) | ||||
|         self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque( | ||||
|             [stream_var]  # type: ignore[list-item] | ||||
|         ) | ||||
|  | ||||
|     def enter_stream(self, stream: "StreamVariable") -> None: | ||||
|         self.cur_stream_stack.append(stream) | ||||
|  | ||||
|     def exit_stream(self) -> None: | ||||
|         self.cur_stream_stack.pop() | ||||
|  | ||||
|     def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable": | ||||
|         if device is not None: | ||||
|             for stream in reversed(self.cur_stream_stack): | ||||
|                 if stream.device == device: | ||||
|                     return stream | ||||
|  | ||||
|         return self.cur_stream_stack[-1] | ||||
|  | ||||
|     def in_stream_context(self) -> bool: | ||||
|         return len(self.cur_stream_stack) > 0 | ||||
|  | ||||
|  | ||||
| class StreamContextVariable(ContextWrappingVariable): | ||||
|     """This represents torch.cuda.StreamContext""" | ||||
|  | ||||
|     @staticmethod | ||||
|     def create( | ||||
|         tx: "InstructionTranslator", | ||||
|         target_value: "StreamVariable", | ||||
|         **kwargs: dict[str, Any], | ||||
|     ) -> "StreamContextVariable": | ||||
|         return StreamContextVariable( | ||||
|             target_values=[target_value], | ||||
|             initial_values=[ | ||||
|                 StreamContextVariable._get_current_stream(target_value.device, tx) | ||||
|             ], | ||||
|             device=target_value.device, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         target_values: list["StreamVariable"], | ||||
|         device: torch.device, | ||||
|         initial_values: Optional[list["StreamVariable"]] = None, | ||||
|         **kwargs: dict[str, Any], | ||||
|     ) -> None: | ||||
|         super().__init__( | ||||
|             target_values=target_values, initial_values=initial_values, **kwargs | ||||
|         ) | ||||
|         self.device = device | ||||
|  | ||||
|     def enter(self, tx: "InstructionTranslator") -> "VariableTracker": | ||||
|         # to stream, from stream is the order of the arguments | ||||
|         # we are entering the target, and leaving the initial stream | ||||
|         tx.symbolic_stream_state.enter_stream(self._get_target_values()[0]) | ||||
|         tx.output.create_proxy( | ||||
|             "call_function", | ||||
|             torch.ops.streams.fork.default, | ||||
|             self._target_stream_proxies() + self._initial_stream_proxies(), | ||||
|             {}, | ||||
|         ) | ||||
|         return ConstantVariable.create(None) | ||||
|  | ||||
|     def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker": | ||||
|         # to stream, from stream is the order of the arguments | ||||
|         # we are leaving the target, and entering the initial stream | ||||
|         tx.symbolic_stream_state.exit_stream() | ||||
|         tx.output.create_proxy( | ||||
|             "call_function", | ||||
|             torch.ops.streams.join.default, | ||||
|             self._initial_stream_proxies() + self._target_stream_proxies(), | ||||
|             {}, | ||||
|         ) | ||||
|         return ConstantVariable.create(None) | ||||
|  | ||||
|     def _initial_stream_proxies(self) -> tuple[Proxy, Proxy]: | ||||
|         assert self.initial_values, "No initial stream to move from" | ||||
|         return StreamContextVariable._extract_stream_properties( | ||||
|             self.initial_values[0].as_proxy() | ||||
|         ) | ||||
|  | ||||
|     def _target_stream_proxies(self) -> tuple[Proxy, Proxy]: | ||||
|         return StreamContextVariable._extract_stream_properties( | ||||
|             self._get_target_values()[0].as_proxy() | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _extract_stream_properties(stream_proxy: Proxy) -> tuple[Proxy, Proxy]: | ||||
|         stream_index = GetAttrVariable.create_getattr_proxy(stream_proxy, "stream_id") | ||||
|         stream_device = GetAttrVariable.create_getattr_proxy(stream_proxy, "device") | ||||
|         return stream_index, stream_device | ||||
|  | ||||
|     @staticmethod | ||||
|     def _get_current_stream( | ||||
|         device: torch.device, tx: "InstructionTranslator" | ||||
|     ) -> "StreamVariable": | ||||
|         from .builder import wrap_fx_proxy_cls | ||||
|  | ||||
|         current_stream_method = get_interface_for_device(device).current_stream | ||||
|         current_stream = wrap_fx_proxy_cls( | ||||
|             StreamVariable, | ||||
|             tx, | ||||
|             tx.output.create_proxy( | ||||
|                 "call_function", | ||||
|                 current_stream_method, | ||||
|                 (None,), | ||||
|                 {}, | ||||
|             ), | ||||
|         ) | ||||
|         return current_stream | ||||
|  | ||||
|     def _get_target_values(self) -> list["StreamVariable"]: | ||||
|         # We need this to be overridable, since StreamVariable does | ||||
|         # not store target values (it does not require any arguments) | ||||
|         # and captures the current stream at the time of entering the context | ||||
|         return self.target_values | ||||
|  | ||||
|     def supports_graph_breaks(self) -> bool: | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class StreamVariable(StreamContextVariable): | ||||
|     """Represents the device-agnostic torch.Stream class""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         proxy: Proxy, | ||||
|         value: torch.Stream, | ||||
|         device: torch.device, | ||||
|         **kwargs: Any, | ||||
|     ) -> None: | ||||
|         # Index into the user object table | ||||
|         # used to pass arbitrary objects to the graph | ||||
|         user_object_index = kwargs.pop("user_obj_index", None) | ||||
|         if proxy is not None and "example_value" in proxy.node.meta: | ||||
|             assert proxy.node.meta["example_value"] == value | ||||
|         assert value.device.type == device.type, ( | ||||
|             "stream value is not equal to the passed device" | ||||
|         ) | ||||
|         super().__init__(target_values=[], initial_values=None, device=device, **kwargs) | ||||
|         self.proxy = proxy | ||||
|         self.value = value | ||||
|         self.device = device | ||||
|  | ||||
|         self.user_object_index = user_object_index | ||||
|  | ||||
|     def python_type(self) -> type: | ||||
|         return torch.Stream | ||||
|  | ||||
|     def call_method( | ||||
|         self, | ||||
|         tx: "InstructionTranslator", | ||||
|         name: str, | ||||
|         args: list[VariableTracker], | ||||
|         kwargs: dict[str, VariableTracker], | ||||
|     ) -> "VariableTracker": | ||||
|         assert hasattr(self.value, name), f"no stream method found named {name}" | ||||
|  | ||||
|         from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs | ||||
|         from .builder import wrap_fx_proxy_cls | ||||
|  | ||||
|         if name in ("wait_stream", "synchronize", "wait_event"): | ||||
|             tx.output.create_proxy( | ||||
|                 "call_method", name, *proxy_args_kwargs([self] + args, kwargs) | ||||
|             ) | ||||
|             return ConstantVariable(None) | ||||
|         elif name == "query": | ||||
|             return wrap_fx_proxy_cls( | ||||
|                 target_cls=ConstantVariable, | ||||
|                 tx=tx, | ||||
|                 proxy=tx.output.create_proxy( | ||||
|                     "call_method", name, *proxy_args_kwargs([self] + args, kwargs) | ||||
|                 ), | ||||
|             ) | ||||
|         elif name == "record_event": | ||||
|             return wrap_fx_proxy_cls( | ||||
|                 target_cls=EventVariable, | ||||
|                 tx=tx, | ||||
|                 proxy=tx.output.create_proxy( | ||||
|                     "call_method", name, *proxy_args_kwargs([self] + args, kwargs) | ||||
|                 ), | ||||
|             ) | ||||
|         elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs: | ||||
|             from ..guards import GuardBuilder, install_guard | ||||
|  | ||||
|             if self.source: | ||||
|                 install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) | ||||
|  | ||||
|             # NB : Checking for mutation is necessary because we compare | ||||
|             # constant values | ||||
|             other = args[0] | ||||
|             if not isinstance(other, StreamVariable): | ||||
|                 return ConstantVariable.create(NotImplemented) | ||||
|  | ||||
|             if other.source: | ||||
|                 install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) | ||||
|             return ConstantVariable.create( | ||||
|                 cmp_name_to_op_mapping[name](self.value, other.value)  # type: ignore[arg-type] | ||||
|             ) | ||||
|  | ||||
|         return super().call_method(tx, name, args, kwargs) | ||||
|  | ||||
|     def enter(self, tx: "InstructionTranslator") -> "VariableTracker": | ||||
|         # NB: Set initial values when we enter | ||||
|         # Don't do this at object creation, as we need to record the current stream | ||||
|         # at the time the context is entered. | ||||
|         self.initial_values = [ | ||||
|             StreamContextVariable._get_current_stream(self.device, tx) | ||||
|         ] | ||||
|         return super().enter(tx) | ||||
|  | ||||
|     def as_proxy(self) -> Proxy: | ||||
|         return self.proxy | ||||
|  | ||||
|     def module_name(self) -> str: | ||||
|         return "torch._C" | ||||
|  | ||||
|     def fn_name(self) -> str: | ||||
|         return "Stream" | ||||
|  | ||||
|     def reconstruct(self, codegen: "PyCodegen") -> None: | ||||
|         # If we got here, this stream is fully subsumed by the graph - this means it is | ||||
|         # not an input or global | ||||
|         assert not self.source | ||||
|         if self.user_object_index is not None: | ||||
|             codegen.add_push_null( | ||||
|                 lambda: codegen.load_import_from( | ||||
|                     torch._dynamo.graph_bytecode_inputs.__name__, | ||||
|                     "get_external_object_by_index", | ||||
|                 ) | ||||
|             ) | ||||
|             codegen.append_output(codegen.create_load_const(self.user_object_index)) | ||||
|             codegen.extend_output(create_call_function(1, False)) | ||||
|         else: | ||||
|             # TODO mlazos: evaluate if we still need this | ||||
|             prefix = f"_stream_{self.device}" | ||||
|             name = codegen.tx.output.install_global_by_id(prefix, self.value) | ||||
|             codegen.append_output(codegen.create_load_global(name, add=True)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def construct_in_graph_stream(index: int, codegen: "PyCodegen") -> None: | ||||
|         # Use source to create the right bytecode, this | ||||
|         # isn't an actual input | ||||
|         source = CallFunctionNoArgsSource(AttrSource(TorchSource(), "Stream")) | ||||
|         codegen(source) | ||||
|  | ||||
|     def _get_target_values(self) -> list["StreamVariable"]: | ||||
|         return [self] | ||||
|  | ||||
|  | ||||
| class EventVariable(VariableTracker): | ||||
|     def __init__(self, proxy: Proxy, value: torch.Event, **kwargs: Any) -> None: | ||||
|         if proxy is not None and "example_value" in proxy.node.meta: | ||||
|             assert proxy.node.meta["example_value"] == value | ||||
|         super().__init__(**kwargs) | ||||
|         self.proxy = proxy | ||||
|         self.value = value | ||||
|  | ||||
|     def call_method( | ||||
|         self, | ||||
|         tx: "InstructionTranslator", | ||||
|         name: str, | ||||
|         args: list[VariableTracker], | ||||
|         kwargs: dict[str, VariableTracker], | ||||
|     ) -> VariableTracker: | ||||
|         from ..utils import proxy_args_kwargs | ||||
|         from .builder import wrap_fx_proxy_cls | ||||
|  | ||||
|         if name in ("wait", "record", "synchronize"): | ||||
|             tx.output.create_proxy( | ||||
|                 "call_method", name, *proxy_args_kwargs([self] + args, kwargs) | ||||
|             ) | ||||
|             return ConstantVariable(None) | ||||
|         elif name == "query": | ||||
|             return wrap_fx_proxy_cls( | ||||
|                 target_cls=ConstantVariable, | ||||
|                 tx=tx, | ||||
|                 proxy=tx.output.create_proxy( | ||||
|                     "call_method", name, *proxy_args_kwargs([self] + args, kwargs) | ||||
|                 ), | ||||
|             ) | ||||
|         else: | ||||
|             method_name = ( | ||||
|                 f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}" | ||||
|             ) | ||||
|             unimplemented_v2( | ||||
|                 gb_type="Unsupported event method", | ||||
|                 context=str(name), | ||||
|                 explanation=f"Dynamo doesn't support tracing the {method_name} method. " | ||||
|                 f"We currently support wait, record, synchronize, and query.", | ||||
|                 hints=[ | ||||
|                     *graph_break_hints.SUPPORTABLE, | ||||
|                 ], | ||||
|             ) | ||||
|  | ||||
|     def as_proxy(self) -> Proxy: | ||||
|         return self.proxy | ||||
|  | ||||
|     def reconstruct(self, codegen: "PyCodegen") -> None: | ||||
|         # If we got here, this event is fully subsumed by the graph - this means it is | ||||
|         # not an input or global | ||||
|         assert not self.source | ||||
|         # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there. | ||||
|         prefix = "_event" | ||||
|         name = codegen.tx.output.install_global_by_id(prefix, self.value) | ||||
|         codegen.append_output(codegen.create_load_global(name, add=True)) | ||||
| @ -1237,6 +1237,35 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): | ||||
|             # pyrefly: ignore  # unbound-name | ||||
|             return VariableTracker.build(tx, module, new_source) | ||||
|  | ||||
|         @register(torch.accelerator.current_stream) | ||||
|         def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs): | ||||
|             if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs): | ||||
|                 unimplemented_v2( | ||||
|                     gb_type="unsupported arguments to torch.accelerator.current_stream", | ||||
|                     context=f"args={args}, kwargs={kwargs}", | ||||
|                     explanation="torch.accelerator.current_stream accepts one optional argument `device`", | ||||
|                     hints=[ | ||||
|                         *graph_break_hints.USER_ERROR, | ||||
|                     ], | ||||
|                 ) | ||||
|             try: | ||||
|                 if kwargs: | ||||
|                     device = torch.device(kwargs["device"].as_python_constant()) | ||||
|                 elif args: | ||||
|                     device = torch.device(args[0].as_python_constant()) | ||||
|                 else: | ||||
|                     device = None | ||||
|  | ||||
|                 return tx.symbolic_stream_state.cur_stream(device) | ||||
|             except Exception as e: | ||||
|                 unimplemented_v2( | ||||
|                     gb_type="bad device argument to torch.accelerator.current_stream", | ||||
|                     context=f"args={args}, kwargs={kwargs}", | ||||
|                     explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", | ||||
|                     hints=[*graph_break_hints.USER_ERROR], | ||||
|                     from_exc=e, | ||||
|                 ) | ||||
|  | ||||
|         @register(torch.set_default_device) | ||||
|         def handle_set_default_device( | ||||
|             self, tx: "InstructionTranslator", *args, **kwargs | ||||
|  | ||||
| @ -58,6 +58,7 @@ from ..exc import ( | ||||
|     raise_observed_exception, | ||||
|     unimplemented_v2, | ||||
| ) | ||||
| from ..graph_bytecode_inputs import get_external_object_by_index | ||||
| from ..guards import GuardBuilder, install_guard | ||||
| from ..source import ( | ||||
|     AttrSource, | ||||
| @ -792,14 +793,31 @@ class UserDefinedClassVariable(UserDefinedVariable): | ||||
|                 ) | ||||
|                 args = [stacked] | ||||
|  | ||||
|             tensor_variable = wrap_fx_proxy( | ||||
|                 tx=tx, | ||||
|                 proxy=tx.output.create_proxy( | ||||
|                     "call_function", | ||||
|                     self.value, | ||||
|                     *proxy_args_kwargs(args, kwargs), | ||||
|                 ), | ||||
|             ) | ||||
|             if issubclass(self.value, torch.Stream): | ||||
|                 # Register newly created stream for reconstruction | ||||
|                 stream = self.value() | ||||
|                 from ..graph_bytecode_inputs import register_graph_created_object | ||||
|                 from .streams import StreamVariable | ||||
|  | ||||
|                 ind = register_graph_created_object( | ||||
|                     stream, StreamVariable.construct_in_graph_stream | ||||
|                 ) | ||||
|                 tensor_variable = wrap_fx_proxy( | ||||
|                     tx=tx, | ||||
|                     proxy=tx.output.create_proxy( | ||||
|                         "call_function", get_external_object_by_index, (ind,), {} | ||||
|                     ), | ||||
|                     user_obj_index=ind, | ||||
|                 ) | ||||
|             else: | ||||
|                 tensor_variable = wrap_fx_proxy( | ||||
|                     tx=tx, | ||||
|                     proxy=tx.output.create_proxy( | ||||
|                         "call_function", | ||||
|                         self.value, | ||||
|                         *proxy_args_kwargs(args, kwargs), | ||||
|                     ), | ||||
|                 ) | ||||
|  | ||||
|             return tensor_variable | ||||
|         elif self.value is random.Random: | ||||
|  | ||||
| @ -49,6 +49,7 @@ static PyObject* THPEvent_pynew( | ||||
|   } | ||||
|  | ||||
|   THPEvent* self = (THPEvent*)ptr.get(); | ||||
|   self->weakreflist = nullptr; | ||||
|  | ||||
|   // TODO: blocking and interprocess are not supported yet. To support them, the | ||||
|   // flag system of c10::Event needs to be refactored. C10::Event should also | ||||
| @ -73,6 +74,7 @@ PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) { | ||||
|   auto self = THPObjectPtr{type->tp_alloc(type, 0)}; | ||||
|   TORCH_CHECK(self, "Failed to allocate memory for Event"); | ||||
|   auto self_ = reinterpret_cast<THPEvent*>(self.get()); | ||||
|   self_->weakreflist = nullptr; | ||||
|   new (&self_->event) c10::Event(device_type, flag); | ||||
|   return self.release(); | ||||
| } | ||||
| @ -82,6 +84,7 @@ static void THPEvent_dealloc(THPEvent* self) { | ||||
|     pybind11::gil_scoped_release no_gil{}; | ||||
|     self->event.~Event(); | ||||
|   } | ||||
|   PyObject_ClearWeakRefs((PyObject*)self); | ||||
|   Py_TYPE(self)->tp_free((PyObject*)self); | ||||
| } | ||||
|  | ||||
| @ -274,7 +277,8 @@ static PyMethodDef THPEvent_methods[] = { | ||||
|     {"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr}, | ||||
|     {"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr}, | ||||
|     {nullptr}}; | ||||
|  | ||||
| #pragma GCC diagnostic push | ||||
| #pragma GCC diagnostic ignored "-Winvalid-offsetof" | ||||
| PyTypeObject THPEventType = { | ||||
|     PyVarObject_HEAD_INIT(nullptr, 0) | ||||
|     "torch.Event", /* tp_name */ | ||||
| @ -300,7 +304,7 @@ PyTypeObject THPEventType = { | ||||
|     nullptr, /* tp_traverse */ | ||||
|     nullptr, /* tp_clear */ | ||||
|     nullptr, /* tp_richcompare */ | ||||
|     0, /* tp_weaklistoffset */ | ||||
|     offsetof(THPEvent, weakreflist), /* tp_weaklistoffset */ | ||||
|     nullptr, /* tp_iter */ | ||||
|     nullptr, /* tp_iternext */ | ||||
|     THPEvent_methods, /* tp_methods */ | ||||
| @ -315,6 +319,7 @@ PyTypeObject THPEventType = { | ||||
|     nullptr, /* tp_alloc */ | ||||
|     THPEvent_pynew, /* tp_new */ | ||||
| }; | ||||
| #pragma GCC diagnostic pop | ||||
|  | ||||
| void THPEvent_init(PyObject* module) { | ||||
|   THPEventClass = &THPEventType; | ||||
|  | ||||
| @ -7,6 +7,7 @@ | ||||
| struct TORCH_API THPEvent { | ||||
|   PyObject_HEAD | ||||
|   c10::Event event; | ||||
|   PyObject* weakreflist; | ||||
| }; | ||||
| TORCH_API extern PyTypeObject* THPEventClass; | ||||
| TORCH_API extern PyTypeObject THPEventType; | ||||
|  | ||||
| @ -95,6 +95,7 @@ static PyObject* THPStream_pynew( | ||||
|   self->device_index = static_cast<int64_t>(stream_opt->device_index()); | ||||
|   self->device_type = static_cast<int64_t>(stream_opt->device_type()); | ||||
|   self->context = nullptr; | ||||
|   self->weakreflist = nullptr; | ||||
|  | ||||
|   return (PyObject*)ptr.release(); | ||||
|   END_HANDLE_TH_ERRORS | ||||
| @ -114,11 +115,13 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) { | ||||
|   self->device_index = static_cast<int64_t>(stream.device_index()); | ||||
|   self->device_type = static_cast<int64_t>(stream.device_type()); | ||||
|   self->context = nullptr; | ||||
|   self->weakreflist = nullptr; | ||||
|   return ptr.release(); | ||||
|   END_HANDLE_TH_ERRORS | ||||
| } | ||||
|  | ||||
| static void THPStream_dealloc(THPStream* self) { | ||||
|   PyObject_ClearWeakRefs((PyObject*)self); | ||||
|   Py_TYPE(self)->tp_free((PyObject*)self); | ||||
| } | ||||
|  | ||||
| @ -436,7 +439,7 @@ static PyTypeObject THPStreamType = { | ||||
|     nullptr, /* tp_traverse */ | ||||
|     nullptr, /* tp_clear */ | ||||
|     THPStream_richcompare, /* tp_richcompare */ | ||||
|     0, /* tp_weaklistoffset */ | ||||
|     offsetof(THPStream, weakreflist), /* tp_weaklistoffset */ | ||||
|     nullptr, /* tp_iter */ | ||||
|     nullptr, /* tp_iternext */ | ||||
|     // NOLINTNEXTLINE(*const-cast) | ||||
|  | ||||
| @ -13,6 +13,7 @@ struct THPStream { | ||||
|   int64_t device_index; | ||||
|   // Used to switch stream context management, initialized lazily. | ||||
|   PyObject* context; | ||||
|   PyObject* weakreflist; | ||||
| }; | ||||
| extern TORCH_API PyTypeObject* THPStreamClass; | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	