mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	[user-streams] Fix stream graph output semantics
ghstack-source-id: 75778deaa3a00c5162ada276aadeebce7f5ffce9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164819 updates to graph semantics changes More fixes ghstack-source-id: 75778deaa3a00c5162ada276aadeebce7f5ffce9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165027 bytecode output fixes
This commit is contained in:
		@ -2790,5 +2790,13 @@
 | 
			
		||||
      "Explanation": "Object does not allow us to make a weakref to it",
 | 
			
		||||
      "Hints": []
 | 
			
		||||
    }
 | 
			
		||||
  ],
 | 
			
		||||
  "GB0277": [
 | 
			
		||||
    {
 | 
			
		||||
      "Gb_type": "Failed to make weakref to graph-created external object",
 | 
			
		||||
      "Context": "user_object: {example_value}",
 | 
			
		||||
      "Explanation": "Object does not allow us to make a weakref to it",
 | 
			
		||||
      "Hints": []
 | 
			
		||||
    }
 | 
			
		||||
  ]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,11 @@
 | 
			
		||||
import weakref
 | 
			
		||||
from typing import Any
 | 
			
		||||
from typing import Any, Callable
 | 
			
		||||
 | 
			
		||||
from torch._dynamo.source import Source
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PyCodegen = Any
 | 
			
		||||
 | 
			
		||||
# This file is to handle types that we don't want to support
 | 
			
		||||
# as explicit FX graph inputs. This uses a sidetable which
 | 
			
		||||
# we populate in bytecode and is loaded during graph execution
 | 
			
		||||
@ -11,44 +13,70 @@ from torch._dynamo.source import Source
 | 
			
		||||
# We use a dynamo-generated index as a level of indirection
 | 
			
		||||
# this allows us to register objects externally in pre-graph bytecode that we want
 | 
			
		||||
# to pass to the graph, but not support their types as graph inputs
 | 
			
		||||
index_to_source: dict[int, Source] = {}
 | 
			
		||||
index_to_bytecode_constructor: dict[int, Callable[[PyCodegen], None]] = {}
 | 
			
		||||
 | 
			
		||||
index_to_user_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
 | 
			
		||||
index_to_external_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
 | 
			
		||||
 | 
			
		||||
keep_alive: list[Any] = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def has_user_objects() -> bool:
 | 
			
		||||
    return bool(index_to_source)
 | 
			
		||||
    return bool(index_to_bytecode_constructor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_user_object_by_index(index: int) -> Any:
 | 
			
		||||
    assert index in index_to_user_object_weakref, (
 | 
			
		||||
def get_external_object_by_index(index: int) -> Any:
 | 
			
		||||
    assert index in index_to_external_object_weakref, (
 | 
			
		||||
        "Index not registered in index_to_user_object_weakref"
 | 
			
		||||
    )
 | 
			
		||||
    obj = index_to_user_object_weakref[index]()
 | 
			
		||||
    obj = index_to_external_object_weakref[index]()
 | 
			
		||||
    assert obj is not None, "User object is no longer alive"
 | 
			
		||||
    return index_to_user_object_weakref[index]()
 | 
			
		||||
    return index_to_external_object_weakref[index]()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def store_user_object_weakrefs(*args: Any) -> None:
 | 
			
		||||
    global index_to_user_object_weakref
 | 
			
		||||
    index_to_user_object_weakref.clear()
 | 
			
		||||
    index_to_user_object_weakref.update(
 | 
			
		||||
    global index_to_external_object_weakref
 | 
			
		||||
    index_to_external_object_weakref.clear()
 | 
			
		||||
    index_to_external_object_weakref.update(
 | 
			
		||||
        {i: weakref.ref(arg) for i, arg in enumerate(args)}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reset_user_object_tracking() -> None:
 | 
			
		||||
    index_to_source.clear()
 | 
			
		||||
    index_to_user_object_weakref.clear()
 | 
			
		||||
    index_to_bytecode_constructor.clear()
 | 
			
		||||
    index_to_external_object_weakref.clear()
 | 
			
		||||
    keep_alive.clear()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_graph_created_object(
 | 
			
		||||
    example_value: Any, construct_fn: Callable[[int, PyCodegen], None]
 | 
			
		||||
) -> int:
 | 
			
		||||
    global index_to_bytecode_constructor
 | 
			
		||||
    global keep_alive
 | 
			
		||||
    keep_alive.append(example_value)
 | 
			
		||||
    index = len(index_to_bytecode_constructor)
 | 
			
		||||
    index_to_bytecode_constructor[index] = lambda cg: construct_fn(index, cg)
 | 
			
		||||
    try:
 | 
			
		||||
        index_to_external_object_weakref[index] = weakref.ref(example_value)
 | 
			
		||||
    except TypeError as e:
 | 
			
		||||
        from .exc import unimplemented_v2
 | 
			
		||||
 | 
			
		||||
        unimplemented_v2(
 | 
			
		||||
            gb_type="Failed to make weakref to graph-created external object",
 | 
			
		||||
            context=f"user_object: {example_value}",
 | 
			
		||||
            explanation="Object does not allow us to make a weakref to it",
 | 
			
		||||
            hints=[],
 | 
			
		||||
            from_exc=e,
 | 
			
		||||
        )
 | 
			
		||||
    return index
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Register a user object to be used in the graph
 | 
			
		||||
def register_user_object(value: Any, source: Source) -> int:
 | 
			
		||||
    global index_to_source
 | 
			
		||||
    index = len(index_to_source)
 | 
			
		||||
    index_to_source[index] = source
 | 
			
		||||
    global index_to_bytecode_constructor
 | 
			
		||||
    index = len(index_to_bytecode_constructor)
 | 
			
		||||
    index_to_bytecode_constructor[index] = lambda cg: cg(source)
 | 
			
		||||
    try:
 | 
			
		||||
        index_to_user_object_weakref[index] = weakref.ref(value)
 | 
			
		||||
        index_to_external_object_weakref[index] = weakref.ref(value)
 | 
			
		||||
    except TypeError as e:
 | 
			
		||||
        from .exc import unimplemented_v2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -100,7 +100,7 @@ from .exc import (
 | 
			
		||||
    unimplemented_v2,
 | 
			
		||||
    unimplemented_v2_with_warning,
 | 
			
		||||
)
 | 
			
		||||
from .graph_bytecode_inputs import has_user_objects, index_to_source
 | 
			
		||||
from .graph_bytecode_inputs import has_user_objects, index_to_bytecode_constructor
 | 
			
		||||
from .graph_deduplication import apply_graph_deduplication
 | 
			
		||||
from .graph_region_tracker import GraphRegionTracker
 | 
			
		||||
from .guards import GuardBuilder, install_guard
 | 
			
		||||
@ -1528,9 +1528,19 @@ class OutputGraph(OutputGraphCommon):
 | 
			
		||||
                    "store_user_object_weakrefs",
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            for source in reversed(index_to_source.values()):
 | 
			
		||||
                codegen(source)
 | 
			
		||||
            codegen.call_function(len(index_to_source), False)
 | 
			
		||||
            tmp_vars = []
 | 
			
		||||
            for constructor in reversed(index_to_bytecode_constructor.values()):
 | 
			
		||||
                constructor(codegen)
 | 
			
		||||
                var_name = (
 | 
			
		||||
                    self.new_var()
 | 
			
		||||
                )  # keep alive any temp objects for the rest of the frame
 | 
			
		||||
                codegen.store(var_name)
 | 
			
		||||
                tmp_vars.append(var_name)
 | 
			
		||||
 | 
			
		||||
            for var_name in tmp_vars:
 | 
			
		||||
                codegen.append_output(codegen.create_load(var_name))
 | 
			
		||||
 | 
			
		||||
            codegen.call_function(len(index_to_bytecode_constructor), False)
 | 
			
		||||
            codegen.pop_top()
 | 
			
		||||
            self.add_output_instructions(codegen.get_instructions())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ import torch
 | 
			
		||||
from torch import SymInt
 | 
			
		||||
from torch._dispatch.python import enable_python_dispatcher
 | 
			
		||||
from torch._dynamo.graph_bytecode_inputs import (
 | 
			
		||||
    get_user_object_by_index,
 | 
			
		||||
    get_external_object_by_index,
 | 
			
		||||
    register_user_object,
 | 
			
		||||
)
 | 
			
		||||
from torch._dynamo.utils import (
 | 
			
		||||
@ -1042,7 +1042,7 @@ class VariableBuilder:
 | 
			
		||||
            self.install_guards(GuardBuilder.TYPE_MATCH)
 | 
			
		||||
            index = register_user_object(value, self.source)
 | 
			
		||||
            stream_proxy = self.tx.output.create_proxy(
 | 
			
		||||
                "call_function", get_user_object_by_index, (index,), {}
 | 
			
		||||
                "call_function", get_external_object_by_index, (index,), {}
 | 
			
		||||
            )
 | 
			
		||||
            set_example_value(stream_proxy.node, value)
 | 
			
		||||
            var = StreamVariable(
 | 
			
		||||
@ -1063,7 +1063,7 @@ class VariableBuilder:
 | 
			
		||||
            index = register_user_object(value, self.source)
 | 
			
		||||
            event_proxy = self.tx.output.create_proxy(
 | 
			
		||||
                "call_function",
 | 
			
		||||
                get_user_object_by_index,
 | 
			
		||||
                get_external_object_by_index,
 | 
			
		||||
                (index,),
 | 
			
		||||
                {},
 | 
			
		||||
            )
 | 
			
		||||
@ -2978,8 +2978,8 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
 | 
			
		||||
        set_example_value(proxy.node, example_value)
 | 
			
		||||
        return SymNodeVariable(proxy, example_value, **options)
 | 
			
		||||
    elif (
 | 
			
		||||
        inspect.isclass(proxy.node.target)
 | 
			
		||||
        and issubclass(proxy.node.target, torch.Stream)
 | 
			
		||||
        isinstance(example_value, torch.Stream)
 | 
			
		||||
        and proxy.node.target == get_external_object_by_index
 | 
			
		||||
    ) or proxy.node.target in [
 | 
			
		||||
        device_interface.current_stream
 | 
			
		||||
        for _, device_interface in get_registered_device_interfaces()
 | 
			
		||||
 | 
			
		||||
@ -4,8 +4,10 @@ import torch
 | 
			
		||||
from torch.fx import Proxy
 | 
			
		||||
 | 
			
		||||
from .. import graph_break_hints
 | 
			
		||||
from ..bytecode_transformation import create_call_function
 | 
			
		||||
from ..device_interface import get_interface_for_device
 | 
			
		||||
from ..exc import TYPE_CHECKING, unimplemented_v2
 | 
			
		||||
from ..source import AttrSource, CallFunctionNoArgsSource, TorchSource
 | 
			
		||||
from .base import VariableTracker
 | 
			
		||||
from .constant import ConstantVariable
 | 
			
		||||
from .ctx_manager import ContextWrappingVariable
 | 
			
		||||
@ -171,6 +173,9 @@ class StreamVariable(StreamContextVariable):
 | 
			
		||||
        device: torch.device,
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        # Index into the user object table
 | 
			
		||||
        # used to pass arbitrary objects to the graph
 | 
			
		||||
        user_object_index = kwargs.pop("user_obj_index", None)
 | 
			
		||||
        if proxy is not None and "example_value" in proxy.node.meta:
 | 
			
		||||
            assert proxy.node.meta["example_value"] == value
 | 
			
		||||
        assert value.device.type == device.type, (
 | 
			
		||||
@ -181,6 +186,8 @@ class StreamVariable(StreamContextVariable):
 | 
			
		||||
        self.value = value
 | 
			
		||||
        self.device = device
 | 
			
		||||
 | 
			
		||||
        self.user_object_index = user_object_index
 | 
			
		||||
 | 
			
		||||
    def python_type(self) -> type:
 | 
			
		||||
        return torch.Stream
 | 
			
		||||
 | 
			
		||||
@ -259,15 +266,27 @@ class StreamVariable(StreamContextVariable):
 | 
			
		||||
        # If we got here, this stream is fully subsumed by the graph - this means it is
 | 
			
		||||
        # not an input or global
 | 
			
		||||
        assert not self.source
 | 
			
		||||
        # Since we just proved that - for other such structures, like lists and dicts, reconstruction
 | 
			
		||||
        # is fine and sound according to dynamo principles of treating collectives. However,
 | 
			
		||||
        # streams are special in that we want to preserve the identity of the stream as the same as in the graph
 | 
			
		||||
        # Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
 | 
			
		||||
        # yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
 | 
			
		||||
        # design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
 | 
			
		||||
        prefix = f"_stream_{self.device}"
 | 
			
		||||
        name = codegen.tx.output.install_global_by_id(prefix, self.value)
 | 
			
		||||
        codegen.append_output(codegen.create_load_global(name, add=True))
 | 
			
		||||
        if self.user_object_index is not None:
 | 
			
		||||
            codegen.add_push_null(
 | 
			
		||||
                lambda: codegen.load_import_from(
 | 
			
		||||
                    torch._dynamo.graph_bytecode_inputs.__name__,
 | 
			
		||||
                    "get_external_object_by_index",
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            codegen.append_output(codegen.create_load_const(self.user_object_index))
 | 
			
		||||
            codegen.extend_output(create_call_function(1, False))
 | 
			
		||||
        else:
 | 
			
		||||
            # TODO mlazos: evaluate if we still need this
 | 
			
		||||
            prefix = f"_stream_{self.device}"
 | 
			
		||||
            name = codegen.tx.output.install_global_by_id(prefix, self.value)
 | 
			
		||||
            codegen.append_output(codegen.create_load_global(name, add=True))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def construct_in_graph_stream(index: int, codegen: "PyCodegen") -> None:
 | 
			
		||||
        # Use source to create the right bytecode, this
 | 
			
		||||
        # isn't an actual input
 | 
			
		||||
        source = CallFunctionNoArgsSource(AttrSource(TorchSource(), "Stream"))
 | 
			
		||||
        codegen(source)
 | 
			
		||||
 | 
			
		||||
    def _get_target_values(self) -> list["StreamVariable"]:
 | 
			
		||||
        return [self]
 | 
			
		||||
 | 
			
		||||
@ -58,6 +58,7 @@ from ..exc import (
 | 
			
		||||
    raise_observed_exception,
 | 
			
		||||
    unimplemented_v2,
 | 
			
		||||
)
 | 
			
		||||
from ..graph_bytecode_inputs import get_external_object_by_index
 | 
			
		||||
from ..guards import GuardBuilder, install_guard
 | 
			
		||||
from ..source import (
 | 
			
		||||
    AttrSource,
 | 
			
		||||
@ -792,14 +793,31 @@ class UserDefinedClassVariable(UserDefinedVariable):
 | 
			
		||||
                )
 | 
			
		||||
                args = [stacked]
 | 
			
		||||
 | 
			
		||||
            tensor_variable = wrap_fx_proxy(
 | 
			
		||||
                tx=tx,
 | 
			
		||||
                proxy=tx.output.create_proxy(
 | 
			
		||||
                    "call_function",
 | 
			
		||||
                    self.value,
 | 
			
		||||
                    *proxy_args_kwargs(args, kwargs),
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
            if issubclass(self.value, torch.Stream):
 | 
			
		||||
                # Register newly created stream for reconstruction
 | 
			
		||||
                stream = self.value()
 | 
			
		||||
                from ..graph_bytecode_inputs import register_graph_created_object
 | 
			
		||||
                from .streams import StreamVariable
 | 
			
		||||
 | 
			
		||||
                ind = register_graph_created_object(
 | 
			
		||||
                    stream, StreamVariable.construct_in_graph_stream
 | 
			
		||||
                )
 | 
			
		||||
                tensor_variable = wrap_fx_proxy(
 | 
			
		||||
                    tx=tx,
 | 
			
		||||
                    proxy=tx.output.create_proxy(
 | 
			
		||||
                        "call_function", get_external_object_by_index, (ind,), {}
 | 
			
		||||
                    ),
 | 
			
		||||
                    user_obj_index=ind,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                tensor_variable = wrap_fx_proxy(
 | 
			
		||||
                    tx=tx,
 | 
			
		||||
                    proxy=tx.output.create_proxy(
 | 
			
		||||
                        "call_function",
 | 
			
		||||
                        self.value,
 | 
			
		||||
                        *proxy_args_kwargs(args, kwargs),
 | 
			
		||||
                    ),
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            return tensor_variable
 | 
			
		||||
        elif self.value is random.Random:
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user