mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[user-cuda-streams] Pass streams/events to the graph via lookup table (#162899)
Stores streams in a global object look table that maps a dynamo selected index to objects. This index is generated during tracing, and at runtime, a helper function is called from the bytecode to populate this map. This differs from the previous implementation that simply mapped IDs to the associated objects. This required specialization on the IDs of the specific objects, while this new approach does not. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162899 Approved by: https://github.com/anijain2305 ghstack dependencies: #163027
This commit is contained in:
committed by
PyTorch MergeBot
parent
f15c25d5c3
commit
04e36611bb
@ -116,6 +116,7 @@ from .exc import (
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
)
|
||||
from .graph_bytecode_inputs import reset_user_object_tracking
|
||||
from .guards import (
|
||||
CheckFunctionManager,
|
||||
get_and_maybe_log_recompilation_reasons,
|
||||
@ -314,6 +315,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
torch.fx._symbolic_trace._maybe_revert_all_patches()
|
||||
)
|
||||
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
|
||||
reset_user_object_tracking()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
|
@ -2734,6 +2734,12 @@
|
||||
}
|
||||
],
|
||||
"GB0272": [
|
||||
{
|
||||
"Gb_type": "Failed to make weakref to User Object when storing by ID",
|
||||
"Context": "user_objected: {obj}",
|
||||
"Explanation": "Object does not allow us to make a weakref to it",
|
||||
"Hints": []
|
||||
},
|
||||
{
|
||||
"Gb_type": "Failed to make weakref to User Object",
|
||||
"Context": "user_objected: {obj}",
|
||||
@ -2776,5 +2782,13 @@
|
||||
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0276": [
|
||||
{
|
||||
"Gb_type": "Failed to make weakref to User Object",
|
||||
"Context": "user_object: {value}",
|
||||
"Explanation": "Object does not allow us to make a weakref to it",
|
||||
"Hints": []
|
||||
}
|
||||
]
|
||||
}
|
||||
|
62
torch/_dynamo/graph_bytecode_inputs.py
Normal file
62
torch/_dynamo/graph_bytecode_inputs.py
Normal file
@ -0,0 +1,62 @@
|
||||
import weakref
|
||||
from typing import Any
|
||||
|
||||
from torch._dynamo.source import Source
|
||||
|
||||
|
||||
# This file is to handle types that we don't want to support
|
||||
# as explicit FX graph inputs. This uses a sidetable which
|
||||
# we populate in bytecode and is loaded during graph execution
|
||||
|
||||
# We use a dynamo-generated index as a level of indirection
|
||||
# this allows us to register objects externally in pre-graph bytecode that we want
|
||||
# to pass to the graph, but not support their types as graph inputs
|
||||
index_to_source: dict[int, Source] = {}
|
||||
|
||||
index_to_user_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
|
||||
|
||||
|
||||
def has_user_objects() -> bool:
|
||||
return bool(index_to_source)
|
||||
|
||||
|
||||
def get_user_object_by_index(index: int) -> Any:
|
||||
assert index in index_to_user_object_weakref, (
|
||||
"Index not registered in index_to_user_object_weakref"
|
||||
)
|
||||
obj = index_to_user_object_weakref[index]()
|
||||
assert obj is not None, "User object is no longer alive"
|
||||
return index_to_user_object_weakref[index]()
|
||||
|
||||
|
||||
def store_user_object_weakrefs(*args: Any) -> None:
|
||||
global index_to_user_object_weakref
|
||||
index_to_user_object_weakref.clear()
|
||||
index_to_user_object_weakref.update(
|
||||
{i: weakref.ref(arg) for i, arg in enumerate(args)}
|
||||
)
|
||||
|
||||
|
||||
def reset_user_object_tracking() -> None:
|
||||
index_to_source.clear()
|
||||
index_to_user_object_weakref.clear()
|
||||
|
||||
|
||||
# Register a user object to be used in the graph
|
||||
def register_user_object(value: Any, source: Source) -> int:
|
||||
global index_to_source
|
||||
index = len(index_to_source)
|
||||
index_to_source[index] = source
|
||||
try:
|
||||
index_to_user_object_weakref[index] = weakref.ref(value)
|
||||
except TypeError as e:
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to make weakref to User Object",
|
||||
context=f"user_object: {value}",
|
||||
explanation="Object does not allow us to make a weakref to it",
|
||||
hints=[],
|
||||
from_exc=e,
|
||||
)
|
||||
return index
|
@ -2166,6 +2166,8 @@ class GuardBuilder(GuardBuilderBase):
|
||||
range,
|
||||
dict_keys,
|
||||
torch.Size,
|
||||
torch.Stream,
|
||||
torch.cuda.streams.Stream,
|
||||
*np_types,
|
||||
*ok_mutable_types,
|
||||
}
|
||||
|
@ -100,6 +100,7 @@ from .exc import (
|
||||
unimplemented_v2,
|
||||
unimplemented_v2_with_warning,
|
||||
)
|
||||
from .graph_bytecode_inputs import has_user_objects, index_to_source
|
||||
from .graph_deduplication import apply_graph_deduplication
|
||||
from .graph_region_tracker import GraphRegionTracker
|
||||
from .guards import GuardBuilder, install_guard
|
||||
@ -1520,6 +1521,27 @@ class OutputGraph(OutputGraphCommon):
|
||||
|
||||
from .decorators import disable
|
||||
|
||||
if has_user_objects():
|
||||
# NB: This is where we store possible user objects before running the graph
|
||||
# index_to_user_object_weakref is the function used in the graph to translate
|
||||
# the dynamo-generated index into the actual object passed to the compiled function.
|
||||
# We generate bytecode to store all user objects at the proper index in the below
|
||||
# call.
|
||||
codegen = PyCodegen(
|
||||
self.root_tx, root, overridden_sources=overridden_sources
|
||||
)
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
torch._dynamo.graph_bytecode_inputs.__name__,
|
||||
"store_user_object_weakrefs",
|
||||
)
|
||||
)
|
||||
for source in reversed(index_to_source.values()):
|
||||
codegen(source)
|
||||
codegen.call_function(len(index_to_source), False)
|
||||
codegen.pop_top()
|
||||
self.add_output_instructions(codegen.get_instructions())
|
||||
|
||||
# to handle random calls
|
||||
if len(self.random_calls) > 0:
|
||||
random_calls_instructions = []
|
||||
@ -1665,7 +1687,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
)
|
||||
elif (
|
||||
vt.source is not None
|
||||
and (source := getattr(vt.source, "base", None))
|
||||
and (source := getattr(vt.source, "base", None)) # type: ignore[assignment]
|
||||
and source.is_input
|
||||
):
|
||||
self.export_metadata.output_return_type[idx] = (
|
||||
|
@ -4725,6 +4725,7 @@ def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]:
|
||||
user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {}
|
||||
|
||||
|
||||
# TODO: mlazos to remove after replacing w/ above API
|
||||
def get_user_object_from_id(obj_id: int) -> Any:
|
||||
obj = user_obj_id_to_weakref[obj_id]()
|
||||
assert obj is not None, "User object is no longer alive"
|
||||
@ -4739,7 +4740,7 @@ def store_user_object_weakref(obj: object) -> None:
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to make weakref to User Object",
|
||||
gb_type="Failed to make weakref to User Object when storing by ID",
|
||||
context=f"user_objected: {obj}",
|
||||
explanation="Object does not allow us to make a weakref to it",
|
||||
hints=[],
|
||||
|
@ -45,6 +45,10 @@ import sympy
|
||||
import torch
|
||||
from torch import SymInt
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.graph_bytecode_inputs import (
|
||||
get_user_object_by_index,
|
||||
register_user_object,
|
||||
)
|
||||
from torch._dynamo.utils import (
|
||||
get_metrics_context,
|
||||
is_int_specialization_case,
|
||||
@ -1035,16 +1039,10 @@ class VariableBuilder:
|
||||
stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
|
||||
return StreamContextVariable.create(self.tx, stream_var)
|
||||
elif isinstance(value, torch.Stream):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
index = register_user_object(value, self.source)
|
||||
stream_proxy = self.tx.output.create_proxy(
|
||||
"call_function",
|
||||
type(value),
|
||||
(),
|
||||
{
|
||||
"stream_id": value.stream_id,
|
||||
"device_index": value.device_index,
|
||||
"device_type": value.device_type,
|
||||
},
|
||||
"call_function", get_user_object_by_index, (index,), {}
|
||||
)
|
||||
set_example_value(stream_proxy.node, value)
|
||||
return StreamVariable(
|
||||
@ -1060,12 +1058,12 @@ class VariableBuilder:
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return FuncTorchInterpreterVariable(value)
|
||||
elif isinstance(value, torch.Event):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
torch._dynamo.utils.store_user_object_weakref(value)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
index = register_user_object(value, self.source)
|
||||
event_proxy = self.tx.output.create_proxy(
|
||||
"call_function",
|
||||
torch._dynamo.utils.get_user_object_from_id,
|
||||
(id(value),),
|
||||
get_user_object_by_index,
|
||||
(index,),
|
||||
{},
|
||||
)
|
||||
set_example_value(event_proxy.node, value)
|
||||
|
@ -70,11 +70,19 @@ class StreamVariable(VariableTracker):
|
||||
),
|
||||
)
|
||||
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
|
||||
if self.source:
|
||||
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
|
||||
|
||||
# NB : Checking for mutation is necessary because we compare
|
||||
# constant values
|
||||
other = args[0]
|
||||
if not isinstance(other, StreamVariable):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
|
||||
if other.source:
|
||||
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type]
|
||||
)
|
||||
|
Reference in New Issue
Block a user