mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	[user-streams] Track symbolic current stream
merge into stream tests ghstack-source-id: 0fab78f039eb26365195a7b5ceca756dd58e9724 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165212 symbolic streams update
This commit is contained in:
		@ -2495,6 +2495,14 @@
 | 
			
		||||
    }
 | 
			
		||||
  ],
 | 
			
		||||
  "GB0249": [
 | 
			
		||||
    {
 | 
			
		||||
      "Gb_type": "bad device argument to torch.accelerator.current_stream",
 | 
			
		||||
      "Context": "args={args}, kwargs={kwargs}",
 | 
			
		||||
      "Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
 | 
			
		||||
      "Hints": [
 | 
			
		||||
        "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
 | 
			
		||||
      ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "Gb_type": "bad device argument to torch.get_device_module",
 | 
			
		||||
      "Context": "args={args}, kwargs={kwargs}",
 | 
			
		||||
@ -2798,5 +2806,25 @@
 | 
			
		||||
      "Explanation": "Object does not allow us to make a weakref to it",
 | 
			
		||||
      "Hints": []
 | 
			
		||||
    }
 | 
			
		||||
  ],
 | 
			
		||||
  "GB0278": [
 | 
			
		||||
    {
 | 
			
		||||
      "Gb_type": "unsupported arguments to torch.accelerator.current_stream",
 | 
			
		||||
      "Context": "args={args}, kwargs={kwargs}",
 | 
			
		||||
      "Explanation": "torch.accelerator.current_stream accepts one optional argument `device`",
 | 
			
		||||
      "Hints": [
 | 
			
		||||
        "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
 | 
			
		||||
      ]
 | 
			
		||||
    }
 | 
			
		||||
  ],
 | 
			
		||||
  "GB0279": [
 | 
			
		||||
    {
 | 
			
		||||
      "Gb_type": "bad device argument to torch.get_device_module",
 | 
			
		||||
      "Context": "args={args}, kwargs={kwargs}",
 | 
			
		||||
      "Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
 | 
			
		||||
      "Hints": [
 | 
			
		||||
        "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
 | 
			
		||||
      ]
 | 
			
		||||
    }
 | 
			
		||||
  ]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -173,6 +173,7 @@ from .variables.misc import (
 | 
			
		||||
    UnknownVariable,
 | 
			
		||||
)
 | 
			
		||||
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
 | 
			
		||||
from .variables.streams import SymbolicStreamState
 | 
			
		||||
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
 | 
			
		||||
from .variables.torch_function import (
 | 
			
		||||
    SymbolicTorchFunctionState,
 | 
			
		||||
@ -1170,6 +1171,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
    symbolic_locals: dict[str, VariableTracker]
 | 
			
		||||
    symbolic_globals: dict[str, VariableTracker]
 | 
			
		||||
    symbolic_torch_function_state: SymbolicTorchFunctionState
 | 
			
		||||
    symbolic_stream_state: SymbolicStreamState
 | 
			
		||||
    post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]]
 | 
			
		||||
    stack: list[VariableTracker]
 | 
			
		||||
    instruction_pointer: Optional[int]
 | 
			
		||||
@ -4069,6 +4071,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
        symbolic_locals: dict[str, VariableTracker],
 | 
			
		||||
        symbolic_globals: dict[str, VariableTracker],
 | 
			
		||||
        symbolic_torch_function_state: SymbolicTorchFunctionState,
 | 
			
		||||
        symbolic_stream_state: SymbolicStreamState,
 | 
			
		||||
        f_code: types.CodeType,
 | 
			
		||||
        export: bool,
 | 
			
		||||
        inline_depth: int,
 | 
			
		||||
@ -4088,6 +4091,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
        self.symbolic_locals = symbolic_locals
 | 
			
		||||
        self.symbolic_globals = symbolic_globals
 | 
			
		||||
        self.symbolic_torch_function_state = symbolic_torch_function_state
 | 
			
		||||
        self.symbolic_stream_state = symbolic_stream_state
 | 
			
		||||
        # used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals)
 | 
			
		||||
        # in order to generate any nested closures
 | 
			
		||||
        self.post_prune_cell_and_freevars = None
 | 
			
		||||
@ -4241,6 +4245,7 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
            # A global var is inserted only after a STORE_GLOBAL happens to it
 | 
			
		||||
            symbolic_globals={},
 | 
			
		||||
            symbolic_torch_function_state=None,  # type: ignore[arg-type] # set below
 | 
			
		||||
            symbolic_stream_state=None,  # type: ignore[arg-type] # set below
 | 
			
		||||
            f_code=f_code,
 | 
			
		||||
            export=export,
 | 
			
		||||
            inline_depth=0,
 | 
			
		||||
@ -4345,6 +4350,8 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                torch_function_mode_stack
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.symbolic_stream_state = SymbolicStreamState()
 | 
			
		||||
 | 
			
		||||
            if export:
 | 
			
		||||
                # export gets confused if we never realize unused inputs
 | 
			
		||||
                # in export mode just eagerly realize everything
 | 
			
		||||
@ -4673,6 +4680,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                sub_locals,
 | 
			
		||||
                parent.symbolic_globals,
 | 
			
		||||
                parent.symbolic_torch_function_state,
 | 
			
		||||
                parent.symbolic_stream_state,
 | 
			
		||||
                func,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
@ -4684,6 +4692,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                sub_locals,
 | 
			
		||||
                parent.symbolic_globals,
 | 
			
		||||
                parent.symbolic_torch_function_state,
 | 
			
		||||
                parent.symbolic_stream_state,
 | 
			
		||||
                # pyrefly: ignore  # bad-argument-type
 | 
			
		||||
                func,
 | 
			
		||||
            )
 | 
			
		||||
@ -4767,6 +4776,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
        symbolic_locals: dict[str, VariableTracker],
 | 
			
		||||
        symbolic_globals: dict[str, VariableTracker],
 | 
			
		||||
        symbolic_torch_function_state: SymbolicTorchFunctionState,
 | 
			
		||||
        symbolic_stream_state: SymbolicStreamState,
 | 
			
		||||
        funcvar: BaseUserFunctionVariable,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        f_globals = funcvar.get_globals()  # type: ignore[attr-defined]
 | 
			
		||||
@ -4800,6 +4810,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
            symbolic_locals=symbolic_locals,
 | 
			
		||||
            symbolic_globals=symbolic_globals,
 | 
			
		||||
            symbolic_torch_function_state=symbolic_torch_function_state,
 | 
			
		||||
            symbolic_stream_state=symbolic_stream_state,
 | 
			
		||||
            instructions=instructions,
 | 
			
		||||
            code_options={k: getattr(code, k) for k in get_code_keys()},
 | 
			
		||||
            f_code=code,
 | 
			
		||||
 | 
			
		||||
@ -2979,7 +2979,8 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
 | 
			
		||||
        return SymNodeVariable(proxy, example_value, **options)
 | 
			
		||||
    elif (
 | 
			
		||||
        isinstance(example_value, torch.Stream)
 | 
			
		||||
        and proxy.node.target == get_external_object_by_index
 | 
			
		||||
        and proxy.node.target
 | 
			
		||||
        in (get_external_object_by_index, torch.accelerator.current_stream)
 | 
			
		||||
    ) or proxy.node.target in [
 | 
			
		||||
        device_interface.current_stream
 | 
			
		||||
        for _, device_interface in get_registered_device_interfaces()
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
import collections
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -11,6 +12,7 @@ from ..source import AttrSource, CallFunctionNoArgsSource, TorchSource
 | 
			
		||||
from .base import VariableTracker
 | 
			
		||||
from .constant import ConstantVariable
 | 
			
		||||
from .ctx_manager import ContextWrappingVariable
 | 
			
		||||
from .lazy import LazyVariableTracker
 | 
			
		||||
from .misc import GetAttrVariable
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -65,6 +67,38 @@ def _(
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SymbolicStreamState:
 | 
			
		||||
    """Track the currently entered stream if any"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        from ..source import CurrentStreamSource
 | 
			
		||||
 | 
			
		||||
        stream_var = LazyVariableTracker.create(
 | 
			
		||||
            torch.accelerator.current_stream(),
 | 
			
		||||
            source=CurrentStreamSource(torch.accelerator.current_stream().device),
 | 
			
		||||
        )
 | 
			
		||||
        self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque(
 | 
			
		||||
            [stream_var]  # type: ignore[list-item]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def enter_stream(self, stream: "StreamVariable") -> None:
 | 
			
		||||
        self.cur_stream_stack.append(stream)
 | 
			
		||||
 | 
			
		||||
    def exit_stream(self) -> None:
 | 
			
		||||
        self.cur_stream_stack.pop()
 | 
			
		||||
 | 
			
		||||
    def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable":
 | 
			
		||||
        if device is not None:
 | 
			
		||||
            for stream in reversed(self.cur_stream_stack):
 | 
			
		||||
                if stream.device == device:
 | 
			
		||||
                    return stream
 | 
			
		||||
 | 
			
		||||
        return self.cur_stream_stack[-1]
 | 
			
		||||
 | 
			
		||||
    def in_stream_context(self) -> bool:
 | 
			
		||||
        return len(self.cur_stream_stack) > 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StreamContextVariable(ContextWrappingVariable):
 | 
			
		||||
    """This represents torch.cuda.StreamContext"""
 | 
			
		||||
 | 
			
		||||
@ -98,6 +132,7 @@ class StreamContextVariable(ContextWrappingVariable):
 | 
			
		||||
    def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
 | 
			
		||||
        # to stream, from stream is the order of the arguments
 | 
			
		||||
        # we are entering the target, and leaving the initial stream
 | 
			
		||||
        tx.symbolic_stream_state.enter_stream(self._get_target_values()[0])
 | 
			
		||||
        tx.output.create_proxy(
 | 
			
		||||
            "call_function",
 | 
			
		||||
            torch.ops.streams.fork.default,
 | 
			
		||||
@ -109,6 +144,7 @@ class StreamContextVariable(ContextWrappingVariable):
 | 
			
		||||
    def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker":
 | 
			
		||||
        # to stream, from stream is the order of the arguments
 | 
			
		||||
        # we are leaving the target, and entering the initial stream
 | 
			
		||||
        tx.symbolic_stream_state.exit_stream()
 | 
			
		||||
        tx.output.create_proxy(
 | 
			
		||||
            "call_function",
 | 
			
		||||
            torch.ops.streams.join.default,
 | 
			
		||||
 | 
			
		||||
@ -1237,6 +1237,35 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
 | 
			
		||||
            # pyrefly: ignore  # unbound-name
 | 
			
		||||
            return VariableTracker.build(tx, module, new_source)
 | 
			
		||||
 | 
			
		||||
        @register(torch.accelerator.current_stream)
 | 
			
		||||
        def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs):
 | 
			
		||||
            if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs):
 | 
			
		||||
                unimplemented_v2(
 | 
			
		||||
                    gb_type="unsupported arguments to torch.accelerator.current_stream",
 | 
			
		||||
                    context=f"args={args}, kwargs={kwargs}",
 | 
			
		||||
                    explanation="torch.accelerator.current_stream accepts one optional argument `device`",
 | 
			
		||||
                    hints=[
 | 
			
		||||
                        *graph_break_hints.USER_ERROR,
 | 
			
		||||
                    ],
 | 
			
		||||
                )
 | 
			
		||||
            try:
 | 
			
		||||
                if kwargs:
 | 
			
		||||
                    device = torch.device(kwargs["device"].as_python_constant())
 | 
			
		||||
                elif args:
 | 
			
		||||
                    device = torch.device(args[0].as_python_constant())
 | 
			
		||||
                else:
 | 
			
		||||
                    device = None
 | 
			
		||||
 | 
			
		||||
                return tx.symbolic_stream_state.cur_stream(device)
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                unimplemented_v2(
 | 
			
		||||
                    gb_type="bad device argument to torch.accelerator.current_stream",
 | 
			
		||||
                    context=f"args={args}, kwargs={kwargs}",
 | 
			
		||||
                    explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
 | 
			
		||||
                    hints=[*graph_break_hints.USER_ERROR],
 | 
			
		||||
                    from_exc=e,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        @register(torch.set_default_device)
 | 
			
		||||
        def handle_set_default_device(
 | 
			
		||||
            self, tx: "InstructionTranslator", *args, **kwargs
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user