Apply UFMT to low traffic torch modules (#106249)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
This commit is contained in:
Edward Z. Yang
2023-07-29 10:51:26 -04:00
committed by PyTorch MergeBot
parent a4ebc61f15
commit 3bf922a6ce
163 changed files with 8472 additions and 4412 deletions

View File

@ -1,20 +1,29 @@
import torch
import torch._C as _C
from torch._C import _functions
import torch._functorch as _functorch
import torch.utils.hooks as hooks
import functools
import warnings
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
import torch
import torch._C as _C
import torch._functorch as _functorch
import torch.utils.hooks as hooks
from torch._C import _functions
from torch._functorch.autograd_function import custom_function_call
__all__ = ["FunctionCtx", "BackwardCFunction", "FunctionMeta", "Function", "once_differentiable", "traceable",
"InplaceFunction", "NestedIOFunction"]
__all__ = [
"FunctionCtx",
"BackwardCFunction",
"FunctionMeta",
"Function",
"once_differentiable",
"traceable",
"InplaceFunction",
"NestedIOFunction",
]
# Formerly known as: _ContextMethodMixin
class FunctionCtx:
def save_for_backward(self, *tensors: torch.Tensor):
r"""Saves given tensors for a future call to :func:`~Function.backward`.
@ -122,7 +131,8 @@ class FunctionCtx:
for tensor in tensors:
assert isinstance(tensor, torch.Tensor) or tensor is None, (
"save_for_forward expects all arguments to be tensors; you should "
"save non-tensors as attributes on ctx.")
"save non-tensors as attributes on ctx."
)
self.saved_for_forward = tensors
@ -165,9 +175,10 @@ class FunctionCtx:
def mark_shared_storage(self, *pairs):
warnings.warn(
'mark_shared_storage is deprecated. '
'Tensors with shared storages are automatically tracked. Note '
'that calls to `set_()` are not tracked')
"mark_shared_storage is deprecated. "
"Tensors with shared storages are automatically tracked. Note "
"that calls to `set_()` are not tracked"
)
def mark_non_differentiable(self, *args: torch.Tensor):
r"""Marks outputs as non-differentiable.
@ -246,11 +257,12 @@ class FunctionCtx:
"""
self.materialize_grads = value
# DO NOT USE: This is only defined to be able to load old serialized models
_ContextMethodMixin = FunctionCtx
class _HookMixin:
class _HookMixin:
@staticmethod
def _register_hook(backward_hooks, hook):
if backward_hooks is None:
@ -267,9 +279,11 @@ class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
backward_fn = self._forward_cls.backward # type: ignore[attr-defined]
vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined]
if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom "
"Function is not allowed. You should only implement one "
"of them.")
raise RuntimeError(
"Implementing both 'backward' and 'vjp' for a custom "
"Function is not allowed. You should only implement one "
"of them."
)
user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
return user_fn(self, *args)
@ -289,14 +303,19 @@ class FunctionMeta(type):
version of this function (which is generated on the fly by this
metaclass).
"""
def __init__(cls, name, bases, attrs):
backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls})
backward_fn = type(
name + "Backward", (BackwardCFunction,), {"_forward_cls": cls}
)
cls._backward_cls = backward_fn
super().__init__(name, bases, attrs)
class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta):
class _SingleLevelFunction(
_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta
):
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
r"""
@ -338,8 +357,9 @@ class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=
``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`
if they are intended to be used for in ``jvp``.
"""
raise NotImplementedError("You must implement the forward function for custom"
" autograd.Function.")
raise NotImplementedError(
"You must implement the forward function for custom" " autograd.Function."
)
@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any:
@ -381,9 +401,11 @@ class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=
first input to :func:`forward` needs gradient computed w.r.t. the
output.
"""
raise NotImplementedError("You must implement either the backward or vjp method for "
"your custom autograd.Function to use it with backward "
"mode AD.")
raise NotImplementedError(
"You must implement either the backward or vjp method for "
"your custom autograd.Function to use it with backward "
"mode AD."
)
# vjp and backward are alias of each other
vjp = backward
@ -406,8 +428,10 @@ class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=
You can use the :attr:`ctx` object to pass any value from the forward to this
functions.
"""
raise NotImplementedError("You must implement the jvp function for custom "
"autograd.Function to use it with forward mode AD.")
raise NotImplementedError(
"You must implement the jvp function for custom "
"autograd.Function to use it with forward mode AD."
)
class Function(_SingleLevelFunction):
@ -443,18 +467,23 @@ class Function(_SingleLevelFunction):
>>> # xdoctest: +SKIP
>>> output = Exp.apply(input)
"""
def __init__(self, *args, **kwargs):
cls = self.__class__
warnings.warn(f"{cls} should not be instantiated. Methods on autograd functions"
"are all static, so you should invoke them on the class itself. "
"Instantiating an autograd function will raise an "
"error in a future version of PyTorch.", DeprecationWarning)
warnings.warn(
f"{cls} should not be instantiated. Methods on autograd functions"
"are all static, so you should invoke them on the class itself. "
"Instantiating an autograd function will raise an "
"error in a future version of PyTorch.",
DeprecationWarning,
)
def __call__(self, *args, **kwargs):
raise RuntimeError(
"Legacy autograd function with non-static forward method is deprecated. "
"Please use new-style autograd function with static forward method. "
"(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)")
"(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)"
)
# for the tracer
is_traceable = False
@ -499,7 +528,8 @@ class Function(_SingleLevelFunction):
"""
raise NotImplementedError(
"To use autograd.Function with vmap, you must either override the "
"vmap staticmethod or set generate_vmap_rule=True.")
"vmap staticmethod or set generate_vmap_rule=True."
)
@classmethod
def apply(cls, *args, **kwargs):
@ -510,15 +540,16 @@ class Function(_SingleLevelFunction):
if cls.setup_context == _SingleLevelFunction.setup_context:
raise RuntimeError(
'In order to use an autograd.Function with functorch transforms '
'(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
'staticmethod. For more details, please see '
'https://pytorch.org/docs/master/notes/extending.func.html')
"In order to use an autograd.Function with functorch transforms "
"(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
"staticmethod. For more details, please see "
"https://pytorch.org/docs/master/notes/extending.func.html"
)
return custom_function_call(cls, *args, **kwargs)
def once_differentiable(fn):
def once_differentiable(fn):
@functools.wraps(fn)
def wrapper(ctx, *args):
with torch.no_grad():
@ -536,8 +567,9 @@ def once_differentiable(fn):
# Unfortunately, this leads to unexpected error messages ("no nodes
# require computing gradients"), but I don't have a better idea.
# These functions would raise an error in backward anyway.
requires_grad = any(isinstance(arg, torch.Tensor) and arg.requires_grad
for arg in args)
requires_grad = any(
isinstance(arg, torch.Tensor) and arg.requires_grad for arg in args
)
if not requires_grad:
return outputs
@ -546,7 +578,9 @@ def once_differentiable(fn):
err_fn = _functions.DelayedError(
b"trying to differentiate twice a function that was marked "
b"with @once_differentiable", len(outputs))
b"with @once_differentiable",
len(outputs),
)
# Create aliases of each output that has requires_grad=True. We need
# at least one of the inputs to err_fn to require grad so that the
@ -558,6 +592,7 @@ def once_differentiable(fn):
return var
return err_fn(*[fake_requires_grad(v) for v in outputs])
return wrapper
@ -577,7 +612,6 @@ def traceable(fn_cls):
class InplaceFunction(Function):
def __init__(self, inplace=False):
super().__init__()
self.inplace = inplace
@ -591,18 +625,23 @@ def _nested_map(condition, fn, condition_msg=None):
return None
elif isinstance(obj, (list, tuple)):
mapped = (_map(x) for x in obj)
if hasattr(obj, '_fields'):
if hasattr(obj, "_fields"):
# obj is namedtuple
return type(obj)(*mapped)
return type(obj)(mapped)
elif isinstance(obj, dict):
return {x : _map(obj[x]) for x in obj}
return {x: _map(obj[x]) for x in obj}
else:
raise ValueError("Auto nesting doesn't know how to process "
"an input object of type " + torch.typename(obj) +
(". Accepted types: " + condition_msg +
", or lists/tuples of them"
if condition_msg else ""))
raise ValueError(
"Auto nesting doesn't know how to process "
"an input object of type "
+ torch.typename(obj)
+ (
". Accepted types: " + condition_msg + ", or lists/tuples of them"
if condition_msg
else ""
)
)
return _map
@ -613,8 +652,7 @@ def _jit_unwrap_structured(obj):
return obj
def _iter_filter(condition, allow_unknown=False, condition_msg=None,
conversion=None):
def _iter_filter(condition, allow_unknown=False, condition_msg=None, conversion=None):
def _iter(obj):
if conversion is not None:
obj = conversion(obj)
@ -632,11 +670,16 @@ def _iter_filter(condition, allow_unknown=False, condition_msg=None,
elif allow_unknown:
yield obj
else:
raise ValueError("Auto nesting doesn't know how to process "
"an input object of type " + torch.typename(obj) +
(". Accepted types: " + condition_msg +
", or lists/tuples of them"
if condition_msg else ""))
raise ValueError(
"Auto nesting doesn't know how to process "
"an input object of type "
+ torch.typename(obj)
+ (
". Accepted types: " + condition_msg + ", or lists/tuples of them"
if condition_msg
else ""
)
)
return _iter
@ -661,17 +704,26 @@ def _unflatten(input, proto):
return unflatten_helper(input, proto)[0]
_iter_jit_values = _iter_filter(lambda o: o is None or isinstance(o, torch._C.Value),
condition_msg="jit's Values or None")
_iter_tensors = _iter_filter(lambda x: isinstance(x, torch.Tensor), condition_msg="Tensors",
conversion=_jit_unwrap_structured)
_iter_tensors_permissive = _iter_filter(lambda x: isinstance(x, torch.Tensor),
allow_unknown=True,
condition_msg="Tensors (permissive)")
_iter_None_tensors = _iter_filter(lambda o: o is None or isinstance(o, torch.Tensor),
condition_msg="Tensors or None")
_map_tensor_data = _nested_map(lambda x: isinstance(x, torch.Tensor), lambda o: o.data,
condition_msg="Tensors")
_iter_jit_values = _iter_filter(
lambda o: o is None or isinstance(o, torch._C.Value),
condition_msg="jit's Values or None",
)
_iter_tensors = _iter_filter(
lambda x: isinstance(x, torch.Tensor),
condition_msg="Tensors",
conversion=_jit_unwrap_structured,
)
_iter_tensors_permissive = _iter_filter(
lambda x: isinstance(x, torch.Tensor),
allow_unknown=True,
condition_msg="Tensors (permissive)",
)
_iter_None_tensors = _iter_filter(
lambda o: o is None or isinstance(o, torch.Tensor), condition_msg="Tensors or None"
)
_map_tensor_data = _nested_map(
lambda x: isinstance(x, torch.Tensor), lambda o: o.data, condition_msg="Tensors"
)
class NestedIOFunction(Function):