[user-streams] Track symbolic current stream

merge into stream tests

ghstack-source-id: 0fab78f039eb26365195a7b5ceca756dd58e9724
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165212

symbolic streams update
This commit is contained in:
Michael Lazos
2025-10-13 14:22:58 -07:00
parent 44de0318c4
commit dc90a72bb5
5 changed files with 106 additions and 1 deletions

View File

@ -2495,6 +2495,14 @@
}
],
"GB0249": [
{
"Gb_type": "bad device argument to torch.accelerator.current_stream",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
},
{
"Gb_type": "bad device argument to torch.get_device_module",
"Context": "args={args}, kwargs={kwargs}",
@ -2798,5 +2806,25 @@
"Explanation": "Object does not allow us to make a weakref to it",
"Hints": []
}
],
"GB0278": [
{
"Gb_type": "unsupported arguments to torch.accelerator.current_stream",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "torch.accelerator.current_stream accepts one optional argument `device`",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0279": [
{
"Gb_type": "bad device argument to torch.get_device_module",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
]
}

View File

@ -173,6 +173,7 @@ from .variables.misc import (
UnknownVariable,
)
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
from .variables.streams import SymbolicStreamState
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
from .variables.torch_function import (
SymbolicTorchFunctionState,
@ -1170,6 +1171,7 @@ class InstructionTranslatorBase(
symbolic_locals: dict[str, VariableTracker]
symbolic_globals: dict[str, VariableTracker]
symbolic_torch_function_state: SymbolicTorchFunctionState
symbolic_stream_state: SymbolicStreamState
post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]]
stack: list[VariableTracker]
instruction_pointer: Optional[int]
@ -4069,6 +4071,7 @@ class InstructionTranslatorBase(
symbolic_locals: dict[str, VariableTracker],
symbolic_globals: dict[str, VariableTracker],
symbolic_torch_function_state: SymbolicTorchFunctionState,
symbolic_stream_state: SymbolicStreamState,
f_code: types.CodeType,
export: bool,
inline_depth: int,
@ -4088,6 +4091,7 @@ class InstructionTranslatorBase(
self.symbolic_locals = symbolic_locals
self.symbolic_globals = symbolic_globals
self.symbolic_torch_function_state = symbolic_torch_function_state
self.symbolic_stream_state = symbolic_stream_state
# used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals)
# in order to generate any nested closures
self.post_prune_cell_and_freevars = None
@ -4241,6 +4245,7 @@ class InstructionTranslator(InstructionTranslatorBase):
# A global var is inserted only after a STORE_GLOBAL happens to it
symbolic_globals={},
symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
symbolic_stream_state=None, # type: ignore[arg-type] # set below
f_code=f_code,
export=export,
inline_depth=0,
@ -4345,6 +4350,8 @@ class InstructionTranslator(InstructionTranslatorBase):
torch_function_mode_stack
)
self.symbolic_stream_state = SymbolicStreamState()
if export:
# export gets confused if we never realize unused inputs
# in export mode just eagerly realize everything
@ -4673,6 +4680,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_state,
parent.symbolic_stream_state,
func,
)
else:
@ -4684,6 +4692,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_state,
parent.symbolic_stream_state,
# pyrefly: ignore # bad-argument-type
func,
)
@ -4767,6 +4776,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
symbolic_locals: dict[str, VariableTracker],
symbolic_globals: dict[str, VariableTracker],
symbolic_torch_function_state: SymbolicTorchFunctionState,
symbolic_stream_state: SymbolicStreamState,
funcvar: BaseUserFunctionVariable,
) -> None:
f_globals = funcvar.get_globals() # type: ignore[attr-defined]
@ -4800,6 +4810,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
symbolic_locals=symbolic_locals,
symbolic_globals=symbolic_globals,
symbolic_torch_function_state=symbolic_torch_function_state,
symbolic_stream_state=symbolic_stream_state,
instructions=instructions,
code_options={k: getattr(code, k) for k in get_code_keys()},
f_code=code,

View File

@ -2979,7 +2979,8 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
return SymNodeVariable(proxy, example_value, **options)
elif (
isinstance(example_value, torch.Stream)
and proxy.node.target == get_external_object_by_index
and proxy.node.target
in (get_external_object_by_index, torch.accelerator.current_stream)
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()

View File

@ -1,3 +1,4 @@
import collections
from typing import Any, Optional
import torch
@ -11,6 +12,7 @@ from ..source import AttrSource, CallFunctionNoArgsSource, TorchSource
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import ContextWrappingVariable
from .lazy import LazyVariableTracker
from .misc import GetAttrVariable
@ -65,6 +67,38 @@ def _(
pass
class SymbolicStreamState:
"""Track the currently entered stream if any"""
def __init__(self) -> None:
from ..source import CurrentStreamSource
stream_var = LazyVariableTracker.create(
torch.accelerator.current_stream(),
source=CurrentStreamSource(torch.accelerator.current_stream().device),
)
self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque(
[stream_var] # type: ignore[list-item]
)
def enter_stream(self, stream: "StreamVariable") -> None:
self.cur_stream_stack.append(stream)
def exit_stream(self) -> None:
self.cur_stream_stack.pop()
def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable":
if device is not None:
for stream in reversed(self.cur_stream_stack):
if stream.device == device:
return stream
return self.cur_stream_stack[-1]
def in_stream_context(self) -> bool:
return len(self.cur_stream_stack) > 0
class StreamContextVariable(ContextWrappingVariable):
"""This represents torch.cuda.StreamContext"""
@ -98,6 +132,7 @@ class StreamContextVariable(ContextWrappingVariable):
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
# to stream, from stream is the order of the arguments
# we are entering the target, and leaving the initial stream
tx.symbolic_stream_state.enter_stream(self._get_target_values()[0])
tx.output.create_proxy(
"call_function",
torch.ops.streams.fork.default,
@ -109,6 +144,7 @@ class StreamContextVariable(ContextWrappingVariable):
def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker":
# to stream, from stream is the order of the arguments
# we are leaving the target, and entering the initial stream
tx.symbolic_stream_state.exit_stream()
tx.output.create_proxy(
"call_function",
torch.ops.streams.join.default,

View File

@ -1237,6 +1237,35 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# pyrefly: ignore # unbound-name
return VariableTracker.build(tx, module, new_source)
@register(torch.accelerator.current_stream)
def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs):
if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs):
unimplemented_v2(
gb_type="unsupported arguments to torch.accelerator.current_stream",
context=f"args={args}, kwargs={kwargs}",
explanation="torch.accelerator.current_stream accepts one optional argument `device`",
hints=[
*graph_break_hints.USER_ERROR,
],
)
try:
if kwargs:
device = torch.device(kwargs["device"].as_python_constant())
elif args:
device = torch.device(args[0].as_python_constant())
else:
device = None
return tx.symbolic_stream_state.cur_stream(device)
except Exception as e:
unimplemented_v2(
gb_type="bad device argument to torch.accelerator.current_stream",
context=f"args={args}, kwargs={kwargs}",
explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
hints=[*graph_break_hints.USER_ERROR],
from_exc=e,
)
@register(torch.set_default_device)
def handle_set_default_device(
self, tx: "InstructionTranslator", *args, **kwargs