Move Tensor.grad back into C++

`Tensor.grad` was moved to python in #30531 to add a warning. However,
that warning has since been lowered into C++ so this wrapper is no
longer necessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76675

Approved by: https://github.com/albanD
This commit is contained in:
Peter Bell
2022-06-09 16:31:59 +01:00
committed by PyTorch MergeBot
parent dd620c4575
commit 7843a5e882
7 changed files with 19 additions and 36 deletions

View File

@ -1663,6 +1663,7 @@ else:
res = torch.gather(src, dim, idx)
weight = torch.rand_like(res, device=device) * 10 ** 6
res.backward(weight)
assert src.grad is not None
grad = src.grad.detach().clone()
if torch.device(device).type == 'cuda':

View File

@ -945,6 +945,7 @@ class _TensorBase(metaclass=_TensorMeta):
grad_fn: Any
_grad_fn: Any
_grad: Optional[Tensor]
grad: Optional[Tensor]
_backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]]
${tensor_method_hints}

View File

@ -1062,34 +1062,6 @@ class Tensor(torch._C._TensorBase):
else:
return super(Tensor, self).rename(names)
@property
def grad(self):
"""
This attribute is ``None`` by default and becomes a Tensor the first time a call to
:func:`backward` computes gradients for ``self``.
The attribute will then contain the gradients computed and future calls to
:func:`backward` will accumulate (add) gradients into it.
"""
if has_torch_function_unary(self):
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.grad.__get__, (self,), self) # type: ignore[attr-defined]
return self._grad
@grad.setter
def grad(self, new_grad):
if has_torch_function_unary(self):
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.grad.__set__, (self,), self, new_grad) # type: ignore[attr-defined]
self._grad = new_grad
@grad.deleter
def grad(self):
if has_torch_function_unary(self):
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.grad.__delete__, (self,), self) # type: ignore[attr-defined]
del self._grad
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
"""

View File

@ -4989,6 +4989,14 @@ masked_fill(mask, value) -> Tensor
Out-of-place version of :meth:`torch.Tensor.masked_fill_`
""")
add_docstr_all('grad',
r"""
This attribute is ``None`` by default and becomes a Tensor the first time a call to
:func:`backward` computes gradients for ``self``.
The attribute will then contain the gradients computed and future calls to
:func:`backward` will accumulate (add) gradients into it.
""")
add_docstr_all('retain_grad',
r"""
retain_grad() -> None

View File

@ -615,6 +615,7 @@ class Module:
grad_applied = fn(param.grad)
should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
if should_use_set_data:
assert out_param.grad is not None
out_param.grad.data = grad_applied
else:
assert param.grad.is_leaf

View File

@ -29,17 +29,17 @@ def clip_grad_norm_(
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(parameters) == 0:
if len(grads) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
device = grads[0].device
if norm_type == inf:
norms = [p.grad.detach().abs().max().to(device) for p in parameters]
norms = [g.detach().abs().max().to(device) for g in grads]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f'The total norm of order {norm_type} for gradients from '
@ -51,8 +51,8 @@ def clip_grad_norm_(
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for p in parameters:
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))
for g in grads:
g.detach().mul_(clip_coef_clamped.to(g.device))
return total_norm

View File

@ -23,7 +23,7 @@ class _LazyImport:
except ImportError:
# If packaging isn't installed, try and use the vendored copy
# in pkg_resources
from pkg_resources import packaging # type: ignore[attr-defined]
from pkg_resources import packaging # type: ignore[attr-defined, no-redef]
return getattr(packaging.version, self._cls_name)
def __call__(self, *args, **kwargs):