mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132415 Approved by: https://github.com/XuehaiPan, https://github.com/jamesjwu
140 lines
3.7 KiB
Python
140 lines
3.7 KiB
Python
# mypy: allow-untyped-defs
|
|
# This module contains functions that *will be allowed* by dynamo
|
|
|
|
import functools
|
|
from typing import List
|
|
from typing_extensions import deprecated
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
|
|
|
|
try:
|
|
import numpy as np
|
|
except ModuleNotFoundError:
|
|
np = None # type: ignore[assignment]
|
|
|
|
|
|
@deprecated(
|
|
"`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
|
|
category=FutureWarning,
|
|
)
|
|
def is_compiling() -> bool:
|
|
"""
|
|
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
|
|
|
|
If need to check specifically that TorchDynamo is used, then use
|
|
torch.compiler.is_dynamo_compiling().
|
|
|
|
TODO(khabinov): we should deprecate this function and use one of these two:
|
|
* torch.compiler.is_compiling(),
|
|
* torch.compiler.is_dynamo_compiling().
|
|
It will depend on the context where to use what.
|
|
"""
|
|
return torch.compiler.is_compiling()
|
|
|
|
|
|
def wrap_inline(fn):
|
|
"""
|
|
Create an extra frame around fn that is not in skipfiles
|
|
"""
|
|
|
|
@functools.wraps(fn)
|
|
def inner(*args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
|
|
return inner
|
|
|
|
|
|
def call_hook(hook, *args):
|
|
"""
|
|
Used by compiled autograd to handle hook returning None
|
|
"""
|
|
result = hook(*args)
|
|
if result is None:
|
|
return args[0]
|
|
return result
|
|
|
|
|
|
def wrap_numpy(f):
|
|
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
|
|
from ``torch.Tensor``s to ``torch.Tensor``s.
|
|
"""
|
|
if not np:
|
|
return f
|
|
|
|
@functools.wraps(f)
|
|
def wrap(*args, **kwargs):
|
|
args, kwargs = pytree.tree_map_only(
|
|
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
|
|
)
|
|
out = f(*args, **kwargs)
|
|
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
|
|
|
|
return wrap
|
|
|
|
|
|
class FakeBackwardCFunction:
|
|
def __init__(
|
|
self,
|
|
real: torch.autograd.function.BackwardCFunction,
|
|
saved_tensors: List[torch.Tensor],
|
|
) -> None:
|
|
self.real = real
|
|
self.saved_tensors = saved_tensors
|
|
|
|
def __getattr__(self, name):
|
|
# route any attribute that isn't defined on this obj
|
|
return getattr(self.real, name)
|
|
|
|
|
|
# This function corresponds to the "eager" implementation of a lifted autograd.Function.backward
|
|
def call_backward(backward_c_function, saved_tensors, *args):
|
|
fake = FakeBackwardCFunction(backward_c_function, saved_tensors)
|
|
grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined]
|
|
|
|
# in eager, we wrap in a tuple when there's only one grad output
|
|
if type(grads) is not tuple:
|
|
grads = (grads,)
|
|
|
|
return grads
|
|
|
|
|
|
def untyped_storage_size(x: torch.Tensor):
|
|
return x.untyped_storage().size()
|
|
|
|
|
|
class FakeCompiledAutogradEngine:
|
|
@staticmethod
|
|
def queue_callback(final_callbacks, cb):
|
|
final_callbacks.append(cb)
|
|
|
|
@staticmethod
|
|
def exec_final_callbacks(final_callbacks):
|
|
i = 0
|
|
while i < len(final_callbacks):
|
|
cb = final_callbacks[i]
|
|
cb()
|
|
i += 1
|
|
final_callbacks.clear()
|
|
|
|
@staticmethod
|
|
def _exec_final_callbacks_stub():
|
|
pass
|
|
|
|
|
|
def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs):
|
|
return getattr(bw_state, hook_name)(*args, **kwargs)
|
|
|
|
|
|
def call_module_hooks_from_backward_state(
|
|
_, result, *args, bw_state, hooks_name: str, module_name: str
|
|
):
|
|
module = getattr(bw_state, module_name)
|
|
hooks = getattr(bw_state, hooks_name)
|
|
for hook in hooks:
|
|
new_result = hook(module, result, *args)
|
|
if new_result is not None:
|
|
result = new_result
|
|
return result
|