Files
pytorch/torch/_dynamo/variables/ctx_manager.py

1572 lines
54 KiB
Python

# mypy: ignore-errors
"""
This file contains a collection of context manager classes used by Dynamo for tracking
and managing various PyTorch runtime states during graph compilation. These context
managers handle different aspects of PyTorch's execution environment, including:
- Autograd states (grad mode, inference mode)
- CUDA streams and events
- Profiling contexts
- Deterministic algorithms
- Forward/backward AD modes
- SDPA (Scaled Dot Product Attention) kernels
- FSDP (Fully Sharded Data Parallel) states
- AMP (Automatic Mixed Precision) autocast states
The context managers ensure proper state transitions during graph compilation by
tracking enter/exit points and managing cleanup operations. They help maintain
consistency between eager execution and compiled graph behavior by capturing and
restoring state changes.
"""
import inspect
import sys
import warnings
from contextlib import ExitStack
from typing import TYPE_CHECKING, Union
import torch._C
from torch._guards import Guard
from .. import graph_break_hints, variables
from ..bytecode_transformation import (
create_call_function,
create_instruction,
create_setup_with,
)
from ..device_interface import get_interface_for_device
from ..exc import unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalStateSource
from ..utils import _get_error_on_graph_break, _set_error_on_graph_break
from .base import VariableTracker
from .functions import (
NestedUserFunctionVariable,
SkipFunctionVariable,
UserFunctionVariable,
UserMethodVariable,
WrappedNestedUserFunctionVariable,
WrappedSkipFunctionVariable,
WrappedUserFunctionVariable,
WrappedUserMethodVariable,
)
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
class ContextWrappingVariable(VariableTracker):
_nonvar_fields = {
"cm_obj",
"target_values",
"initial_values",
"state",
*VariableTracker._nonvar_fields,
}
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(**kwargs)
self.target_values = target_values
self.initial_values = initial_values
def enter(self, tx):
self._call_func(tx, self.target_values)
self.set_cleanup_hook(tx)
return variables.ConstantVariable.create(None)
def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None):
if fn is None:
def fn():
self._call_func(tx, self.initial_values)
self.cleanup_fn = fn
tx.output.add_cleanup_hook(self.cleanup)
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup_assert()
return variables.ConstantVariable.create(None)
def reconstruct_type(self, codegen: "PyCodegen"):
codegen(
AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name())
)
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(lambda: self.reconstruct_type(codegen))
target_values = self.target_values
if not target_values:
target_values = ()
codegen.extend_output([codegen.create_load_const(val) for val in target_values])
codegen.extend_output(create_call_function(len(target_values), False))
def module_name(self):
raise NotImplementedError("module_name called on base")
def fn_name(self):
raise NotImplementedError("fn_name called on base")
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
assert len(args) == 1
assert isinstance(
args[0],
(
NestedUserFunctionVariable,
SkipFunctionVariable,
UserMethodVariable,
UserFunctionVariable,
),
)
if isinstance(args[0], NestedUserFunctionVariable):
return WrappedNestedUserFunctionVariable(args[0], self)
if isinstance(args[0], SkipFunctionVariable):
return WrappedSkipFunctionVariable(args[0], self)
if isinstance(args[0], UserMethodVariable):
return WrappedUserMethodVariable(args[0], self)
if isinstance(args[0], UserFunctionVariable):
return WrappedUserFunctionVariable(args[0], self)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return True
def cleanup(self):
if self.cleanup_fn is not None:
self.cleanup_fn()
self.cleanup_fn = None
def cleanup_assert(self):
assert self.cleanup_fn, "multiple exits?"
self.cleanup()
class GenericContextWrappingVariable(UserDefinedObjectVariable):
# Some methods in ContextWrappingVariable assumes the arguments are
# python constants. Which might not always be the case here.
def __init__(self, cm_obj, **kwargs) -> None:
assert cm_obj is not None
super().__init__(
value=cm_obj,
value_type=cm_obj.__class__,
**kwargs,
)
self.cm_obj = cm_obj
def module_name(self):
return self.cm_obj.__module__
def fn_name(self):
return type(self.cm_obj).__name__
def enter(self, tx):
source = None if self.source is None else AttrSource(self.source, "__enter__")
return variables.UserMethodVariable(
self.cm_obj.__enter__.__func__,
self,
source=source,
).call_function(tx, [], {})
def exit(self, tx: "InstructionTranslator", *args):
source = None if self.source is None else AttrSource(self.source, "__exit__")
x = variables.UserMethodVariable(
self.cm_obj.__exit__.__func__,
self,
source=source,
).call_function(tx, args, {})
tx.active_generic_context_managers.pop()
return x
def supports_graph_breaks(self):
return False
def exit_on_graph_break(self):
return True
class RepararametrizeModuleContextVariable(GenericContextWrappingVariable):
def __init__(self, ctx_manager_vt, mod):
self.cm_vt = ctx_manager_vt
self.mod = mod
# We don't call super().__init__() because we're delegating most methods to cm_vt
def enter(self, tx: "InstructionTranslator"):
# Custom enter implementation with side effects
self.old_parameters_var = self.mod.var_getattr(tx, "_parameters").realize()
self.old_buffer_var = self.mod.var_getattr(tx, "_buffers").realize()
tx.output.side_effects.ignore_mutations_on(self.old_parameters_var)
tx.output.side_effects.ignore_mutations_on(self.old_buffer_var)
return self.cm_vt.enter(tx)
def exit(self, tx: "InstructionTranslator", *args):
# Custom exit implementation with side effects
x = self.cm_vt.exit(tx, *args)
tx.output.side_effects.stop_ignoring_mutations_on(self.old_buffer_var)
tx.output.side_effects.stop_ignoring_mutations_on(self.old_parameters_var)
return x
# Forward all other method calls to self.cm_vt
def __getattr__(self, name):
# This will be called for any attribute not explicitly defined in this class
return getattr(self.cm_vt, name)
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
"""represents torch grad requires grad"""
@staticmethod
def create(tx: "InstructionTranslator", target_values, **kwargs):
return GradInplaceRequiresGradCtxManagerVariable(
target_values=target_values,
initial_values=None,
**kwargs,
)
def enter(self, tx):
[enabled] = self.target_values
self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed()
torch._C._functorch.set_inplace_requires_grad_allowed(enabled)
self.set_cleanup_hook(
tx,
lambda: torch._C._functorch.set_inplace_requires_grad_allowed(
self.prev_state
),
)
self.proxy = tx.output.create_node(
"call_function",
torch._C._functorch.set_inplace_requires_grad_allowed,
(enabled,),
{},
)
return variables.ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup()
tx.output.create_node(
"call_function",
torch._C._functorch.set_inplace_requires_grad_allowed,
(self.prev_state,),
{},
)
return variables.ConstantVariable.create(None)
class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable):
"""represents torch._functorch.pyfunction.temporarily_pop_interpreter_stack()"""
@staticmethod
def create(tx: "InstructionTranslator", target_values, **kwargs):
return TemporarilyPopInterpreterStackCtxManagerVariable(
target_values=target_values,
initial_values=None,
**kwargs,
)
def enter(self, tx):
self.saved = torch._C._functorch.pop_dynamic_layer_stack()
self.set_cleanup_hook(
tx,
lambda: torch._C._functorch.push_dynamic_layer_stack(self.saved),
)
self.proxy = tx.output.create_node(
"call_function",
torch._C._functorch.pop_dynamic_layer_stack,
(),
{},
)
return variables.ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup()
tx.output.create_node(
"call_function",
torch._C._functorch.push_dynamic_layer_stack,
(self.proxy,),
{},
)
return variables.ConstantVariable.create(None)
class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable):
"""represents torch.func.jvp increment/decrement nesting"""
# A guard is needed as the grad level is baked into the torch FX graph
# This is fine if jvp is only called from within the function
# being compiled. But the FX graph may be invalid in the case of a jvp
# call from eager that calls the compiled function, as the jvp levels
# may be different.
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
@staticmethod
def create(tx: "InstructionTranslator", **kwargs):
var = JvpIncrementNestingCtxManagerVariable(
target_values=None,
initial_values=None,
**kwargs,
)
return var
def enter(self, tx):
install_guard(self._guards_singleton)
jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting()
self.set_cleanup_hook(
tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting()
)
self.proxy = tx.output.create_node(
"call_function",
torch._C._functorch._jvp_increment_nesting,
(),
{},
)
return variables.ConstantVariable.create(jvp_level)
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup()
tx.output.create_node(
"call_function", torch._C._functorch._jvp_decrement_nesting, (), {}
)
return variables.ConstantVariable.create(None)
class SetFwdGradEnabledContextManager(ContextWrappingVariable):
"""represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad"""
@staticmethod
def create(tx: "InstructionTranslator", target_values, **kwargs):
return SetFwdGradEnabledContextManager(
target_values=target_values,
initial_values=None,
**kwargs,
)
def enter(self, tx):
[mode] = self.target_values
self.prev_state = torch._C._is_fwd_grad_enabled()
torch._C._set_fwd_grad_enabled(mode)
self.set_cleanup_hook(
tx,
lambda: torch._C._set_fwd_grad_enabled(self.prev_state),
)
self.proxy = tx.output.create_node(
"call_function",
torch._C._set_fwd_grad_enabled,
(mode,),
{},
)
return variables.ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup()
tx.output.create_node(
"call_function",
torch._C._set_fwd_grad_enabled,
(self.prev_state,),
{},
)
return variables.ConstantVariable.create(None)
class DualLevelContextManager(ContextWrappingVariable):
"""Represents torch.autograd.forward_ad.dual_level ctx manager"""
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL)
@staticmethod
def create(tx: "InstructionTranslator", **kwargs):
return DualLevelContextManager(
target_values=None,
initial_values=None,
**kwargs,
)
def enter(self, tx):
install_guard(self._guards_singleton)
self.new_level = torch.autograd.forward_ad.enter_dual_level()
self.set_cleanup_hook(
tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level)
)
self.proxy = tx.output.create_node(
"call_function",
torch._C._enter_dual_level,
(),
{},
)
return variables.ConstantVariable.create(self.new_level)
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup()
tx.output.create_node(
"call_function",
torch._C._exit_dual_level,
(self.new_level,),
{},
)
return variables.ConstantVariable.create(None)
class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
"""represents torch.func.grad increment/decrement nesting"""
# A guard is needed as the grad level is baked into the torch FX graph
# This is fine if grad is only called from within the function
# being compiled. But the FX graph may be invalid in the case of a grad
# call from eager that calls the compiled function, as the grad levels
# may be different.
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
@staticmethod
def create(tx: "InstructionTranslator", **kwargs):
var = GradIncrementNestingCtxManagerVariable(
target_values=None,
initial_values=None,
**kwargs,
)
return var
def enter(self, tx):
install_guard(self._guards_singleton)
grad_level = torch._C._functorch._grad_increment_nesting()
self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting())
self.proxy = tx.output.create_node(
"call_function",
torch._C._functorch._grad_increment_nesting,
(),
{},
)
return variables.ConstantVariable.create(grad_level)
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup()
tx.output.create_node(
"call_function", torch._C._functorch._grad_decrement_nesting, (), {}
)
return variables.ConstantVariable.create(None)
class CatchWarningsCtxManagerVariable(ContextWrappingVariable):
"""Delay a call to warnings.catch_warnings"""
@staticmethod
def create(tx: "InstructionTranslator", catch_warnings_args):
return CatchWarningsCtxManagerVariable(
catch_warnings_args=catch_warnings_args,
target_values=None,
initial_values=None,
)
def __init__(self, catch_warnings_args, **kwargs) -> None:
assert isinstance(catch_warnings_args, dict), catch_warnings_args
super().__init__(**kwargs)
self.catch_warnings_args = catch_warnings_args
def enter(self, tx):
kwargs = {
k: v.as_python_constant() for k, v in self.catch_warnings_args.items()
}
ctx_val = warnings.catch_warnings(**kwargs)
self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None))
return variables.ConstantVariable.create(ctx_val.__enter__())
def reconstruct(self, cg):
cg.add_push_null(lambda: cg.load_import_from("warnings", "catch_warnings"))
cg.foreach(self.catch_warnings_args.values())
keys = tuple(self.catch_warnings_args.keys())
cg.extend_output(cg.create_call_function_kw(len(keys), keys, False))
class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
"""represents torch VMap increment/decrement nesting"""
# A guard is needed as the vmap level is baked into the torch FX graph
# generated. This is fine if vmap is only called from within the function
# being compiled. But the FX graph may be invalid in the case of a vmap
# call from eager that calls the compiled function, as the vmap levels
# may be different.
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
@staticmethod
def create(tx: "InstructionTranslator", target_values, **kwargs):
var = VmapIncrementNestingCtxManagerVariable(
target_values=target_values,
initial_values=None,
**kwargs,
)
return var
def enter(self, tx):
install_guard(self._guards_singleton)
batch_size, randomness = self.target_values
if isinstance(batch_size, variables.SymNodeVariable):
batch_size_value = batch_size.sym_num
batch_size_node = batch_size.as_proxy().node
else:
batch_size_value = batch_size.as_python_constant()
batch_size_node = batch_size.as_python_constant()
randomness = randomness.as_python_constant()
vmap_level = torch._C._functorch._vmap_increment_nesting(
batch_size_value, randomness
)
self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
self.proxy = tx.output.create_node(
"call_function",
torch._functorch.predispatch._vmap_increment_nesting,
(batch_size_node, randomness),
{},
)
return variables.ConstantVariable.create(vmap_level)
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup()
tx.output.create_node(
"call_function",
torch._functorch.predispatch._vmap_decrement_nesting,
(),
{},
)
return variables.ConstantVariable.create(None)
class GradModeVariable(ContextWrappingVariable):
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)
@staticmethod
def create(tx: "InstructionTranslator", target_value, initialized=False, **kwargs):
var = GradModeVariable(
target_values=[target_value],
initial_values=[torch.is_grad_enabled()],
**kwargs,
)
if initialized:
var._call_func(tx, var.target_values)
return var
def __init__(
self, target_values, initial_values=None, initialized=True, **kwargs
) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
install_guard(self._guards_singleton)
def enter(self, tx):
self._call_func(tx, self.target_values)
return variables.ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
self._call_func(tx, self.initial_values)
return variables.ConstantVariable.create(None)
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
):
self._call_func(tx, self.initial_values) # undo eager initialization
return super().call_function(tx, args, kwargs)
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
value = values[0]
# Coalesce grad mode mutations
if torch.is_grad_enabled() != value:
tx.output.create_node(
"call_function", torch._C._set_grad_enabled, (value,), {}
)
torch._C._set_grad_enabled(value)
def module_name(self):
return "torch"
def fn_name(self):
return "set_grad_enabled"
class InferenceModeVariable(ContextWrappingVariable):
@staticmethod
def create(tx: "InstructionTranslator", target_value, **kwargs):
var = InferenceModeVariable(
[target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs
)
return var
def __init__(
self,
target_values,
initial_values=None,
**kwargs,
) -> None:
if initial_values is None:
# This must be called here since function defaults are evaluated at import time
initial_values = torch.is_inference_mode_enabled()
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.target_values = target_values
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup_assert()
tx.output.create_node(
"call_function",
torch.autograd.grad_mode._exit_inference_mode,
(self.proxy,),
{},
)
def enter(self, tx):
disabled_inference_mode_forcibly = False
if (
torch._dynamo.config.fake_tensor_disable_inference_mode
and self.target_values[0]
):
# Do not set the inference mode because we keep it off during
# compilation. Set the grad_enabled to False to reflect the relevant
# part of inference_mode to torch.compile.
disabled_inference_mode_forcibly = True
prior = torch.is_grad_enabled()
torch._C._set_grad_enabled(False)
else:
ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values)
def cleanup_hook():
if disabled_inference_mode_forcibly:
torch._C._set_grad_enabled(prior)
else:
torch.autograd.grad_mode._exit_inference_mode(ctx)
self.set_cleanup_hook(tx, cleanup_hook)
self.proxy = tx.output.create_node(
"call_function",
torch.autograd.grad_mode._enter_inference_mode,
(*self.target_values,),
{},
)
def module_name(self):
return "torch"
def fn_name(self):
return "inference_mode"
class CUDADeviceVariable(ContextWrappingVariable):
"""represents torch.cuda.device"""
@staticmethod
def create(tx: "InstructionTranslator", device, **kwargs):
var = CUDADeviceVariable(
target_values=[torch.cuda._get_device_index(device, optional=True)],
initial_values=None,
**kwargs,
)
return var
def __init__(
self,
target_values,
initial_values=None,
**kwargs,
) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.target_values = target_values
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup_assert()
tx.output.create_node(
"call_function",
torch.cuda._maybe_exchange_device,
(self.proxy,),
{},
)
return variables.ConstantVariable.create(False)
def enter(self, tx):
prev_idx = torch.cuda._exchange_device(*self.target_values)
self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx))
self.proxy = tx.output.create_node(
"call_function",
torch.cuda._exchange_device,
(*self.target_values,),
{},
)
def module_name(self):
return "torch.cuda"
def fn_name(self):
return "device"
class TorchFunctionDisableVariable(ContextWrappingVariable):
"""represents whether torch function overrides are enabled or not"""
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)
@staticmethod
def create(tx: "InstructionTranslator", **kwargs):
var = TorchFunctionDisableVariable(
target_values=[],
initial_values=[],
**kwargs,
)
return var
def __init__(
self, target_values, initial_values=None, only_subclass=True, **kwargs
) -> None:
assert len(target_values) == 0
assert len(initial_values) == 0
from ..symbolic_convert import InstructionTranslator
tx = InstructionTranslator.current_tx()
self.only_subclass = only_subclass
self.initial_torch_function_subclass_enabled = (
tx.symbolic_torch_function_state.torch_function_subclass_enabled
)
self.initial_torch_function_mode_enabled = (
tx.symbolic_torch_function_state.torch_function_mode_enabled
)
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
install_guard(self._guards_singleton)
def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None):
if fn is None:
def fn():
tx.symbolic_torch_function_state.torch_function_subclass_enabled = (
self.initial_torch_function_subclass_enabled
)
if not self.only_subclass:
tx.symbolic_torch_function_state.torch_function_mode_enabled = (
self.initial_torch_function_subclass_enabled
)
self.cleanup_fn = fn
tx.output.add_cleanup_hook(self.cleanup)
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 0
tx.symbolic_torch_function_state.torch_function_subclass_enabled = False
if not self.only_subclass:
tx.symbolic_torch_function_state.torch_function_mode_enabled = False
def module_name(self):
return "torch._C"
def fn_name(self):
if self.only_subclass:
return "DisableTorchFunctionSubclass"
return "DisableTorchFunction"
class DeterministicAlgorithmsVariable(ContextWrappingVariable):
"""represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""
_guards_singleton = Guard(
GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS
)
@staticmethod
def create(tx: "InstructionTranslator", target_value, **kwargs):
var = DeterministicAlgorithmsVariable(
target_values=[target_value],
initial_values=[torch.are_deterministic_algorithms_enabled()],
**kwargs,
)
var._call_func(tx, [target_value])
var.set_cleanup_hook(tx)
return var
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
install_guard(self._guards_singleton)
def enter(self, tx):
return variables.ConstantVariable.create(None)
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
value = values[0]
tx.output.create_node(
"call_function", torch._C._set_deterministic_algorithms, (value,), {}
)
torch._C._set_deterministic_algorithms(value)
def module_name(self):
return "torch"
def fn_name(self):
return "use_deterministic_algorithms"
class DisabledSavedTensorsHooksVariable(ContextWrappingVariable):
"""represents torch.autograd.graph.disable_saved_tensors_hook."""
@staticmethod
def create(tx: "InstructionTranslator", target_value, **kwargs):
var = DisabledSavedTensorsHooksVariable(
target_values=[target_value],
initial_values=[
torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
],
**kwargs,
)
var._call_func(tx, [target_value])
var.set_cleanup_hook(tx)
return var
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
def enter(self, tx):
return variables.ConstantVariable.create(None)
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
value = values[0]
if value is not None:
# Disable `saved_tensors_hooks` with message (`value`)
# OR
# we are exiting this context and restoring the previous message.
tx.output.create_node(
"call_function",
torch._C._autograd._saved_tensors_hooks_disable,
(value,),
{},
)
torch._C._autograd._saved_tensors_hooks_disable(value)
else:
# We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`.
tx.output.create_node(
"call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {}
)
torch._C._autograd._saved_tensors_hooks_enable()
def module_name(self):
return "torch.autograd.graph"
def fn_name(self):
return "disable_saved_tensors_hooks"
class AutocastModeVariable(ContextWrappingVariable):
@staticmethod
def create(func, args, kwargs):
assert func in [
torch.amp.autocast_mode.autocast,
torch.cuda.amp.autocast,
torch.cpu.amp.autocast,
]
# device_type : str,
# dtype : Optional[_dtype] = None,
# enabled : bool = True,
# cache_enabled : Optional[bool] = None):cache_enabled
bound_args = inspect.signature(func).bind(*args, **kwargs)
bound_args.apply_defaults()
target_values = []
kwargs.clear()
for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
if key == "device_type" and func in [
torch.cuda.amp.autocast,
torch.cpu.amp.autocast,
]:
arg = "cuda" if func is torch.cuda.amp.autocast else "cpu"
else:
arg = bound_args.arguments[key]
if isinstance(arg, VariableTracker):
target_values.append(arg.as_python_constant())
else:
target_values.append(arg)
var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
return var
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.target_values = target_values
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup_assert()
tx.output.create_node(
"call_function", torch.amp._exit_autocast, (self.proxy,), {}
)
return variables.ConstantVariable.create(None)
def enter(self, tx):
ctx = torch.amp._enter_autocast(*self.target_values)
self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
self.proxy = tx.output.create_node(
"call_function", torch.amp._enter_autocast, (*self.target_values,), {}
)
def module_name(self):
return "torch.amp.autocast_mode"
def fn_name(self):
return "autocast"
class NullContextVariable(ContextWrappingVariable):
"""
This class represents Python contextlib.nullcontext.
"""
def __init__(self, target_values=None, **kwargs) -> None:
super().__init__(target_values=target_values, **kwargs)
def enter(self, tx):
none = variables.ConstantVariable.create(None)
return self.target_values if self.target_values else none
def exit(self, tx: "InstructionTranslator", *args):
return variables.ConstantVariable.create(None)
def module_name(self):
return "contextlib"
def fn_name(self):
return "nullcontext"
class ProfilerContextVariable(ContextWrappingVariable):
"""
This class represents a set of torch profiler context objects, where Dynamo
ignores all the side-effects in the __init__, __enter__ and __exit__ methods
by treating the object mostly as a `contextlib.nullcontext`, except for edge
cases like the `__enter__` method which returns the object itself rather
than `None`, per implementation of the torch objects.
"""
def __init__(self, **kwargs) -> None:
super().__init__(target_values=None, **kwargs)
def enter(self, tx):
return self
def exit(self, tx: "InstructionTranslator", *args):
return variables.ConstantVariable.create(None)
def module_name(self):
return "contextlib"
def fn_name(self):
return "nullcontext"
def reconstruct(self, cg):
unimplemented_v2(
gb_type="torch.profiler object escaped from compiled region",
context=str(self),
explanation="Dynamo doesn't support compiling a region that returns a torch.profiler context manager.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
class StreamContextVariable(ContextWrappingVariable):
@staticmethod
def create(tx: "InstructionTranslator", target_value, **kwargs):
from .builder import wrap_fx_proxy_cls
current_stream_method = get_interface_for_device(
target_value.device
).current_stream
current_stream = wrap_fx_proxy_cls(
StreamVariable,
tx,
tx.output.create_proxy(
"call_function",
current_stream_method,
(None,),
{},
),
)
return StreamContextVariable(
target_values=[target_value],
initial_values=[current_stream],
device=target_value.device,
**kwargs,
)
def __init__(self, target_values, device, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.device = device
self.set_stream = get_interface_for_device(self.device).set_stream
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
def enter(self, tx):
# stream generated inside the traced function
if self.target_values[0].as_proxy() is not None:
tx.output.create_proxy(
"call_function",
self.set_stream,
(self.target_values[0].as_proxy(),),
{},
)
# stream passed from outside the traced function
else:
stream = self.target_values[0].value
tx.output.create_proxy(
"call_function",
self.set_stream_id,
(stream.stream_id, stream.device_index, stream.device_type),
{},
)
self.set_stream(self.target_values[0].value)
self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
def exit(self, tx: "InstructionTranslator", *args):
tx.output.create_proxy(
"call_function",
self.set_stream,
(self.initial_values[0].as_proxy(),),
{},
)
self.cleanup_assert()
class PreserveVersionContextVariable(ContextWrappingVariable):
"""
Wraps torch.autograd._unsafe_preserve_version_counter
"""
@staticmethod
def _create_lambda_from_tensors(tx, tensors):
if isinstance(tensors, variables.TensorVariable):
versions = variables.TupleVariable(
[x.var_getattr(tx, "_version") for x in [tensors]]
)
tensors = variables.TupleVariable([tensors])
else:
versions = variables.TupleVariable(
[x.var_getattr(tx, "_version") for x in tensors.items]
)
return PreserveVersionContextVariable(tensors, versions)
@staticmethod
def constructor(tx):
return variables.LambdaVariable(
lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors(
tx, tensors
)
)
def __init__(self, tensors, prev_versions, **kwargs) -> None:
kwargs.setdefault("target_values", None)
super().__init__(**kwargs)
self.tensors = tensors
self.prev_versions = prev_versions
# The context manager accepts Union[Tensor, Tuple[Tensor]]
if isinstance(self.tensors, variables.TensorVariable):
self.tensors = variables.TupleVariable([self.tensors])
if isinstance(
self.prev_versions, (variables.ConstantVariable, variables.SymNodeVariable)
):
self.prev_versions = variables.TupleVariable([self.prev_versions])
def enter(self, tx):
pass
def exit(self, tx: "InstructionTranslator", *args):
from ..tensor_version_op import _unsafe_set_version_counter
return variables.TorchInGraphFunctionVariable(
_unsafe_set_version_counter
).call_function(tx, [self.tensors, self.prev_versions], {})
def reconstruct(self, codegen: "PyCodegen"):
unimplemented_v2(
gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region",
context=str(self),
explanation=(
"Dynamo doesn't support compiling a region that returns "
"a torch.autograd._unsafe_preserve_version_counter context manager."
),
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable):
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE)
@staticmethod
def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs):
var = FSDPParamGroupUseTrainingStateVariable(
param_group_var=param_group_var,
target_values=[target_value],
initial_values=[param_group_var.value._training_state],
**kwargs,
)
return var
def __init__(
self, param_group_var, target_values, initial_values=None, **kwargs
) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.param_group_var = param_group_var
install_guard(self._guards_singleton)
def enter(self, tx):
self._call_func(tx, self.target_values)
return variables.ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
self._call_func(tx, self.initial_values)
return variables.ConstantVariable.create(None)
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
):
self._call_func(tx, self.initial_values) # undo eager initialization
return super().call_function(tx, args, kwargs)
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
value = values[0]
if self.param_group_var.value._training_state != value:
self.param_group_var.call_method(
tx,
"__setattr__",
(
variables.ConstantVariable.create("_training_state"),
variables.EnumVariable(value),
),
{},
)
self.param_group_var.value._training_state = value
def module_name(self):
return "torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup"
def fn_name(self):
return "use_training_state"
class SDPAKernelVariable(ContextWrappingVariable):
"""represents torch.nn.attention.sdpa_kernel"""
@staticmethod
def create(tx: "InstructionTranslator", backends, set_priority=False, **kwargs):
if isinstance(backends, torch.nn.attention.SDPBackend):
backends = [backends]
var = SDPAKernelVariable(
target_values=backends,
initial_values=None,
set_priority=set_priority,
**kwargs,
)
return var
def __init__(
self,
target_values: list[torch.nn.attention.SDPBackend],
initial_values=None,
set_priority: bool = False,
**kwargs,
) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.set_priority = set_priority
@staticmethod
def _backends_to_nodes(tx, backends):
# convert to/from string in order to bake the backend into FX graph
nodes = [
tx.output.create_node(
"call_function",
torch.nn.attention._backend_from_string,
(backend.name,),
{},
)
for backend in backends
]
return nodes
def enter(self, tx):
self.prev_backends = torch.nn.attention._cur_sdpa_kernel_backends(
with_priority=self.set_priority
)
self.set_cleanup_hook(
tx,
lambda: torch.nn.attention._sdpa_kernel(
self.prev_backends, set_priority=self.set_priority
),
)
torch.nn.attention._sdpa_kernel(
self.target_values, set_priority=self.set_priority
)
arg = self._backends_to_nodes(tx, self.target_values)
tx.output.create_node(
"call_function",
torch.nn.attention._sdpa_kernel,
(arg, bool(self.set_priority)),
{},
)
return variables.ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
self.cleanup_assert()
arg = self._backends_to_nodes(tx, self.prev_backends)
tx.output.create_node(
"call_function",
torch.nn.attention._sdpa_kernel,
(arg, bool(self.set_priority)),
{},
)
return variables.ConstantVariable.create(None)
def module_name(self):
return "torch.nn.attention"
# use a private version of sdpa_kernel that accepts variadic arguments
# since dynamo reconstructs the contents of target_values one-by-one
def fn_name(self):
return "_sdpa_kernel_variadic"
class FxTracebackAnnotateVariable(ContextWrappingVariable):
"""
fx.traceback.annotate is a context manager that allows users to annotate the
fx graph nodes with custom metadata. In the context of Dynamo, we don't have
to trace the body of the context manager. Instead we want to directly run
the body of the context manager, so the Dynamo created Fx graphs have the
right custom metadata. This variable tracker just runs __enter__ and
__exit__ method (instead of tracing).
"""
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
def enter(self, tx, *args):
# Run the annotation ctx manager in eager. Also ensure that
# preserve_node_meta context manager is setup. This is important to pass
# on the metadata to the create_proxy nodes.
stack = ExitStack()
stack.enter_context(torch.fx.traceback.annotate(self.target_values))
stack.enter_context(torch.fx.traceback.preserve_node_meta())
self.set_cleanup_hook(tx, lambda: stack.close())
return variables.ConstantVariable.create(None)
def module_name(self):
return "torch.fx.traceback"
def fn_name(self):
return "annotate"
class StreamVariable(VariableTracker):
def __init__(self, proxy, value, device, **kwargs) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
assert value.device.type == device.type, (
"stream value is not equal to the passed device"
)
super().__init__(**kwargs)
self.proxy = proxy
self.value = value
self.device = device
def python_type(self):
return torch.Stream
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
assert hasattr(self.value, name), f"no stream method found named {name}"
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait_stream", "synchronize", "wait_event"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return variables.ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=variables.ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name == "record_event":
return wrap_fx_proxy_cls(
target_cls=EventVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
# NB : Checking for mutation is necessary because we compare
# constant values
other = args[0]
if not isinstance(other, StreamVariable):
return variables.ConstantVariable.create(NotImplemented)
return variables.ConstantVariable.create(
cmp_name_to_op_mapping[name](self.value, other.value)
)
return super().call_method(tx, name, args, kwargs)
def as_proxy(self):
return self.proxy
def reconstruct(self, codegen: "PyCodegen"):
# If we got here, this stream is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
# Since we just proved that - for other such structures, like lists and dicts, reconstruction
# is fine and sound according to dynamo principles of treating collectives. However,
# streams are special in that we want to preserve the identity of the stream as the same as in the graph
# Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
# yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
# design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
prefix = f"_stream_{self.device}"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))
class EventVariable(VariableTracker):
def __init__(self, proxy, value, **kwargs) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
super().__init__(**kwargs)
self.proxy = proxy
self.value = value
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
from ..utils import proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait", "record", "synchronize"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return variables.ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=variables.ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
else:
method_name = (
f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
)
unimplemented_v2(
gb_type="Unsupported event method",
context=str(name),
explanation=f"Dynamo doesn't support tracing the {method_name} method. "
f"We currently support wait, record, synchronize, and query.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
def as_proxy(self):
return self.proxy
def reconstruct(self, codegen: "PyCodegen"):
# If we got here, this event is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
# Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
prefix = "_event"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))
class DynamoConfigPatchVariable(ContextWrappingVariable):
"""represents torch._dynamo.patch_dynamo_config"""
# NOTE: no need to guard on dynamo config because dynamo config should not affect soundness
# (though it may affect tracing behavior)
def __init__(self, target_values, **kwargs) -> None:
target_values = tuple(target_values.items())
super().__init__(target_values=(target_values,), initial_values=None, **kwargs)
self.initial_values = {}
for key, _ in target_values:
self.initial_values[key] = torch._dynamo.config.__getattr__(key)
self.initial_values = (tuple(self.initial_values.items()),)
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
value = values[0]
# manually patch dynamo config
for key, val in value:
torch._dynamo.config.__setattr__(key, val)
# No need to keep track of global side effects because
# dynamo will properly restore this context manager for
# unsupported instructions and continuation functions.
# Dynamo config also should not affect the semantics of the compiled graph.
def module_name(self):
return "torch._dynamo"
def fn_name(self):
return "patch_dynamo_config"
class ErrorOnGraphBreakVariable(ContextWrappingVariable):
"""represents torch._dynamo.error_on_graph_break"""
def __init__(self, error_on_graph_break, **kwargs) -> None:
super().__init__(
target_values=(error_on_graph_break,),
initial_values=(_get_error_on_graph_break(),),
**kwargs,
)
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
_set_error_on_graph_break(values[0])
def module_name(self):
return "torch._dynamo"
def fn_name(self):
return "error_on_graph_break"
class WithEnterFunctionVariable(VariableTracker):
def __init__(
self,
ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.ctx = ctx
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
assert not args
assert not kwargs
# NOTE: we assume that the instruction immediately after the current CALL instruction
# is the first instruction of the block.
return tx.enter_ctx(self.ctx, tx.current_instruction)
def reconstruct(self, codegen: "PyCodegen"):
try:
type_str = f"{self.ctx.module_name()}.{self.ctx.fn_name()}"
except NotImplementedError:
type_str = str(type(self.ctx))
unimplemented_v2(
gb_type="Attempted to reconstruct context manager's __enter__ method",
context=str(self.ctx),
explanation=f"Attempted to reconstruct context manager {type_str} while tracing `with ...:`",
hints=[
"It is likely there is a graph break while tracing `with ctx:` "
"but outside the actual `ctx.__enter__()` method. "
"`torch.compile` does not expect this to happen.",
*graph_break_hints.DIFFICULT,
*graph_break_hints.DYNAMO_BUG,
],
)
class WithExitFunctionVariable(VariableTracker):
_nonvar_fields = {
"target",
*VariableTracker._nonvar_fields,
}
def __init__(
self,
ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable],
target,
**kwargs,
) -> None:
super().__init__(**kwargs)
assert isinstance(
ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
)
self.ctx = ctx
self.target = target
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
assert not kwargs
return self.ctx.exit(tx, *args)
def reconstruct(self, codegen: "PyCodegen"):
# Note here we reconstruct the context manager rather than the
# exit function. The handler generated by BlockStackEntry
# will re-enter the context in the resume function.
self.ctx.reconstruct_type(codegen)
if codegen.tx.output.partial_convert:
if sys.version_info >= (3, 11):
codegen.append_output(create_instruction("PUSH_NULL"))
if sys.version_info < (3, 13):
codegen.append_output(create_instruction("SWAP", arg=2))
codegen.extend_output(
[codegen.create_load_const(val) for val in self.ctx.target_values]
)
codegen.extend_output(
create_call_function(len(self.ctx.target_values), False)
)
codegen.append_output(create_setup_with(self.target))
codegen.append_output(create_instruction("POP_TOP"))