mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
a4ebc61f15
commit
3bf922a6ce
@ -1,20 +1,29 @@
|
||||
import torch
|
||||
import torch._C as _C
|
||||
from torch._C import _functions
|
||||
import torch._functorch as _functorch
|
||||
import torch.utils.hooks as hooks
|
||||
import functools
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch._C as _C
|
||||
import torch._functorch as _functorch
|
||||
import torch.utils.hooks as hooks
|
||||
from torch._C import _functions
|
||||
from torch._functorch.autograd_function import custom_function_call
|
||||
|
||||
__all__ = ["FunctionCtx", "BackwardCFunction", "FunctionMeta", "Function", "once_differentiable", "traceable",
|
||||
"InplaceFunction", "NestedIOFunction"]
|
||||
__all__ = [
|
||||
"FunctionCtx",
|
||||
"BackwardCFunction",
|
||||
"FunctionMeta",
|
||||
"Function",
|
||||
"once_differentiable",
|
||||
"traceable",
|
||||
"InplaceFunction",
|
||||
"NestedIOFunction",
|
||||
]
|
||||
|
||||
|
||||
# Formerly known as: _ContextMethodMixin
|
||||
class FunctionCtx:
|
||||
|
||||
def save_for_backward(self, *tensors: torch.Tensor):
|
||||
r"""Saves given tensors for a future call to :func:`~Function.backward`.
|
||||
|
||||
@ -122,7 +131,8 @@ class FunctionCtx:
|
||||
for tensor in tensors:
|
||||
assert isinstance(tensor, torch.Tensor) or tensor is None, (
|
||||
"save_for_forward expects all arguments to be tensors; you should "
|
||||
"save non-tensors as attributes on ctx.")
|
||||
"save non-tensors as attributes on ctx."
|
||||
)
|
||||
|
||||
self.saved_for_forward = tensors
|
||||
|
||||
@ -165,9 +175,10 @@ class FunctionCtx:
|
||||
|
||||
def mark_shared_storage(self, *pairs):
|
||||
warnings.warn(
|
||||
'mark_shared_storage is deprecated. '
|
||||
'Tensors with shared storages are automatically tracked. Note '
|
||||
'that calls to `set_()` are not tracked')
|
||||
"mark_shared_storage is deprecated. "
|
||||
"Tensors with shared storages are automatically tracked. Note "
|
||||
"that calls to `set_()` are not tracked"
|
||||
)
|
||||
|
||||
def mark_non_differentiable(self, *args: torch.Tensor):
|
||||
r"""Marks outputs as non-differentiable.
|
||||
@ -246,11 +257,12 @@ class FunctionCtx:
|
||||
"""
|
||||
self.materialize_grads = value
|
||||
|
||||
|
||||
# DO NOT USE: This is only defined to be able to load old serialized models
|
||||
_ContextMethodMixin = FunctionCtx
|
||||
|
||||
class _HookMixin:
|
||||
|
||||
class _HookMixin:
|
||||
@staticmethod
|
||||
def _register_hook(backward_hooks, hook):
|
||||
if backward_hooks is None:
|
||||
@ -267,9 +279,11 @@ class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
|
||||
backward_fn = self._forward_cls.backward # type: ignore[attr-defined]
|
||||
vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined]
|
||||
if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
|
||||
raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom "
|
||||
"Function is not allowed. You should only implement one "
|
||||
"of them.")
|
||||
raise RuntimeError(
|
||||
"Implementing both 'backward' and 'vjp' for a custom "
|
||||
"Function is not allowed. You should only implement one "
|
||||
"of them."
|
||||
)
|
||||
user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
|
||||
return user_fn(self, *args)
|
||||
|
||||
@ -289,14 +303,19 @@ class FunctionMeta(type):
|
||||
version of this function (which is generated on the fly by this
|
||||
metaclass).
|
||||
"""
|
||||
|
||||
def __init__(cls, name, bases, attrs):
|
||||
backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls})
|
||||
backward_fn = type(
|
||||
name + "Backward", (BackwardCFunction,), {"_forward_cls": cls}
|
||||
)
|
||||
cls._backward_cls = backward_fn
|
||||
|
||||
super().__init__(name, bases, attrs)
|
||||
|
||||
|
||||
class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta):
|
||||
class _SingleLevelFunction(
|
||||
_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta
|
||||
):
|
||||
@staticmethod
|
||||
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""
|
||||
@ -338,8 +357,9 @@ class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=
|
||||
``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`
|
||||
if they are intended to be used for in ``jvp``.
|
||||
"""
|
||||
raise NotImplementedError("You must implement the forward function for custom"
|
||||
" autograd.Function.")
|
||||
raise NotImplementedError(
|
||||
"You must implement the forward function for custom" " autograd.Function."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any:
|
||||
@ -381,9 +401,11 @@ class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=
|
||||
first input to :func:`forward` needs gradient computed w.r.t. the
|
||||
output.
|
||||
"""
|
||||
raise NotImplementedError("You must implement either the backward or vjp method for "
|
||||
"your custom autograd.Function to use it with backward "
|
||||
"mode AD.")
|
||||
raise NotImplementedError(
|
||||
"You must implement either the backward or vjp method for "
|
||||
"your custom autograd.Function to use it with backward "
|
||||
"mode AD."
|
||||
)
|
||||
|
||||
# vjp and backward are alias of each other
|
||||
vjp = backward
|
||||
@ -406,8 +428,10 @@ class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=
|
||||
You can use the :attr:`ctx` object to pass any value from the forward to this
|
||||
functions.
|
||||
"""
|
||||
raise NotImplementedError("You must implement the jvp function for custom "
|
||||
"autograd.Function to use it with forward mode AD.")
|
||||
raise NotImplementedError(
|
||||
"You must implement the jvp function for custom "
|
||||
"autograd.Function to use it with forward mode AD."
|
||||
)
|
||||
|
||||
|
||||
class Function(_SingleLevelFunction):
|
||||
@ -443,18 +467,23 @@ class Function(_SingleLevelFunction):
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> output = Exp.apply(input)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
cls = self.__class__
|
||||
warnings.warn(f"{cls} should not be instantiated. Methods on autograd functions"
|
||||
"are all static, so you should invoke them on the class itself. "
|
||||
"Instantiating an autograd function will raise an "
|
||||
"error in a future version of PyTorch.", DeprecationWarning)
|
||||
warnings.warn(
|
||||
f"{cls} should not be instantiated. Methods on autograd functions"
|
||||
"are all static, so you should invoke them on the class itself. "
|
||||
"Instantiating an autograd function will raise an "
|
||||
"error in a future version of PyTorch.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"Legacy autograd function with non-static forward method is deprecated. "
|
||||
"Please use new-style autograd function with static forward method. "
|
||||
"(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)")
|
||||
"(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)"
|
||||
)
|
||||
|
||||
# for the tracer
|
||||
is_traceable = False
|
||||
@ -499,7 +528,8 @@ class Function(_SingleLevelFunction):
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"To use autograd.Function with vmap, you must either override the "
|
||||
"vmap staticmethod or set generate_vmap_rule=True.")
|
||||
"vmap staticmethod or set generate_vmap_rule=True."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kwargs):
|
||||
@ -510,15 +540,16 @@ class Function(_SingleLevelFunction):
|
||||
|
||||
if cls.setup_context == _SingleLevelFunction.setup_context:
|
||||
raise RuntimeError(
|
||||
'In order to use an autograd.Function with functorch transforms '
|
||||
'(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
|
||||
'staticmethod. For more details, please see '
|
||||
'https://pytorch.org/docs/master/notes/extending.func.html')
|
||||
"In order to use an autograd.Function with functorch transforms "
|
||||
"(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
|
||||
"staticmethod. For more details, please see "
|
||||
"https://pytorch.org/docs/master/notes/extending.func.html"
|
||||
)
|
||||
|
||||
return custom_function_call(cls, *args, **kwargs)
|
||||
|
||||
def once_differentiable(fn):
|
||||
|
||||
def once_differentiable(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(ctx, *args):
|
||||
with torch.no_grad():
|
||||
@ -536,8 +567,9 @@ def once_differentiable(fn):
|
||||
# Unfortunately, this leads to unexpected error messages ("no nodes
|
||||
# require computing gradients"), but I don't have a better idea.
|
||||
# These functions would raise an error in backward anyway.
|
||||
requires_grad = any(isinstance(arg, torch.Tensor) and arg.requires_grad
|
||||
for arg in args)
|
||||
requires_grad = any(
|
||||
isinstance(arg, torch.Tensor) and arg.requires_grad for arg in args
|
||||
)
|
||||
if not requires_grad:
|
||||
return outputs
|
||||
|
||||
@ -546,7 +578,9 @@ def once_differentiable(fn):
|
||||
|
||||
err_fn = _functions.DelayedError(
|
||||
b"trying to differentiate twice a function that was marked "
|
||||
b"with @once_differentiable", len(outputs))
|
||||
b"with @once_differentiable",
|
||||
len(outputs),
|
||||
)
|
||||
|
||||
# Create aliases of each output that has requires_grad=True. We need
|
||||
# at least one of the inputs to err_fn to require grad so that the
|
||||
@ -558,6 +592,7 @@ def once_differentiable(fn):
|
||||
return var
|
||||
|
||||
return err_fn(*[fake_requires_grad(v) for v in outputs])
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ -577,7 +612,6 @@ def traceable(fn_cls):
|
||||
|
||||
|
||||
class InplaceFunction(Function):
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
@ -591,18 +625,23 @@ def _nested_map(condition, fn, condition_msg=None):
|
||||
return None
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
mapped = (_map(x) for x in obj)
|
||||
if hasattr(obj, '_fields'):
|
||||
if hasattr(obj, "_fields"):
|
||||
# obj is namedtuple
|
||||
return type(obj)(*mapped)
|
||||
return type(obj)(mapped)
|
||||
elif isinstance(obj, dict):
|
||||
return {x : _map(obj[x]) for x in obj}
|
||||
return {x: _map(obj[x]) for x in obj}
|
||||
else:
|
||||
raise ValueError("Auto nesting doesn't know how to process "
|
||||
"an input object of type " + torch.typename(obj) +
|
||||
(". Accepted types: " + condition_msg +
|
||||
", or lists/tuples of them"
|
||||
if condition_msg else ""))
|
||||
raise ValueError(
|
||||
"Auto nesting doesn't know how to process "
|
||||
"an input object of type "
|
||||
+ torch.typename(obj)
|
||||
+ (
|
||||
". Accepted types: " + condition_msg + ", or lists/tuples of them"
|
||||
if condition_msg
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
return _map
|
||||
|
||||
@ -613,8 +652,7 @@ def _jit_unwrap_structured(obj):
|
||||
return obj
|
||||
|
||||
|
||||
def _iter_filter(condition, allow_unknown=False, condition_msg=None,
|
||||
conversion=None):
|
||||
def _iter_filter(condition, allow_unknown=False, condition_msg=None, conversion=None):
|
||||
def _iter(obj):
|
||||
if conversion is not None:
|
||||
obj = conversion(obj)
|
||||
@ -632,11 +670,16 @@ def _iter_filter(condition, allow_unknown=False, condition_msg=None,
|
||||
elif allow_unknown:
|
||||
yield obj
|
||||
else:
|
||||
raise ValueError("Auto nesting doesn't know how to process "
|
||||
"an input object of type " + torch.typename(obj) +
|
||||
(". Accepted types: " + condition_msg +
|
||||
", or lists/tuples of them"
|
||||
if condition_msg else ""))
|
||||
raise ValueError(
|
||||
"Auto nesting doesn't know how to process "
|
||||
"an input object of type "
|
||||
+ torch.typename(obj)
|
||||
+ (
|
||||
". Accepted types: " + condition_msg + ", or lists/tuples of them"
|
||||
if condition_msg
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
return _iter
|
||||
|
||||
@ -661,17 +704,26 @@ def _unflatten(input, proto):
|
||||
return unflatten_helper(input, proto)[0]
|
||||
|
||||
|
||||
_iter_jit_values = _iter_filter(lambda o: o is None or isinstance(o, torch._C.Value),
|
||||
condition_msg="jit's Values or None")
|
||||
_iter_tensors = _iter_filter(lambda x: isinstance(x, torch.Tensor), condition_msg="Tensors",
|
||||
conversion=_jit_unwrap_structured)
|
||||
_iter_tensors_permissive = _iter_filter(lambda x: isinstance(x, torch.Tensor),
|
||||
allow_unknown=True,
|
||||
condition_msg="Tensors (permissive)")
|
||||
_iter_None_tensors = _iter_filter(lambda o: o is None or isinstance(o, torch.Tensor),
|
||||
condition_msg="Tensors or None")
|
||||
_map_tensor_data = _nested_map(lambda x: isinstance(x, torch.Tensor), lambda o: o.data,
|
||||
condition_msg="Tensors")
|
||||
_iter_jit_values = _iter_filter(
|
||||
lambda o: o is None or isinstance(o, torch._C.Value),
|
||||
condition_msg="jit's Values or None",
|
||||
)
|
||||
_iter_tensors = _iter_filter(
|
||||
lambda x: isinstance(x, torch.Tensor),
|
||||
condition_msg="Tensors",
|
||||
conversion=_jit_unwrap_structured,
|
||||
)
|
||||
_iter_tensors_permissive = _iter_filter(
|
||||
lambda x: isinstance(x, torch.Tensor),
|
||||
allow_unknown=True,
|
||||
condition_msg="Tensors (permissive)",
|
||||
)
|
||||
_iter_None_tensors = _iter_filter(
|
||||
lambda o: o is None or isinstance(o, torch.Tensor), condition_msg="Tensors or None"
|
||||
)
|
||||
_map_tensor_data = _nested_map(
|
||||
lambda x: isinstance(x, torch.Tensor), lambda o: o.data, condition_msg="Tensors"
|
||||
)
|
||||
|
||||
|
||||
class NestedIOFunction(Function):
|
||||
|
Reference in New Issue
Block a user