mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 23:15:01 +08:00
Compare commits
17 Commits
ciflow/h10
...
mlazos/use
| Author | SHA1 | Date | |
|---|---|---|---|
| f70a6ac1a6 | |||
| 020bd1f830 | |||
| c84e73df6e | |||
| 6120d39fdb | |||
| 17ed117a90 | |||
| 3e941aebb7 | |||
| e403203714 | |||
| 1e80b1ad7d | |||
| 3dbe758856 | |||
| 7faa4842f8 | |||
| 45200769d1 | |||
| 38695fbfb4 | |||
| 0c01bec755 | |||
| 02cbfe3469 | |||
| 3e8fcf18ab | |||
| 3e5b122936 | |||
| 422233fde2 |
146
test/dynamo/test_streams.py
Normal file
146
test/dynamo/test_streams.py
Normal file
@ -0,0 +1,146 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
|
||||
|
||||
class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
|
||||
def test_stream_weakref(self):
|
||||
s = torch.Stream()
|
||||
weakref.ref(s)
|
||||
|
||||
def test_event_weakref(self):
|
||||
e = torch.Event()
|
||||
weakref.ref(e)
|
||||
|
||||
def test_stream_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
z = torch.add(x, y)
|
||||
y = z + 2 + z1
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
actual = fn_opt(*inp)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_stream_context_graph_break(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
z = torch.add(x, y)
|
||||
y = z + 2 + z1
|
||||
torch._dynamo.graph_break()
|
||||
y = y + 1
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
fn_opt = torch.compile(fn)
|
||||
actual = fn_opt(*inp)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_stream_input(self):
|
||||
def fn(x, y, s):
|
||||
z = torch.add(x, y)
|
||||
y = z + 2
|
||||
return y, s
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(device="cuda"))
|
||||
expected = fn(*inp)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
actual = fn_opt(*inp)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_local_stream_return(self):
|
||||
def fn(x, y):
|
||||
s = torch.Stream()
|
||||
z = torch.add(x, y)
|
||||
y = z + 2
|
||||
return y, s
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
_, s0 = fn_opt(*inp)
|
||||
_, s1 = fn_opt(*inp)
|
||||
# Streams will be different values for each invocation
|
||||
# so don't check for equality
|
||||
self.assertIsInstance(s0, torch.Stream)
|
||||
# Stream should be newly allocated on each call
|
||||
self.assertNotEqual(s0, s1)
|
||||
|
||||
def test_get_current_stream_return(self):
|
||||
def fn(x, s):
|
||||
with s:
|
||||
s0 = torch.cuda.current_stream()
|
||||
return x, s0
|
||||
|
||||
s_inp = torch.Stream(device="cuda")
|
||||
inp = (torch.ones(2, 2) + 1, s_inp)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
_, s0 = fn_opt(*inp)
|
||||
_, s1 = fn_opt(*inp)
|
||||
self.assertEqual(s_inp, s0)
|
||||
self.assertEqual(s0, s1)
|
||||
|
||||
def test_nested_stream_enter_exit(self):
|
||||
pass
|
||||
|
||||
def test_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
def test_nested_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
def test_local_stream_enter_exit(self):
|
||||
pass
|
||||
|
||||
def test_local_stream_nested_enter_exit(self):
|
||||
pass
|
||||
|
||||
def test_stream_with_mutation(self):
|
||||
pass
|
||||
|
||||
def test_run_opcheck(self):
|
||||
from torch._dynamo.variables.streams import fork_stream_, join_stream_
|
||||
from torch.library import opcheck
|
||||
|
||||
sample_inputs = [
|
||||
(1, torch.device("cuda:0"), 1, [torch.randn(3), torch.randn(3)]),
|
||||
(
|
||||
2,
|
||||
torch.device("cuda:0"),
|
||||
0,
|
||||
[torch.randn(2, 3, device="cuda"), torch.randn(2, 3, device="cuda")],
|
||||
),
|
||||
]
|
||||
for args in sample_inputs:
|
||||
opcheck(fork_stream_, args)
|
||||
opcheck(join_stream_, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -149,7 +149,6 @@ def reset() -> None:
|
||||
GenerationTracker.clear()
|
||||
TensorifyState.clear()
|
||||
torch._dynamo.utils.warn_once_cache.clear()
|
||||
torch._dynamo.utils.user_obj_id_to_weakref.clear()
|
||||
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
|
||||
|
||||
|
||||
|
||||
@ -116,6 +116,7 @@ from .exc import (
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
)
|
||||
from .graph_bytecode_inputs import reset_user_object_tracking
|
||||
from .guards import (
|
||||
CheckFunctionManager,
|
||||
get_and_maybe_log_recompilation_reasons,
|
||||
@ -310,6 +311,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
torch.fx._symbolic_trace._maybe_revert_all_patches()
|
||||
)
|
||||
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
|
||||
reset_user_object_tracking()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
|
||||
@ -2734,6 +2734,12 @@
|
||||
}
|
||||
],
|
||||
"GB0272": [
|
||||
{
|
||||
"Gb_type": "Failed to make weakref to User Object when storing by ID",
|
||||
"Context": "user_objected: {obj}",
|
||||
"Explanation": "Object does not allow us to make a weakref to it",
|
||||
"Hints": []
|
||||
},
|
||||
{
|
||||
"Gb_type": "Failed to make weakref to User Object",
|
||||
"Context": "user_objected: {obj}",
|
||||
@ -2763,5 +2769,13 @@
|
||||
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0275": [
|
||||
{
|
||||
"Gb_type": "Failed to make weakref to User Object",
|
||||
"Context": "user_object: {value}",
|
||||
"Explanation": "Object does not allow us to make a weakref to it",
|
||||
"Hints": []
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
65
torch/_dynamo/graph_bytecode_inputs.py
Normal file
65
torch/_dynamo/graph_bytecode_inputs.py
Normal file
@ -0,0 +1,65 @@
|
||||
import weakref
|
||||
from typing import Any
|
||||
|
||||
from torch._dynamo.source import Source
|
||||
|
||||
|
||||
# This file is to handle types that we don't want to support
|
||||
# as explicit FX graph inputs. This uses a sidetable which
|
||||
# we populate in bytecode and is loaded during graph execution
|
||||
|
||||
# We use a dynamo-generated index as a level of indirection
|
||||
# this allows us to register objects externally in pre-graph bytecode that we want
|
||||
# to pass to the graph, but not support their types as graph inputs
|
||||
index_to_source: dict[int, Source] = {}
|
||||
|
||||
index_to_user_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
|
||||
|
||||
|
||||
def has_user_objects() -> bool:
|
||||
return bool(index_to_source)
|
||||
|
||||
|
||||
def get_user_object_by_index(index: int) -> Any:
|
||||
print(index)
|
||||
breakpoint()
|
||||
assert index in index_to_user_object_weakref, (
|
||||
"Index not registered in index_to_user_object_weakref"
|
||||
)
|
||||
print(index_to_user_object_weakref[index])
|
||||
obj = index_to_user_object_weakref[index]()
|
||||
assert obj is not None, "User object is no longer alive"
|
||||
return index_to_user_object_weakref[index]()
|
||||
|
||||
|
||||
def store_user_object_weakrefs(*args: Any) -> None:
|
||||
global index_to_user_object_weakref
|
||||
index_to_user_object_weakref.clear()
|
||||
index_to_user_object_weakref.update(
|
||||
{i: weakref.ref(arg) for i, arg in enumerate(args)}
|
||||
)
|
||||
|
||||
|
||||
def reset_user_object_tracking() -> None:
|
||||
index_to_source.clear()
|
||||
index_to_user_object_weakref.clear()
|
||||
|
||||
|
||||
# Register a user object to be used in the graph
|
||||
def register_user_object(value: Any, source: Source) -> int:
|
||||
global index_to_source
|
||||
index = len(index_to_source)
|
||||
index_to_source[index] = source
|
||||
try:
|
||||
index_to_user_object_weakref[index] = weakref.ref(value)
|
||||
except TypeError as e:
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to make weakref to User Object",
|
||||
context=f"user_object: {value}",
|
||||
explanation="Object does not allow us to make a weakref to it",
|
||||
hints=[],
|
||||
from_exc=e,
|
||||
)
|
||||
return index
|
||||
@ -2162,6 +2162,8 @@ class GuardBuilder(GuardBuilderBase):
|
||||
range,
|
||||
dict_keys,
|
||||
torch.Size,
|
||||
torch.Stream,
|
||||
torch.cuda.streams.Stream,
|
||||
*np_types,
|
||||
*ok_mutable_types,
|
||||
}
|
||||
|
||||
@ -99,6 +99,7 @@ from .exc import (
|
||||
unimplemented_v2,
|
||||
unimplemented_v2_with_warning,
|
||||
)
|
||||
from .graph_bytecode_inputs import has_user_objects, index_to_source
|
||||
from .graph_deduplication import apply_graph_deduplication
|
||||
from .graph_region_tracker import GraphRegionTracker
|
||||
from .guards import GuardBuilder, install_guard
|
||||
@ -1508,6 +1509,27 @@ class OutputGraph(OutputGraphCommon):
|
||||
|
||||
from .decorators import disable
|
||||
|
||||
if has_user_objects():
|
||||
# NB: This is where we store possible user objects before running the graph
|
||||
# index_to_user_object_weakref is the function used in the graph to translate
|
||||
# the dynamo-generated index into the actual object passed to the compiled function.
|
||||
# We generate bytecode to store all user objects at the proper index in the below
|
||||
# call.
|
||||
codegen = PyCodegen(
|
||||
self.root_tx, root, overridden_sources=overridden_sources
|
||||
)
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
torch._dynamo.graph_bytecode_inputs.__name__,
|
||||
"store_user_object_weakrefs",
|
||||
)
|
||||
)
|
||||
for source in reversed(index_to_source.values()):
|
||||
codegen(source)
|
||||
codegen.call_function(len(index_to_source), False)
|
||||
codegen.pop_top()
|
||||
self.add_output_instructions(codegen.get_instructions())
|
||||
|
||||
# to handle random calls
|
||||
if len(self.random_calls) > 0:
|
||||
random_calls_instructions = []
|
||||
@ -1645,7 +1667,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
)
|
||||
elif (
|
||||
vt.source is not None
|
||||
and (source := getattr(vt.source, "base", None))
|
||||
and (source := getattr(vt.source, "base", None)) # type: ignore[assignment]
|
||||
and source.is_input
|
||||
):
|
||||
self.export_metadata.output_return_type[idx] = (
|
||||
|
||||
@ -3333,6 +3333,7 @@ def get_fake_value(
|
||||
UserError,
|
||||
UserErrorType,
|
||||
)
|
||||
from .variables.streams import stream_state_mgr
|
||||
|
||||
op = node.op
|
||||
|
||||
@ -3347,13 +3348,27 @@ def get_fake_value(
|
||||
if (
|
||||
torch._dynamo.config.use_graph_deduplication
|
||||
or torch._dynamo.config.track_nodes_for_deduplication
|
||||
or stream_state_mgr.in_stream_context()
|
||||
):
|
||||
flat_args_kwargs = get_fake_values_from_nodes(
|
||||
tx, _get_flat_args(node, {}), allow_non_graph_fake
|
||||
)
|
||||
id_to_initial_version = {
|
||||
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
|
||||
}
|
||||
flat_args = _get_flat_args(node, {})
|
||||
if stream_state_mgr.in_stream_context():
|
||||
for arg in flat_args:
|
||||
if isinstance(arg, torch.fx.Node):
|
||||
stream_state_mgr.track_node(arg)
|
||||
|
||||
if (
|
||||
torch._dynamo.config.use_graph_deduplication
|
||||
or torch._dynamo.config.track_nodes_for_deduplication
|
||||
):
|
||||
flat_args_kwargs = get_fake_values_from_nodes(
|
||||
tx, flat_args, allow_non_graph_fake
|
||||
)
|
||||
id_to_initial_version = {
|
||||
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
|
||||
}
|
||||
else:
|
||||
flat_args_kwargs = []
|
||||
id_to_initial_version = {}
|
||||
else:
|
||||
flat_args_kwargs = []
|
||||
id_to_initial_version = {}
|
||||
@ -3502,6 +3517,9 @@ def get_fake_value(
|
||||
torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val
|
||||
)
|
||||
|
||||
if stream_state_mgr.in_stream_context():
|
||||
stream_state_mgr.track_internal_node(node)
|
||||
|
||||
if (
|
||||
torch._dynamo.config.use_graph_deduplication
|
||||
or torch._dynamo.config.track_nodes_for_deduplication
|
||||
@ -4701,34 +4719,6 @@ def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]:
|
||||
return tensor_dict
|
||||
|
||||
|
||||
# This is useful for reconstructing within the Dynamo graph the non-graph-input objects
|
||||
# whose lifetime is governed by the user.
|
||||
# e.g. torch.cuda.Event is a prime example.
|
||||
user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {}
|
||||
|
||||
|
||||
def get_user_object_from_id(obj_id: int) -> Any:
|
||||
obj = user_obj_id_to_weakref[obj_id]()
|
||||
assert obj is not None, "User object is no longer alive"
|
||||
return obj
|
||||
|
||||
|
||||
def store_user_object_weakref(obj: object) -> None:
|
||||
obj_id = id(obj)
|
||||
try:
|
||||
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)
|
||||
except TypeError as e:
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to make weakref to User Object",
|
||||
context=f"user_objected: {obj}",
|
||||
explanation="Object does not allow us to make a weakref to it",
|
||||
hints=[],
|
||||
from_exc=e,
|
||||
)
|
||||
|
||||
|
||||
class CompileTimeInstructionCounter:
|
||||
_counter: int = 0
|
||||
_id: int = -1
|
||||
|
||||
@ -36,8 +36,6 @@ from .ctx_manager import (
|
||||
JvpIncrementNestingCtxManagerVariable,
|
||||
SDPAKernelVariable,
|
||||
SetFwdGradEnabledContextManager,
|
||||
StreamContextVariable,
|
||||
StreamVariable,
|
||||
TemporarilyPopInterpreterStackCtxManagerVariable,
|
||||
VmapIncrementNestingCtxManagerVariable,
|
||||
WithEnterFunctionVariable,
|
||||
@ -130,6 +128,7 @@ from .nn_module import (
|
||||
)
|
||||
from .optimizer import OptimizerVariable
|
||||
from .sdpa import SDPAParamsVariable
|
||||
from .streams import EventVariable, StreamContextVariable, StreamVariable
|
||||
from .tensor import (
|
||||
DataPtrVariable,
|
||||
FakeItemVariable,
|
||||
|
||||
@ -45,6 +45,10 @@ import sympy
|
||||
import torch
|
||||
from torch import SymInt
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.graph_bytecode_inputs import (
|
||||
get_user_object_by_index,
|
||||
register_user_object,
|
||||
)
|
||||
from torch._dynamo.utils import (
|
||||
get_metrics_context,
|
||||
is_int_specialization_case,
|
||||
@ -172,11 +176,8 @@ from .ctx_manager import (
|
||||
AutocastModeVariable,
|
||||
DynamoConfigPatchVariable,
|
||||
ErrorOnGraphBreakVariable,
|
||||
EventVariable,
|
||||
NullContextVariable,
|
||||
PreserveVersionContextVariable,
|
||||
StreamContextVariable,
|
||||
StreamVariable,
|
||||
)
|
||||
from .dicts import (
|
||||
ConstDictVariable,
|
||||
@ -257,6 +258,7 @@ from .nn_module import (
|
||||
from .optimizer import OptimizerVariable
|
||||
from .script_object import TorchScriptObjectVariable
|
||||
from .sdpa import SDPAParamsVariable
|
||||
from .streams import EventVariable, StreamContextVariable, StreamVariable
|
||||
from .tensor import (
|
||||
NumpyNdarrayVariable,
|
||||
supported_const_comparison_op_values,
|
||||
@ -1036,24 +1038,20 @@ class VariableBuilder:
|
||||
stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
|
||||
return StreamContextVariable.create(self.tx, stream_var)
|
||||
elif isinstance(value, torch.Stream):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
# This refers to the device-agnostic torch.Stream
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
index = register_user_object(value, self.source)
|
||||
stream_proxy = self.tx.output.create_proxy(
|
||||
"call_function",
|
||||
type(value),
|
||||
(),
|
||||
{
|
||||
"stream_id": value.stream_id,
|
||||
"device_index": value.device_index,
|
||||
"device_type": value.device_type,
|
||||
},
|
||||
"call_function", get_user_object_by_index, (index,), {}
|
||||
)
|
||||
set_example_value(stream_proxy.node, value)
|
||||
return StreamVariable(
|
||||
var = StreamVariable(
|
||||
stream_proxy,
|
||||
value,
|
||||
value.device,
|
||||
source=self.source,
|
||||
)
|
||||
return self.tx.output.side_effects.track_object_existing(value, var)
|
||||
elif isinstance(value, (torch._C._SDPAParams)):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return SDPAParamsVariable.create(self.tx, value, self.source)
|
||||
@ -1061,12 +1059,12 @@ class VariableBuilder:
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return FuncTorchInterpreterVariable(value)
|
||||
elif isinstance(value, torch.Event):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
torch._dynamo.utils.store_user_object_weakref(value)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
index = register_user_object(value, self.source)
|
||||
event_proxy = self.tx.output.create_proxy(
|
||||
"call_function",
|
||||
torch._dynamo.utils.get_user_object_from_id,
|
||||
(id(value),),
|
||||
get_user_object_by_index,
|
||||
(index,),
|
||||
{},
|
||||
)
|
||||
set_example_value(event_proxy.node, value)
|
||||
@ -2980,8 +2978,8 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
set_example_value(proxy.node, example_value)
|
||||
return SymNodeVariable(proxy, example_value, **options)
|
||||
elif (
|
||||
inspect.isclass(proxy.node.target)
|
||||
and issubclass(proxy.node.target, torch.Stream)
|
||||
isinstance(example_value, torch.Stream)
|
||||
and proxy.node.target == get_user_object_by_index
|
||||
) or proxy.node.target in [
|
||||
device_interface.current_stream
|
||||
for _, device_interface in get_registered_device_interfaces()
|
||||
@ -3069,6 +3067,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
set_example_value(proxy.node, example_value)
|
||||
return ConstantVariable.create(example_value, **options)
|
||||
else:
|
||||
breakpoint()
|
||||
unimplemented_v2(
|
||||
gb_type="torch.* op returned non-Tensor",
|
||||
context=f"example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}",
|
||||
|
||||
@ -83,7 +83,6 @@ from ..utils import (
|
||||
)
|
||||
from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import EventVariable, StreamVariable
|
||||
from .dicts import (
|
||||
ConstDictVariable,
|
||||
DefaultDictVariable,
|
||||
@ -101,6 +100,7 @@ from .lists import (
|
||||
TupleIteratorVariable,
|
||||
TupleVariable,
|
||||
)
|
||||
from .streams import EventVariable, StreamVariable
|
||||
from .tensor import (
|
||||
FakeItemVariable,
|
||||
supported_comparison_ops,
|
||||
|
||||
@ -34,7 +34,6 @@ from ..bytecode_transformation import (
|
||||
create_instruction,
|
||||
create_setup_with,
|
||||
)
|
||||
from ..device_interface import get_interface_for_device
|
||||
from ..exc import unimplemented_v2
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, GlobalStateSource
|
||||
@ -991,70 +990,6 @@ class ProfilerContextVariable(ContextWrappingVariable):
|
||||
)
|
||||
|
||||
|
||||
class StreamContextVariable(ContextWrappingVariable):
|
||||
@staticmethod
|
||||
def create(tx: "InstructionTranslator", target_value, **kwargs):
|
||||
from .builder import wrap_fx_proxy_cls
|
||||
|
||||
current_stream_method = get_interface_for_device(
|
||||
target_value.device
|
||||
).current_stream
|
||||
current_stream = wrap_fx_proxy_cls(
|
||||
StreamVariable,
|
||||
tx,
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
current_stream_method,
|
||||
(None,),
|
||||
{},
|
||||
),
|
||||
)
|
||||
return StreamContextVariable(
|
||||
target_values=[target_value],
|
||||
initial_values=[current_stream],
|
||||
device=target_value.device,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __init__(self, target_values, device, initial_values=None, **kwargs) -> None:
|
||||
super().__init__(
|
||||
target_values=target_values, initial_values=initial_values, **kwargs
|
||||
)
|
||||
self.device = device
|
||||
self.set_stream = get_interface_for_device(self.device).set_stream
|
||||
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
|
||||
|
||||
def enter(self, tx):
|
||||
# stream generated inside the traced function
|
||||
if self.target_values[0].as_proxy() is not None:
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.set_stream,
|
||||
(self.target_values[0].as_proxy(),),
|
||||
{},
|
||||
)
|
||||
# stream passed from outside the traced function
|
||||
else:
|
||||
stream = self.target_values[0].value
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.set_stream_id,
|
||||
(stream.stream_id, stream.device_index, stream.device_type),
|
||||
{},
|
||||
)
|
||||
self.set_stream(self.target_values[0].value)
|
||||
self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.set_stream,
|
||||
(self.initial_values[0].as_proxy(),),
|
||||
{},
|
||||
)
|
||||
self.cleanup_assert()
|
||||
|
||||
|
||||
class PreserveVersionContextVariable(ContextWrappingVariable):
|
||||
"""
|
||||
Wraps torch.autograd._unsafe_preserve_version_counter
|
||||
@ -1262,142 +1197,6 @@ class SDPAKernelVariable(ContextWrappingVariable):
|
||||
return "_sdpa_kernel_variadic"
|
||||
|
||||
|
||||
class StreamVariable(VariableTracker):
|
||||
def __init__(self, proxy, value, device, **kwargs) -> None:
|
||||
if proxy is not None and "example_value" in proxy.node.meta:
|
||||
assert proxy.node.meta["example_value"] == value
|
||||
assert value.device.type == device.type, (
|
||||
"stream value is not equal to the passed device"
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
self.proxy = proxy
|
||||
self.value = value
|
||||
self.device = device
|
||||
|
||||
def python_type(self):
|
||||
return torch.Stream
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
assert hasattr(self.value, name), f"no stream method found named {name}"
|
||||
|
||||
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
|
||||
from .builder import wrap_fx_proxy_cls
|
||||
|
||||
if name in ("wait_stream", "synchronize", "wait_event"):
|
||||
tx.output.create_proxy(
|
||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||
)
|
||||
return variables.ConstantVariable(None)
|
||||
elif name == "query":
|
||||
return wrap_fx_proxy_cls(
|
||||
target_cls=variables.ConstantVariable,
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||
),
|
||||
)
|
||||
elif name == "record_event":
|
||||
return wrap_fx_proxy_cls(
|
||||
target_cls=EventVariable,
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||
),
|
||||
)
|
||||
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
|
||||
# NB : Checking for mutation is necessary because we compare
|
||||
# constant values
|
||||
other = args[0]
|
||||
if not isinstance(other, StreamVariable):
|
||||
return variables.ConstantVariable.create(NotImplemented)
|
||||
return variables.ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.value, other.value)
|
||||
)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def as_proxy(self):
|
||||
return self.proxy
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# If we got here, this stream is fully subsumed by the graph - this means it is
|
||||
# not an input or global
|
||||
assert not self.source
|
||||
# Since we just proved that - for other such structures, like lists and dicts, reconstruction
|
||||
# is fine and sound according to dynamo principles of treating collectives. However,
|
||||
# streams are special in that we want to preserve the identity of the stream as the same as in the graph
|
||||
# Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
|
||||
# yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
|
||||
# design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
|
||||
prefix = f"_stream_{self.device}"
|
||||
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
|
||||
|
||||
class EventVariable(VariableTracker):
|
||||
def __init__(self, proxy, value, **kwargs) -> None:
|
||||
if proxy is not None and "example_value" in proxy.node.meta:
|
||||
assert proxy.node.meta["example_value"] == value
|
||||
super().__init__(**kwargs)
|
||||
self.proxy = proxy
|
||||
self.value = value
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from ..utils import proxy_args_kwargs
|
||||
from .builder import wrap_fx_proxy_cls
|
||||
|
||||
if name in ("wait", "record", "synchronize"):
|
||||
tx.output.create_proxy(
|
||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||
)
|
||||
return variables.ConstantVariable(None)
|
||||
elif name == "query":
|
||||
return wrap_fx_proxy_cls(
|
||||
target_cls=variables.ConstantVariable,
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||
),
|
||||
)
|
||||
else:
|
||||
method_name = (
|
||||
f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
|
||||
)
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported event method",
|
||||
context=str(name),
|
||||
explanation=f"Dynamo doesn't support tracing the {method_name} method. "
|
||||
f"We currently support wait, record, synchronize, and query.",
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
def as_proxy(self):
|
||||
return self.proxy
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# If we got here, this event is fully subsumed by the graph - this means it is
|
||||
# not an input or global
|
||||
assert not self.source
|
||||
# Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
|
||||
prefix = "_event"
|
||||
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
|
||||
|
||||
class DynamoConfigPatchVariable(ContextWrappingVariable):
|
||||
"""represents torch._dynamo.patch_dynamo_config"""
|
||||
|
||||
|
||||
420
torch/_dynamo/variables/streams.py
Normal file
420
torch/_dynamo/variables/streams.py
Normal file
@ -0,0 +1,420 @@
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.fx import Node, Proxy
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .. import graph_break_hints
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..device_interface import get_interface_for_device
|
||||
from ..exc import TYPE_CHECKING, unimplemented_v2
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import ContextWrappingVariable
|
||||
from .misc import GetAttrVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from ..codegen import PyCodegen
|
||||
|
||||
from torch._library.custom_ops import custom_op
|
||||
|
||||
|
||||
# Avoid circular dependency for the dataclass
|
||||
TensorVariable = Any
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
@custom_op("streams::fork", mutates_args={"args"})
|
||||
def fork_stream_(
|
||||
index: int, device: torch.device, device_index: int, args: list[Tensor]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@fork_stream_.register_fake
|
||||
def _(index: int, device: torch.device, device_index: int, args: list[Tensor]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@custom_op("streams::join", mutates_args={"args"})
|
||||
def join_stream_(
|
||||
index: int, device: torch.device, device_index: int, args: list[Tensor]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@join_stream_.register_fake
|
||||
def _(index: int, device: torch.device, device_index: int, args: list[Tensor]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
_keep_alive = []
|
||||
|
||||
|
||||
def add_dynamo_owned_stream(s):
|
||||
global _keep_alive
|
||||
_keep_alive.append(s)
|
||||
|
||||
|
||||
# Stream state consists of the fork stream node
|
||||
# and the external to the stream that are accessed from within the
|
||||
# stream
|
||||
@dataclass
|
||||
class StreamState:
|
||||
# the fork node that initiated the creation of this stream state
|
||||
# we will finalize it once the stream state is popped
|
||||
fork_node: Node
|
||||
# Nodes not created within the stream
|
||||
external_nodes: OrderedSet[Node]
|
||||
# Nodes created within the stream
|
||||
internal_nodes: OrderedSet[Node]
|
||||
|
||||
|
||||
class StreamStateManager:
|
||||
"""
|
||||
Class used to track the current stream context we are in and identify
|
||||
any used tensors as external (created outside the stream context) or
|
||||
internal (created within the stream context). We use this information to
|
||||
ensure the fork op is dependent on any external tensors, so that it will not
|
||||
be reordered before them or after ops which use the externally created tensors.
|
||||
Analagously, we use the internal tensors to ensure that the join op is not
|
||||
reordered before any internally created tensors or after ops which use the
|
||||
internally created tensors.
|
||||
|
||||
To actually implement this, we have a stack of stream states which track any external tensors that
|
||||
have not yet been seen within the stream context and any tensors created within the stream context.
|
||||
Once we exit the stream context we populate the args of fork with all external tensors which have been used,
|
||||
and join with any internal tensors that were created.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.state_stack: deque[StreamState] = deque()
|
||||
|
||||
def in_stream_context(self) -> bool:
|
||||
return bool(self.state_stack)
|
||||
|
||||
def track_internal_node(self, node: Node) -> None:
|
||||
# if we are in a stream context, all created nodes are internal
|
||||
if self.in_stream_context():
|
||||
val = node.meta.get("example_value")
|
||||
if isinstance(val, torch.Tensor):
|
||||
# Only add tensor nodes
|
||||
# if we have seen the node before, it is an internal
|
||||
self._cur_state().internal_nodes.add(node)
|
||||
|
||||
def track_node(self, node: Node) -> None:
|
||||
# If we are in a stream context, args of ops may be external
|
||||
if self.in_stream_context():
|
||||
val = node.meta.get("example_value")
|
||||
if isinstance(val, torch.Tensor) and node not in self._internal_nodes():
|
||||
self._external_nodes().add(node)
|
||||
|
||||
def push_stream_state(self, node: Node) -> None:
|
||||
self.state_stack.append(StreamState(node, OrderedSet(), OrderedSet()))
|
||||
|
||||
def pop_stream_state(self) -> StreamState:
|
||||
assert self.state_stack, "No stream state to pop"
|
||||
return self.state_stack.pop()
|
||||
|
||||
def _cur_state(self) -> StreamState:
|
||||
assert self.state_stack, "No stream state to pop"
|
||||
return self.state_stack[-1]
|
||||
|
||||
def _internal_nodes(self) -> OrderedSet[Node]:
|
||||
return self._cur_state().internal_nodes
|
||||
|
||||
def _external_nodes(self) -> OrderedSet[Node]:
|
||||
return self._cur_state().external_nodes
|
||||
|
||||
|
||||
stream_state_mgr = StreamStateManager()
|
||||
|
||||
|
||||
class StreamContextVariable(ContextWrappingVariable):
|
||||
"""This represents torch.cuda.StreamContext"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
tx: "InstructionTranslator",
|
||||
target_value: "StreamVariable",
|
||||
**kwargs: dict[str, Any],
|
||||
) -> "StreamContextVariable":
|
||||
return StreamContextVariable(
|
||||
target_values=[target_value],
|
||||
initial_values=[
|
||||
StreamContextVariable._get_current_stream(target_value.device, tx)
|
||||
],
|
||||
device=target_value.device,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_values: list["StreamVariable"],
|
||||
device: torch.device,
|
||||
initial_values: Optional[list["StreamVariable"]] = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
target_values=target_values, initial_values=initial_values, **kwargs
|
||||
)
|
||||
self.device = device
|
||||
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
|
||||
|
||||
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
|
||||
stream_id, device, device_index = (
|
||||
StreamContextVariable._extract_stream_properties(
|
||||
self._get_target_values()[0].as_proxy()
|
||||
)
|
||||
)
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function",
|
||||
torch.ops.streams.fork.default,
|
||||
(stream_id, device, device_index, []),
|
||||
{},
|
||||
)
|
||||
stream_state_mgr.push_stream_state(proxy.node)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker":
|
||||
state = stream_state_mgr.pop_stream_state()
|
||||
initial_stream_proxy = self.initial_values[0].as_proxy()
|
||||
stream_id, device, device_index = (
|
||||
StreamContextVariable._extract_stream_properties(initial_stream_proxy)
|
||||
)
|
||||
tx.output.create_node(
|
||||
"call_function",
|
||||
torch.ops.streams.join.default,
|
||||
(
|
||||
stream_id.node,
|
||||
device.node,
|
||||
device_index.node,
|
||||
list(state.internal_nodes),
|
||||
),
|
||||
{},
|
||||
)
|
||||
state.fork_node.args = (
|
||||
state.fork_node.args[0],
|
||||
state.fork_node.args[1],
|
||||
state.fork_node.args[2],
|
||||
list(state.external_nodes),
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
@staticmethod
|
||||
def _extract_stream_properties(stream_proxy: Proxy) -> tuple[Proxy, Proxy, Proxy]:
|
||||
stream_index = GetAttrVariable.create_getattr_proxy(stream_proxy, "stream_id")
|
||||
stream_device = GetAttrVariable.create_getattr_proxy(stream_proxy, "device")
|
||||
stream_device_index = GetAttrVariable.create_getattr_proxy(
|
||||
stream_proxy, "device_index"
|
||||
)
|
||||
return stream_index, stream_device, stream_device_index
|
||||
|
||||
@staticmethod
|
||||
def _get_current_stream(
|
||||
device: torch.device, tx: "InstructionTranslator"
|
||||
) -> "StreamVariable":
|
||||
from .builder import wrap_fx_proxy_cls
|
||||
|
||||
current_stream_method = get_interface_for_device(device).current_stream
|
||||
current_stream = wrap_fx_proxy_cls(
|
||||
StreamVariable,
|
||||
tx,
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
current_stream_method,
|
||||
(None,),
|
||||
{},
|
||||
),
|
||||
)
|
||||
return current_stream
|
||||
|
||||
def _get_target_values(self) -> list["StreamVariable"]:
|
||||
# We need this to be overridable, since StreamVariable does
|
||||
# not store target values (it does not require any arguments)
|
||||
# and captures the current stream at the time of entering the context
|
||||
return self.target_values
|
||||
|
||||
def supports_graph_breaks(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class StreamVariable(StreamContextVariable):
|
||||
"""Represents the device-agnostic torch.Stream class"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy: Proxy,
|
||||
value: torch.Stream,
|
||||
device: torch.device,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
user_ind = kwargs.pop("user_obj_index", None)
|
||||
if proxy is not None and "example_value" in proxy.node.meta:
|
||||
assert proxy.node.meta["example_value"] == value
|
||||
assert value.device.type == device.type, (
|
||||
"stream value is not equal to the passed device"
|
||||
)
|
||||
super().__init__(target_values=[], initial_values=None, device=device, **kwargs)
|
||||
self.proxy = proxy
|
||||
self.value = value
|
||||
self.device = device
|
||||
self.user_ind = user_ind
|
||||
|
||||
def python_type(self) -> type:
|
||||
return torch.Stream
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
assert hasattr(self.value, name), f"no stream method found named {name}"
|
||||
|
||||
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
|
||||
from .builder import wrap_fx_proxy_cls
|
||||
|
||||
if name in ("wait_stream", "synchronize", "wait_event"):
|
||||
tx.output.create_proxy(
|
||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||
)
|
||||
return ConstantVariable(None)
|
||||
elif name == "query":
|
||||
return wrap_fx_proxy_cls(
|
||||
target_cls=ConstantVariable,
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||
),
|
||||
)
|
||||
elif name == "record_event":
|
||||
return wrap_fx_proxy_cls(
|
||||
target_cls=EventVariable,
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||
),
|
||||
)
|
||||
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
|
||||
if self.source:
|
||||
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
|
||||
|
||||
# NB : Checking for mutation is necessary because we compare
|
||||
# constant values
|
||||
other = args[0]
|
||||
if not isinstance(other, StreamVariable):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
|
||||
if other.source:
|
||||
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
|
||||
# NB: Set initial values and target values when we enter
|
||||
# Don't do this at object creation, as we need to record the current stream
|
||||
# at the time the context is entered.
|
||||
self.initial_values = [
|
||||
StreamContextVariable._get_current_stream(self.device, tx)
|
||||
]
|
||||
return super().enter(tx)
|
||||
|
||||
def as_proxy(self) -> Proxy:
|
||||
return self.proxy
|
||||
|
||||
def module_name(self) -> str:
|
||||
return "torch._C"
|
||||
|
||||
def fn_name(self) -> str:
|
||||
return "Stream"
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# If we got here, this stream is fully subsumed by the graph - this means it is
|
||||
# not an input or global
|
||||
assert not self.source
|
||||
if self.user_ind is not None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
torch._dynamo.graph_bytecode_inputs.__name__,
|
||||
"get_user_object_by_index",
|
||||
)
|
||||
)
|
||||
codegen.append_output(codegen.create_load_const(self.user_ind))
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
else:
|
||||
# TODO mlazos: evaluate if we still need this
|
||||
prefix = f"_stream_{self.device}"
|
||||
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
|
||||
def _get_target_values(self) -> list["StreamVariable"]:
|
||||
return [self]
|
||||
|
||||
|
||||
class EventVariable(VariableTracker):
|
||||
def __init__(self, proxy: Proxy, value: torch.Event, **kwargs: Any) -> None:
|
||||
if proxy is not None and "example_value" in proxy.node.meta:
|
||||
assert proxy.node.meta["example_value"] == value
|
||||
super().__init__(**kwargs)
|
||||
self.proxy = proxy
|
||||
self.value = value
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
from ..utils import proxy_args_kwargs
|
||||
from .builder import wrap_fx_proxy_cls
|
||||
|
||||
if name in ("wait", "record", "synchronize"):
|
||||
tx.output.create_proxy(
|
||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||
)
|
||||
return ConstantVariable(None)
|
||||
elif name == "query":
|
||||
return wrap_fx_proxy_cls(
|
||||
target_cls=ConstantVariable,
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
||||
),
|
||||
)
|
||||
else:
|
||||
method_name = (
|
||||
f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
|
||||
)
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported event method",
|
||||
context=str(name),
|
||||
explanation=f"Dynamo doesn't support tracing the {method_name} method. "
|
||||
f"We currently support wait, record, synchronize, and query.",
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
def as_proxy(self) -> Proxy:
|
||||
return self.proxy
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# If we got here, this event is fully subsumed by the graph - this means it is
|
||||
# not an input or global
|
||||
assert not self.source
|
||||
# Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
|
||||
prefix = "_event"
|
||||
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
@ -245,6 +245,7 @@ class BaseTorchVariable(VariableTracker):
|
||||
|
||||
def __init__(self, value, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
print(value)
|
||||
self.value = value
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
|
||||
@ -58,6 +58,7 @@ from ..exc import (
|
||||
raise_observed_exception,
|
||||
unimplemented_v2,
|
||||
)
|
||||
from ..graph_bytecode_inputs import get_user_object_by_index, register_user_object
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import (
|
||||
AttrSource,
|
||||
@ -66,6 +67,7 @@ from ..source import (
|
||||
DictGetItemSource,
|
||||
GetItemSource,
|
||||
RandomValueSource,
|
||||
TorchSource,
|
||||
TypeDictSource,
|
||||
TypeMROSource,
|
||||
TypeSource,
|
||||
@ -792,14 +794,33 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
)
|
||||
args = [stacked]
|
||||
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
*proxy_args_kwargs(args, kwargs),
|
||||
),
|
||||
)
|
||||
if issubclass(self.value, torch.Stream):
|
||||
torch_stream_src = CallFunctionNoArgsSource(
|
||||
AttrSource(TorchSource(), "Stream")
|
||||
)
|
||||
# Register newly created stream for reconstruction
|
||||
stream = self.value()
|
||||
from .streams import add_dynamo_owned_stream
|
||||
|
||||
add_dynamo_owned_stream(stream)
|
||||
ind = register_user_object(stream, torch_stream_src)
|
||||
breakpoint()
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function", get_user_object_by_index, (ind,), {}
|
||||
),
|
||||
user_obj_index=ind,
|
||||
)
|
||||
else:
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
*proxy_args_kwargs(args, kwargs),
|
||||
),
|
||||
)
|
||||
|
||||
return tensor_variable
|
||||
elif self.value is random.Random:
|
||||
|
||||
@ -49,6 +49,7 @@ static PyObject* THPEvent_pynew(
|
||||
}
|
||||
|
||||
THPEvent* self = (THPEvent*)ptr.get();
|
||||
self->weakreflist = nullptr;
|
||||
|
||||
// TODO: blocking and interprocess are not supported yet. To support them, the
|
||||
// flag system of c10::Event needs to be refactored. C10::Event should also
|
||||
@ -73,6 +74,7 @@ PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
|
||||
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
|
||||
TORCH_CHECK(self, "Failed to allocate memory for Event");
|
||||
auto self_ = reinterpret_cast<THPEvent*>(self.get());
|
||||
self_->weakreflist = nullptr;
|
||||
new (&self_->event) c10::Event(device_type, flag);
|
||||
return self.release();
|
||||
}
|
||||
@ -82,6 +84,9 @@ static void THPEvent_dealloc(THPEvent* self) {
|
||||
pybind11::gil_scoped_release no_gil{};
|
||||
self->event.~Event();
|
||||
}
|
||||
if (self->weakreflist != nullptr) {
|
||||
PyObject_ClearWeakRefs((PyObject*)self);
|
||||
}
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
}
|
||||
|
||||
@ -300,7 +305,7 @@ PyTypeObject THPEventType = {
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
offsetof(THPEvent, weakreflist), /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
THPEvent_methods, /* tp_methods */
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
struct TORCH_API THPEvent {
|
||||
PyObject_HEAD
|
||||
c10::Event event;
|
||||
PyObject* weakreflist;
|
||||
};
|
||||
TORCH_API extern PyTypeObject* THPEventClass;
|
||||
TORCH_API extern PyTypeObject THPEventType;
|
||||
|
||||
@ -95,6 +95,7 @@ static PyObject* THPStream_pynew(
|
||||
self->device_index = static_cast<int64_t>(stream_opt->device_index());
|
||||
self->device_type = static_cast<int64_t>(stream_opt->device_type());
|
||||
self->context = nullptr;
|
||||
self->weakreflist = nullptr;
|
||||
|
||||
return (PyObject*)ptr.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -114,11 +115,15 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
|
||||
self->device_index = static_cast<int64_t>(stream.device_index());
|
||||
self->device_type = static_cast<int64_t>(stream.device_type());
|
||||
self->context = nullptr;
|
||||
self->weakreflist = nullptr;
|
||||
return ptr.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static void THPStream_dealloc(THPStream* self) {
|
||||
if (self->weakreflist != nullptr) {
|
||||
PyObject_ClearWeakRefs((PyObject*)self);
|
||||
}
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
}
|
||||
|
||||
@ -436,7 +441,7 @@ static PyTypeObject THPStreamType = {
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
THPStream_richcompare, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
offsetof(THPStream, weakreflist), /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
// NOLINTNEXTLINE(*const-cast)
|
||||
|
||||
@ -13,6 +13,7 @@ struct THPStream {
|
||||
int64_t device_index;
|
||||
// Used to switch stream context management, initialized lazily.
|
||||
PyObject* context;
|
||||
PyObject* weakreflist;
|
||||
};
|
||||
extern TORCH_API PyTypeObject* THPStreamClass;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user