mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
For AC HOP, dynamo traces it without kwargs. (kwargs are only inputs to the HOP, not to the body)
55f01a48af/torch/_dynamo/variables/higher_order_ops.py (L2594-L2609)
When we add non-strict support, we should match this calling convention too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165145
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602, #164431, #164433, #164437
344 lines
13 KiB
Python
344 lines
13 KiB
Python
# mypy: allow-untyped-defs
|
|
import inspect
|
|
import itertools
|
|
import logging
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._higher_order_ops.utils import reenter_make_fx
|
|
from torch._logging import warning_once
|
|
from torch._ops import HigherOrderOperator
|
|
from torch.fx import GraphModule
|
|
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
|
from torch.types import _dtype
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
uid = itertools.count(1)
|
|
|
|
|
|
# Used for testing the HigherOrderOperator mechanism
|
|
class Wrap(HigherOrderOperator):
|
|
def __init__(self) -> None:
|
|
super().__init__("wrap")
|
|
|
|
def __call__(self, func, *args, **kwargs):
|
|
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
|
# so no need to trace into it.
|
|
import torch._dynamo # noqa: F401
|
|
from torch._dynamo import disable
|
|
|
|
@disable
|
|
def wrapper():
|
|
result = func(*args, **kwargs)
|
|
return result
|
|
|
|
return wrapper()
|
|
|
|
|
|
wrap = Wrap()
|
|
|
|
|
|
class WrapWithSetGradEnabled(HigherOrderOperator):
|
|
def __init__(self) -> None:
|
|
super().__init__("wrap_with_set_grad_enabled")
|
|
|
|
def __call__(self, enable_grad, wrapped_func, *args, **kwargs):
|
|
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
|
# so no need to trace into it.
|
|
import torch._dynamo # noqa: F401
|
|
from torch._dynamo import disable
|
|
|
|
@disable
|
|
def wrapper():
|
|
prev = torch.is_grad_enabled()
|
|
torch.set_grad_enabled(enable_grad)
|
|
res = wrapped_func(*args, **kwargs)
|
|
torch.set_grad_enabled(prev)
|
|
return res
|
|
|
|
return wrapper()
|
|
|
|
|
|
wrap_with_set_grad_enabled = WrapWithSetGradEnabled()
|
|
|
|
|
|
class WrapWithAutocast(HigherOrderOperator):
|
|
def __init__(self):
|
|
super().__init__("wrap_with_autocast")
|
|
|
|
def __call__(
|
|
self,
|
|
device_type: str,
|
|
dtype: Optional[_dtype],
|
|
enabled: bool,
|
|
cache_enabled: Optional[bool],
|
|
wrapped_func,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
|
# so no need to trace into it.
|
|
import torch._dynamo # noqa: F401
|
|
from torch._dynamo import disable
|
|
|
|
@disable
|
|
def wrapper():
|
|
with torch.autocast(device_type, dtype, enabled, cache_enabled):
|
|
return wrapped_func(*args, **kwargs)
|
|
|
|
return wrapper()
|
|
|
|
|
|
wrap_with_autocast = WrapWithAutocast()
|
|
|
|
|
|
# This HOP allows you to bypass dynamo tracing of the wrapper function while
|
|
# still tracing the inner function.
|
|
# Takes two callables: The first, `wrapper_fn`, accepts `inner_fn` and returns a
|
|
# callable with the same signature. The second is the `inner_fn` itself. Any
|
|
# extra *args and **kwargs are forwarded to `wrapper_fn(inner_fn)` when it is
|
|
# executed.
|
|
class DynamoBypassingWrapper(HigherOrderOperator):
|
|
def __init__(self):
|
|
super().__init__("dynamo_bypassing_wrapper")
|
|
|
|
def __call__(
|
|
self,
|
|
wrapper_fn_or_key,
|
|
inner_fn,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
|
# so no need to trace into it.
|
|
import torch._dynamo # noqa: F401
|
|
from torch._dynamo import disable
|
|
|
|
is_compiling = isinstance(wrapper_fn_or_key, str)
|
|
if is_compiling:
|
|
assert isinstance(inner_fn, torch.fx.GraphModule)
|
|
wrapper_fn = inner_fn.meta[wrapper_fn_or_key]
|
|
else:
|
|
wrapper_fn = wrapper_fn_or_key
|
|
|
|
@disable
|
|
def wrapper():
|
|
return wrapper_fn(inner_fn)(*args, **kwargs)
|
|
|
|
return wrapper()
|
|
|
|
|
|
dynamo_bypassing_wrapper = DynamoBypassingWrapper()
|
|
|
|
|
|
class WrapActivationCheckpoint(HigherOrderOperator):
|
|
"""
|
|
This operator is used to wrap torch.utils.checkpoint. This avoids
|
|
TorchDynamo to look into saved tensor hooks and directly passes the control
|
|
to AOT Autograd, which is ok with tracing saved tensor hooks. As a result of
|
|
AOT tracing torch.utils.checkpoint code, we have a backward graph with
|
|
recomputed forward nodes.
|
|
|
|
However, we might deprecate this operator soon. The difficulty arises in the
|
|
functionalization of rng ops. Today, there are two different
|
|
functionalization of rng ops - one at AOT autograd and other at Inductor.
|
|
And they are difficult to map to each other. The rng states also complicate
|
|
pattern matching in Inductor. Due to the ease of implementation, we are
|
|
currently inclined towards functionalization at Inductor level, which means
|
|
that duplication/recomputation is done as a compiler pass in the
|
|
partitioners. See TagActivationCheckpoint for more information.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__("wrap_activation_checkpoint", cacheable=False)
|
|
|
|
def __call__(self, function, *args, **kwargs):
|
|
# use_reentrant is set to False because this op is going to be traced.
|
|
# And we ensure that AOT Autograd traces through the non reentrant
|
|
# version of checkpointing.
|
|
import torch.fx.traceback as fx_traceback
|
|
from torch.fx import Interpreter
|
|
|
|
kwargs["use_reentrant"] = False
|
|
kwargs["preserve_rng_state"] = False
|
|
# Using interpreter allows preservation of metadata through torch.compile stack.
|
|
with fx_traceback.preserve_node_meta():
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
return checkpoint(Interpreter(function).run, *args, **kwargs)
|
|
|
|
|
|
wrap_activation_checkpoint = WrapActivationCheckpoint()
|
|
|
|
|
|
class TagActivationCheckpoint(HigherOrderOperator):
|
|
"""
|
|
This operator is supposed to be used only with torch.compile stack. This
|
|
accepts a Fx graph module which needs to be checkpointed. This operator adds
|
|
"recomputable" tag to the nodes of the Fx graph that should be recomputed.
|
|
|
|
The goal is to:
|
|
1. Avoid using Dynamo to trace through saved tensor hooks.
|
|
2. For selective checkpointing case, let AOTAutograd trace through
|
|
saved tensor hooks but has special logic with TorchDispatchMode to override
|
|
the usual saved_tensor_hooks fn logic in order to tag the nodes.
|
|
3. Rely on the partitioners to actually duplicate the nodes.
|
|
This sits well in the torch.compile stack, because by the time graph
|
|
reaches partitioner, inductor has already run its functionalization of rng
|
|
ops (by setting fixed seed for each random op, see `replace_random_passes`).
|
|
Therefore, the duplication of nodes, by design, respects the rng states in
|
|
the forward and recomputed forward in backward.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__("tag_activation_checkpoint", cacheable=False)
|
|
|
|
@staticmethod
|
|
def divide_kwargs(kwargs):
|
|
"""
|
|
checkpoint fn can have mixed kwargs between checkpointed fn and
|
|
checkpoint fn itself. For example
|
|
>> def gn(x, y, z=None):
|
|
>> a = torch.matmul(x, y)
|
|
>> if z is not None:
|
|
>> return torch.matmul(a, z)
|
|
>> return a
|
|
>> def fn(x, y, z):
|
|
>> return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
|
|
In the above case, z belongs to checkpointed function gn, but
|
|
use_reentrant belongs to the checkpoint function. This function splits
|
|
the kwargs into checkpoint_kwargs and gmod_kwargs (or
|
|
checkpointed_fn_kwargs).
|
|
We do sorting to ensure same graph from run to run for better
|
|
debuggability. It is not required for correctness.
|
|
"""
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
ckpt_signature = inspect.signature(checkpoint)
|
|
checkpoint_keys = set()
|
|
for name in ckpt_signature.parameters:
|
|
if name in ("function", "args", "kwargs"):
|
|
continue
|
|
checkpoint_keys.add(name)
|
|
|
|
# `preserve_rng_state` is not a regular kwarg
|
|
checkpoint_keys.add("preserve_rng_state")
|
|
|
|
checkpoint_kwargs = {
|
|
name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys
|
|
}
|
|
gmod_kwargs = {
|
|
name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys
|
|
}
|
|
return checkpoint_kwargs, gmod_kwargs
|
|
|
|
@staticmethod
|
|
def tag_nodes(gmod, is_sac):
|
|
from torch.utils.checkpoint import CheckpointPolicy
|
|
|
|
unique_graph_id = next(uid)
|
|
for node in gmod.graph.nodes:
|
|
if node.op in ("call_function", "call_method", "call_module"):
|
|
node.meta["ac_graph_id"] = unique_graph_id
|
|
if is_sac:
|
|
# For selective checkpointing, we will populate this tag later in _CachingTorchDispatchMode.
|
|
node.meta["recompute"] = None
|
|
else:
|
|
# Under vanilla activation checkpointing, all nodes should be recomputed.
|
|
node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
|
|
return gmod
|
|
|
|
def __call__(self, gmod, *args, **kwargs):
|
|
dispatch_key_set = torch._ops._compute_keyset(
|
|
args, kwargs, self.non_fallthrough_keys
|
|
)
|
|
dispatch_key = dispatch_key_set.highestPriorityTypeId()
|
|
if dispatch_key == torch._C.DispatchKey.PreDispatch:
|
|
return super().__call__(gmod, *args, **kwargs)
|
|
|
|
return tag_activation_checkpoint_impl(gmod, *args, **kwargs)
|
|
|
|
|
|
tag_activation_checkpoint = TagActivationCheckpoint()
|
|
|
|
|
|
def tag_activation_checkpoint_impl(gmod, *args, **kwargs):
|
|
import torch.fx.traceback as fx_traceback
|
|
from torch.fx import Interpreter
|
|
|
|
if "_checkpoint_context_fn" in gmod.meta:
|
|
warning_once(
|
|
log,
|
|
"""
|
|
Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
|
|
Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_).
|
|
""",
|
|
)
|
|
# use_reentrant is set to False because this op is going to be traced.
|
|
# And we ensure that AOT Autograd traces through the non reentrant
|
|
# version of checkpointing.
|
|
kwargs["use_reentrant"] = False
|
|
# preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through
|
|
# `torch.random.fork_rng` op (which is not supported yet under CUDA).
|
|
# This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state
|
|
# regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor
|
|
# instead of in AOTAutograd).
|
|
kwargs["preserve_rng_state"] = False
|
|
kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"]
|
|
# We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag
|
|
# for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py.
|
|
gmod = TagActivationCheckpoint.tag_nodes(gmod, is_sac=True)
|
|
# Using interpreter allows preservation of metadata through torch.compile stack.
|
|
with fx_traceback.preserve_node_meta():
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
return checkpoint(Interpreter(gmod).run, *args, **kwargs)
|
|
else:
|
|
gmod = TagActivationCheckpoint.tag_nodes(gmod, is_sac=False)
|
|
# Using interpreter allows preservation of metadata through torch.compile stack.
|
|
# TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here
|
|
# as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile.
|
|
# (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test)
|
|
with fx_traceback.preserve_node_meta():
|
|
return Interpreter(gmod).run(*args)
|
|
|
|
|
|
@tag_activation_checkpoint.py_impl(ProxyTorchDispatchMode)
|
|
def proxy_mode_key(
|
|
proxy_mode: ProxyTorchDispatchMode,
|
|
gmod: GraphModule,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> tuple[torch.Tensor]:
|
|
import torch.fx.traceback as fx_traceback
|
|
from torch.fx import Interpreter
|
|
|
|
assert proxy_mode.pre_dispatch, (
|
|
"post-dispatch mode should have inlined in the Autograd key"
|
|
)
|
|
example_out = tag_activation_checkpoint(gmod, *args, **kwargs)
|
|
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) # type: ignore[union-attr]
|
|
proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, kwargs) # type: ignore[union-attr]
|
|
qualname = proxy_mode.tracer.get_fresh_qualname("wrap_body") # type: ignore[union-attr]
|
|
|
|
# TODO (tmanlaibaatar) don't we need flat_apply here??
|
|
# Dynamo already traced the gmod body without kwargs
|
|
flat_args, _ = pytree.tree_flatten(args)
|
|
with fx_traceback.preserve_node_meta():
|
|
gmod_aten = reenter_make_fx(Interpreter(gmod).run)(*flat_args)
|
|
gmod_aten.meta["_checkpoint_context_fn"] = gmod.meta["_checkpoint_context_fn"]
|
|
proxy_mode.tracer.root.register_module(qualname, gmod_aten) # type: ignore[union-attr]
|
|
proxy_gmod = proxy_mode.tracer.unwrap_proxy(gmod_aten) # type: ignore[union-attr, call-overload]
|
|
out_proxy = proxy_mode.tracer.create_proxy(
|
|
"call_function",
|
|
tag_activation_checkpoint,
|
|
(proxy_gmod, *proxy_args),
|
|
proxy_kwargs,
|
|
)
|
|
return track_tensor_tree(
|
|
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
|
)
|