mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This work rewrites vmap support in torch.compile by inlining most of the frames into the existing FX graph. It also unlocks to PyTorch to support features that were previously missing, such as keyword args. Fixes: https://github.com/pytorch/pytorch/issues/114306 Pull Request resolved: https://github.com/pytorch/pytorch/pull/116050 Approved by: https://github.com/zou3519
712 lines
24 KiB
Python
712 lines
24 KiB
Python
import dataclasses
|
|
import inspect
|
|
from typing import Callable, Dict, List, Optional
|
|
|
|
import torch._C
|
|
from torch._guards import Guard
|
|
|
|
from .. import variables
|
|
from ..bytecode_transformation import create_call_function, create_instruction
|
|
from ..device_interface import get_interface_for_device
|
|
from ..exc import unimplemented, Unsupported
|
|
from ..guards import GuardBuilder, install_guard
|
|
from ..source import AttrSource, GlobalStateSource
|
|
from .base import VariableTracker
|
|
from .functions import (
|
|
NestedUserFunctionVariable,
|
|
UserFunctionVariable,
|
|
UserMethodVariable,
|
|
WrappedUserFunctionVariable,
|
|
WrappedUserMethodVariable,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ContextMangerState:
|
|
"""
|
|
Mutating `self` in VariableTracker is not allowed because we copy
|
|
them. This is a mutable container pointed to by context managers
|
|
that won't get copied, so it is safe to mutate.
|
|
"""
|
|
|
|
cleanup_fn: Optional[Callable] = None
|
|
proxy: Optional[torch.fx.Proxy] = None
|
|
|
|
def cleanup(self):
|
|
if self.cleanup_fn is not None:
|
|
self.cleanup_fn()
|
|
self.cleanup_fn = None
|
|
|
|
def cleanup_assert(self):
|
|
assert self.cleanup_fn, "multiple exits?"
|
|
self.cleanup()
|
|
|
|
|
|
class ContextWrappingVariable(VariableTracker):
|
|
_nonvar_fields = {
|
|
"cm_obj",
|
|
"target_values",
|
|
"initial_values",
|
|
"state",
|
|
*VariableTracker._nonvar_fields,
|
|
}
|
|
|
|
def __init__(self, target_values, initial_values=None, *, state=None, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.target_values = target_values
|
|
self.initial_values = initial_values
|
|
self.state = ContextMangerState() if state is None else state
|
|
|
|
def enter(self, tx):
|
|
self._call_func(tx, self.target_values)
|
|
self.set_cleanup_hook(tx)
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def set_cleanup_hook(self, tx, fn=None):
|
|
if fn is None:
|
|
|
|
def fn():
|
|
self._call_func(tx, self.initial_values)
|
|
|
|
self.state.cleanup_fn = fn
|
|
tx.output.add_cleanup_hook(self.state.cleanup)
|
|
|
|
def exit(self, tx, *args):
|
|
self.state.cleanup_assert()
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def reconstruct(self, codegen):
|
|
attr_source = AttrSource(
|
|
codegen.tx.import_source(self.module_name()), self.fn_name()
|
|
)
|
|
return attr_source.reconstruct(codegen)
|
|
|
|
def module_name(self):
|
|
raise NotImplementedError("module_name called on base")
|
|
|
|
def fn_name(self):
|
|
raise NotImplementedError("fn_name called on base")
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
assert len(args) == 1
|
|
if isinstance(args[0], NestedUserFunctionVariable):
|
|
args[0] = UserFunctionVariable(args[0].get_function())
|
|
assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable))
|
|
|
|
if isinstance(args[0], UserMethodVariable):
|
|
return WrappedUserMethodVariable(args[0], self)
|
|
|
|
if isinstance(args[0], UserFunctionVariable):
|
|
return WrappedUserFunctionVariable(args[0], self)
|
|
|
|
|
|
class GenericContextWrappingVariable(ContextWrappingVariable):
|
|
def __init__(self, target_values, initial_values=None, *, cm_obj=None, **kwargs):
|
|
assert cm_obj is not None
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
self.cm_obj = cm_obj
|
|
|
|
def enter(self, tx):
|
|
source = None if self.source is None else AttrSource(self.source, "__enter__")
|
|
try:
|
|
return variables.UserMethodVariable(
|
|
self.cm_obj.__enter__.__func__,
|
|
variables.UserDefinedObjectVariable(self.cm_obj),
|
|
source=source,
|
|
).call_function(tx, [], {})
|
|
except Unsupported as e:
|
|
raise unimplemented(
|
|
f"Unsupported context manager {self.cm_obj}'s __enter__ function"
|
|
) from e
|
|
|
|
def exit(self, tx, *args):
|
|
source = None if self.source is None else AttrSource(self.source, "__exit__")
|
|
try:
|
|
x = variables.UserMethodVariable(
|
|
self.cm_obj.__exit__.__func__,
|
|
variables.UserDefinedObjectVariable(self.cm_obj),
|
|
source=source,
|
|
).call_function(
|
|
tx,
|
|
[
|
|
variables.ConstantVariable.create(None),
|
|
variables.ConstantVariable.create(None),
|
|
variables.ConstantVariable.create(None),
|
|
],
|
|
{},
|
|
)
|
|
except Unsupported as e:
|
|
raise unimplemented(
|
|
f"Unsupported context manager {self.cm_obj}'s __exit__ function"
|
|
) from e
|
|
|
|
tx.generic_context_manager_depth -= 1
|
|
return x
|
|
|
|
|
|
class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
|
"""represents torch VMap increment/decrement nesting"""
|
|
|
|
# A guard is needed as the vmap level is baked into the torch FX graph
|
|
# generated. This is fine if vmap is only called from within the function
|
|
# being compiled. But the FX graph may be invalid in the case of a vmap
|
|
# call from eager that calls the compiled function, as the vmap levels
|
|
# may be different.
|
|
_guards_singleton = Guard(
|
|
GlobalStateSource(), GuardBuilder.FUNCTORCH_CURRENT_LEVEL_MATCH
|
|
)
|
|
|
|
@staticmethod
|
|
def create(tx, target_values, **kwargs):
|
|
var = VmapIncrementNestingCtxManagerVariable(
|
|
target_values=target_values,
|
|
initial_values=None,
|
|
**kwargs,
|
|
)
|
|
return var
|
|
|
|
def enter(self, tx):
|
|
install_guard(self._guards_singleton)
|
|
batch_size, randomness = self.target_values
|
|
vmap_level = torch._C._functorch._vmap_increment_nesting(batch_size, randomness)
|
|
self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
|
|
self.state.proxy = tx.output.create_node(
|
|
"call_function",
|
|
torch._C._functorch._vmap_increment_nesting,
|
|
(batch_size, randomness),
|
|
{},
|
|
)
|
|
return variables.ConstantVariable.create(vmap_level)
|
|
|
|
def exit(self, tx, *args):
|
|
self.state.cleanup()
|
|
tx.output.create_node(
|
|
"call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
|
|
)
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
|
|
class GradModeVariable(ContextWrappingVariable):
|
|
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""
|
|
|
|
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)
|
|
|
|
@staticmethod
|
|
def create(tx, target_value, initialized=False, **kwargs):
|
|
var = GradModeVariable(
|
|
target_values=[target_value],
|
|
initial_values=[torch.is_grad_enabled()],
|
|
**kwargs,
|
|
)
|
|
if initialized:
|
|
var._call_func(tx, var.target_values)
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, initialized=True, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
install_guard(self._guards_singleton)
|
|
|
|
def enter(self, tx):
|
|
self._call_func(tx, self.target_values)
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def exit(self, tx, *args):
|
|
self._call_func(tx, self.initial_values)
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
):
|
|
self._call_func(tx, self.initial_values) # undo eager initialization
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
def _call_func(self, tx, values):
|
|
assert len(values) == 1
|
|
value = values[0]
|
|
# Coalesce grad mode mutations
|
|
if torch.is_grad_enabled() != value:
|
|
tx.output.create_node(
|
|
"call_function", torch._C._set_grad_enabled, (value,), {}
|
|
)
|
|
torch._C._set_grad_enabled(value)
|
|
|
|
def module_name(self):
|
|
return "torch"
|
|
|
|
def fn_name(self):
|
|
return "set_grad_enabled"
|
|
|
|
|
|
class InferenceModeVariable(ContextWrappingVariable):
|
|
@staticmethod
|
|
def create(tx, target_values, **kwargs):
|
|
var = InferenceModeVariable(
|
|
target_values, initial_values=torch.is_inference_mode_enabled(), **kwargs
|
|
)
|
|
return var
|
|
|
|
def __init__(
|
|
self,
|
|
target_values,
|
|
initial_values=None,
|
|
**kwargs,
|
|
):
|
|
if initial_values is None:
|
|
# This must be called here since function defaults are evaluated at import time
|
|
initial_values = torch.is_inference_mode_enabled()
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
self.target_values = target_values
|
|
|
|
def exit(self, tx, *args):
|
|
self.state.cleanup_assert()
|
|
tx.output.create_node(
|
|
"call_function",
|
|
torch.autograd.grad_mode._exit_inference_mode,
|
|
(self.state.proxy,),
|
|
{},
|
|
)
|
|
|
|
def enter(self, tx):
|
|
ctx = torch.autograd.grad_mode._enter_inference_mode(self.target_values)
|
|
self.set_cleanup_hook(
|
|
tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx)
|
|
)
|
|
self.state.proxy = tx.output.create_node(
|
|
"call_function",
|
|
torch.autograd.grad_mode._enter_inference_mode,
|
|
(self.target_values,),
|
|
{},
|
|
)
|
|
|
|
def module_name(self):
|
|
return "torch.inference_mode"
|
|
|
|
def fn_name(self):
|
|
return "inference_mode"
|
|
|
|
|
|
class TorchFunctionDisableVariable(ContextWrappingVariable):
|
|
"""represents whether torch function overrides are enabled or not"""
|
|
|
|
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)
|
|
|
|
@staticmethod
|
|
def create(tx, **kwargs):
|
|
var = TorchFunctionDisableVariable(
|
|
target_values=[False],
|
|
initial_values=[tx.output.torch_function_enabled],
|
|
**kwargs,
|
|
)
|
|
# mlazos: I think this is here to make sure we don't reinvoke on clone()
|
|
var._call_func(tx, [False])
|
|
var.set_cleanup_hook(tx)
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
install_guard(self._guards_singleton)
|
|
|
|
def enter(self, tx):
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def _call_func(self, tx, values):
|
|
assert len(values) == 1
|
|
tx.output.set_torch_function_state(values[0])
|
|
|
|
|
|
class DeterministicAlgorithmsVariable(ContextWrappingVariable):
|
|
"""represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""
|
|
|
|
_guards_singleton = Guard(
|
|
GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS
|
|
)
|
|
|
|
@staticmethod
|
|
def create(tx, target_value, **kwargs):
|
|
var = DeterministicAlgorithmsVariable(
|
|
target_values=[target_value],
|
|
initial_values=[torch.are_deterministic_algorithms_enabled()],
|
|
**kwargs,
|
|
)
|
|
var._call_func(tx, [target_value])
|
|
var.set_cleanup_hook(tx)
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
install_guard(self._guards_singleton)
|
|
|
|
def enter(self, tx):
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def _call_func(self, tx, values):
|
|
assert len(values) == 1
|
|
value = values[0]
|
|
tx.output.create_node(
|
|
"call_function", torch._C._set_deterministic_algorithms, (value,), {}
|
|
),
|
|
torch._C._set_deterministic_algorithms(value)
|
|
|
|
def module_name(self):
|
|
return "torch"
|
|
|
|
def fn_name(self):
|
|
return "use_deterministic_algorithms"
|
|
|
|
|
|
class DisabledSavedTensorsHooksVariable(ContextWrappingVariable):
|
|
"""represents torch.autograd.graph.disable_saved_tensors_hook."""
|
|
|
|
@staticmethod
|
|
def create(tx, target_value, **kwargs):
|
|
var = DisabledSavedTensorsHooksVariable(
|
|
target_values=[target_value],
|
|
initial_values=[
|
|
torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
|
|
],
|
|
**kwargs,
|
|
)
|
|
var._call_func(tx, [target_value])
|
|
var.set_cleanup_hook(tx)
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
|
|
def enter(self, tx):
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def _call_func(self, tx, values):
|
|
assert len(values) == 1
|
|
value = values[0]
|
|
if value is not None:
|
|
# Disable `saved_tensors_hooks` with message (`value`)
|
|
# OR
|
|
# we are exiting this context and restoring the previous message.
|
|
tx.output.create_node(
|
|
"call_function",
|
|
torch._C._autograd._saved_tensors_hooks_disable,
|
|
(value,),
|
|
{},
|
|
)
|
|
torch._C._autograd._saved_tensors_hooks_disable(value)
|
|
else:
|
|
# We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`.
|
|
tx.output.create_node(
|
|
"call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {}
|
|
)
|
|
torch._C._autograd._saved_tensors_hooks_enable()
|
|
|
|
def module_name(self):
|
|
return "torch.autograd.graph"
|
|
|
|
def fn_name(self):
|
|
return "disable_saved_tensors_hooks"
|
|
|
|
|
|
class AutocastModeVariable(ContextWrappingVariable):
|
|
@staticmethod
|
|
def create(func, args, kwargs):
|
|
assert func in [
|
|
torch.amp.autocast_mode.autocast,
|
|
torch.cuda.amp.autocast,
|
|
torch.cpu.amp.autocast,
|
|
]
|
|
# device_type : str,
|
|
# dtype : Optional[_dtype] = None,
|
|
# enabled : bool = True,
|
|
# cache_enabled : Optional[bool] = None):cache_enabled
|
|
bound_args = inspect.signature(func).bind(*args, **kwargs)
|
|
bound_args.apply_defaults()
|
|
target_values = []
|
|
kwargs.clear()
|
|
|
|
for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
|
|
if key == "device_type" and func in [
|
|
torch.cuda.amp.autocast,
|
|
torch.cpu.amp.autocast,
|
|
]:
|
|
arg = "cuda" if func is torch.cuda.amp.autocast else "cpu"
|
|
else:
|
|
arg = bound_args.arguments[key]
|
|
if isinstance(arg, VariableTracker):
|
|
target_values.append(arg.as_python_constant())
|
|
else:
|
|
target_values.append(arg)
|
|
|
|
var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
|
|
return var
|
|
|
|
def __init__(self, target_values, initial_values=None, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
self.target_values = target_values
|
|
|
|
def exit(self, tx, *args):
|
|
self.state.cleanup_assert()
|
|
tx.output.create_node(
|
|
"call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
|
|
)
|
|
|
|
def enter(self, tx):
|
|
ctx = torch.amp._enter_autocast(*self.target_values)
|
|
self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
|
|
self.state.proxy = tx.output.create_node(
|
|
"call_function", torch.amp._enter_autocast, (*self.target_values,), {}
|
|
)
|
|
|
|
def module_name(self):
|
|
return "torch.amp.autocast_mode"
|
|
|
|
def fn_name(self):
|
|
return "autocast"
|
|
|
|
|
|
class NullContextVariable(ContextWrappingVariable):
|
|
"""
|
|
This class represents Python contextlib.nullcontext.
|
|
It's used as a placeholder for other context managers that Dynamo doesn't
|
|
support yet, e.g, torch.autograd.profiler.record_function.
|
|
"""
|
|
|
|
def __init__(self, target_values=None, **kwargs):
|
|
super().__init__(target_values=target_values, **kwargs)
|
|
|
|
def enter(self, tx):
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def exit(self, tx, *args):
|
|
return variables.ConstantVariable.create(None)
|
|
|
|
def module_name(self):
|
|
return "contextlib"
|
|
|
|
def fn_name(self):
|
|
return "nullcontext"
|
|
|
|
|
|
class StreamContextVariable(ContextWrappingVariable):
|
|
@staticmethod
|
|
def create(tx, target_value, **kwargs):
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
current_stream_method = get_interface_for_device(
|
|
target_value.device
|
|
).current_stream
|
|
current_stream = wrap_fx_proxy_cls(
|
|
StreamVariable,
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
current_stream_method,
|
|
(None,),
|
|
{},
|
|
),
|
|
)
|
|
return StreamContextVariable(
|
|
target_values=[target_value],
|
|
initial_values=[current_stream],
|
|
device=target_value.device,
|
|
**kwargs,
|
|
)
|
|
|
|
def __init__(self, target_values, device, initial_values=None, **kwargs):
|
|
super().__init__(
|
|
target_values=target_values, initial_values=initial_values, **kwargs
|
|
)
|
|
self.device = device
|
|
self.set_stream = get_interface_for_device(self.device).set_stream
|
|
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
|
|
|
|
def enter(self, tx):
|
|
# stream generated inside the traced function
|
|
if self.target_values[0].as_proxy() is not None:
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
self.set_stream,
|
|
(self.target_values[0].as_proxy(),),
|
|
{},
|
|
)
|
|
# stream passed from outside the traced function
|
|
else:
|
|
stream = self.target_values[0].value
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
self.set_stream_id,
|
|
(stream.stream_id, stream.device_index, stream.device_type),
|
|
{},
|
|
)
|
|
self.set_stream(self.target_values[0].value)
|
|
self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
|
|
|
|
def exit(self, tx, *args):
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
self.set_stream,
|
|
(self.initial_values[0].as_proxy(),),
|
|
{},
|
|
)
|
|
self.state.cleanup_assert()
|
|
|
|
|
|
class StreamVariable(VariableTracker):
|
|
def __init__(self, proxy, value, device, **kwargs):
|
|
if proxy is not None and "example_value" in proxy.node.meta:
|
|
assert proxy.node.meta["example_value"] == value
|
|
assert (
|
|
value.device.type == device.type
|
|
), "stream value is not equal to the passed device"
|
|
super().__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.value = value
|
|
self.device = device
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
assert hasattr(self.value, name), f"no stream method found named {name}"
|
|
assert name in [
|
|
"wait_stream",
|
|
"synchronize",
|
|
"query",
|
|
"record_event",
|
|
"wait_event",
|
|
], f" unsupported stream method {name}"
|
|
|
|
from ..utils import proxy_args_kwargs
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
if name in ("wait_stream", "synchronize", "wait_event"):
|
|
tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
)
|
|
return variables.ConstantVariable(None)
|
|
elif name == "query":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=variables.ConstantVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
elif name == "record_event":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=EventVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
else:
|
|
unimplemented(self.device + " stream method " + name + " unsupported")
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
def reconstruct(self, codegen):
|
|
# If we got here, this stream is fully subsumed by the graph - this means it is
|
|
# not an input or global
|
|
assert not self.source
|
|
# Since we just proved that - for other such structures, like lists and dicts, reconstruction
|
|
# is fine and sound according to dynamo principles of treating collectives. However,
|
|
# streams are special in that we want to preserve the identity of the stream as the same as in the graph
|
|
# Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
|
|
# yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
|
|
# design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
|
|
name = f"_stream_{self.device}_{id(self.value)}"
|
|
if name not in codegen.tx.output.global_scope:
|
|
codegen.tx.output.install_global(name, self.value)
|
|
|
|
return [codegen.create_load_global(name, push_null=False, add=True)]
|
|
|
|
|
|
class EventVariable(VariableTracker):
|
|
def __init__(self, proxy, value, **kwargs):
|
|
if proxy is not None and "example_value" in proxy.node.meta:
|
|
assert proxy.node.meta["example_value"] == value
|
|
super().__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.value = value
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from ..utils import proxy_args_kwargs
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
if name in ("wait", "record", "synchronize"):
|
|
tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
)
|
|
return variables.ConstantVariable(None)
|
|
elif name == "query":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=variables.ConstantVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
else:
|
|
unimplemented(f"event method {name} unsupported")
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
|
|
class WithExitFunctionVariable(VariableTracker):
|
|
def __init__(self, ctx: ContextWrappingVariable, target, **kwargs):
|
|
super().__init__(**kwargs)
|
|
assert isinstance(ctx, ContextWrappingVariable)
|
|
self.ctx = ctx
|
|
self.target = target
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
assert not kwargs
|
|
return self.ctx.exit(tx, *args)
|
|
|
|
def reconstruct(self, codegen):
|
|
# Note here we reconstruct the context manager rather than the
|
|
# exit function. The handler generated by BlockStackEntry
|
|
# will re-enter the context in the resume function.
|
|
output = AttrSource(
|
|
codegen.tx.import_source(self.ctx.module_name()), self.ctx.fn_name()
|
|
).reconstruct(codegen)
|
|
|
|
if codegen.tx.output.partial_convert:
|
|
loads = [codegen.create_load_const(val) for val in self.ctx.target_values]
|
|
output.extend(loads)
|
|
output.extend(
|
|
[
|
|
*create_call_function(len(loads), True),
|
|
create_instruction("SETUP_WITH", target=self.target),
|
|
create_instruction("POP_TOP"),
|
|
]
|
|
)
|
|
return output
|