mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This adds support for backwards hooks that are *both*: 1) Interior to the graph; and 2) Dynamically generated (e.g. lambdas) We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382 Approved by: https://github.com/xmfan
81 lines
1.8 KiB
Python
81 lines
1.8 KiB
Python
# This module contains functions that *will be allowed* by dynamo
|
|
|
|
import functools
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
|
|
try:
|
|
import numpy as np
|
|
except ModuleNotFoundError:
|
|
np = None # type: ignore[assignment]
|
|
|
|
|
|
def is_compiling() -> bool:
|
|
return False
|
|
|
|
|
|
def wrap_inline(fn):
|
|
"""
|
|
Create an extra frame around fn that is not in skipfiles
|
|
"""
|
|
|
|
@functools.wraps(fn)
|
|
def inner(*args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
|
|
return inner
|
|
|
|
|
|
def call_hook(hook, *args):
|
|
"""
|
|
Used by compiled autograd to handle hook returning None
|
|
"""
|
|
result = hook(*args)
|
|
if result is None:
|
|
return args[0]
|
|
return result
|
|
|
|
|
|
def wrap_numpy(f):
|
|
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
|
|
from ``torch.Tensor``s to ``torch.Tensor``s.
|
|
"""
|
|
if not np:
|
|
return f
|
|
|
|
@functools.wraps(f)
|
|
def wrap(*args, **kwargs):
|
|
args, kwargs = pytree.tree_map_only(
|
|
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
|
|
)
|
|
out = f(*args, **kwargs)
|
|
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
|
|
|
|
return wrap
|
|
|
|
|
|
class FakeContext:
|
|
def __init__(self, saved_tensors):
|
|
# this will cache the results of saved_tensors
|
|
# and will no longer call into c++ binding
|
|
self.saved_tensors = saved_tensors
|
|
|
|
|
|
def call_backward(backward_fn, saved_tensors, *args):
|
|
grads = backward_fn(FakeContext(saved_tensors), *args)
|
|
|
|
# in eager, we wrap in a tuple when there's only one grad output
|
|
if type(grads) is not tuple:
|
|
grads = (grads,)
|
|
|
|
return grads
|
|
|
|
|
|
def untyped_storage_size(x: torch.Tensor):
|
|
return x.untyped_storage().size()
|
|
|
|
|
|
def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs):
|
|
return getattr(bw_state, hook_name)(*args, **kwargs)
|