Add warning when accessing Tensor::grad() in the C++ API (#59362)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/35379

 - Adds  `retains_grad` attribute backed by cpp as a native function. The python bindings for the function are skipped to be consistent with `is_leaf`.
   - Tried writing it without native function, but the jit test `test_tensor_properties` seems to require that it be a native function (or alternatively maybe it could also work if we manually add a prim implementation?).
 - Python API now uses `retain_grad` implementation from cpp

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59362

Reviewed By: jbschlosser

Differential Revision: D28969298

Pulled By: soulitzer

fbshipit-source-id: 335f2be50b9fb870cd35dc72f7dadd6c8666cc02
This commit is contained in:
Jeffrey Wan
2021-06-08 19:40:03 -07:00
committed by Facebook GitHub Bot
parent 90303157ab
commit f52e202840
14 changed files with 96 additions and 40 deletions

View File

@ -76,6 +76,10 @@ void Tensor::retain_grad() const {
impl::GetVariableHooks()->retain_grad(*this);
}
bool Tensor::retains_grad() const {
return impl::GetVariableHooks()->retains_grad(*this);
}
void Tensor::_backward(TensorList inputs,
const c10::optional<Tensor>& gradient,
c10::optional<bool> keep_graph,

View File

@ -54,6 +54,7 @@ struct TORCH_API VariableHooksInterface {
virtual Tensor data(const Tensor&) const = 0;
virtual int64_t _version(const Tensor&) const = 0;
virtual void retain_grad(const Tensor&) const = 0;
virtual bool retains_grad(const Tensor&) const = 0;
virtual void _backward(const Tensor&, TensorList, const c10::optional<Tensor>&, c10::optional<bool>, bool) const = 0;
virtual void requires_grad_(const Tensor&, bool) const = 0;
};

View File

@ -40,6 +40,10 @@ void retain_grad(Tensor& self) {
return self.retain_grad();
}
bool retains_grad(const Tensor& self) {
return self.retains_grad();
}
Tensor _fw_primal(const Tensor& self, int64_t level) {
AT_ERROR("_fw_primal is not implemented for Tensor");
}

View File

@ -89,6 +89,10 @@
manual_cpp_binding: True
variants: method
- func: retains_grad(Tensor self) -> bool
manual_cpp_binding: True
variants: method
- func: _fw_primal(Tensor(a) self, int level) -> Tensor(a)
variants: method
dispatch:

View File

@ -711,7 +711,13 @@ class TORCH_API Tensor {
/// \fn void retain_grad() const;
///
/// Enables .grad() for non-leaf Tensors.
/// Enables this Tensor to have their :attr:`grad` populated during
/// :func:`backward`. This is a no-op for leaf tensors.
/// \fn bool retains_grad() const;
///
/// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
/// populated during :func:`backward`, ``false`` otherwise.
const Tensor& set_requires_grad(bool requires_grad) const {
impl_->set_requires_grad(requires_grad);
@ -734,7 +740,16 @@ class TORCH_API Tensor {
/// The attribute will then contain the gradients computed and future calls
/// to `backward()` will accumulate (add) gradients into it.
const Tensor& grad() const {
return impl_->grad();
const Tensor& maybe_grad = impl_->grad();
if (!is_leaf() && !retains_grad() && !maybe_grad.defined()) {
TORCH_WARN(
"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
"attribute won't be populated during autograd.backward(). If you indeed want the .grad "
"field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. "
"If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor "
"instead. See github.com/pytorch/pytorch/pull/30531 for more informations.");
}
return maybe_grad;
}
// The Forward AD API functions below are low level and are not to be used by end
@ -891,6 +906,8 @@ public:
void retain_grad() const;
bool retains_grad() const;
void _backward(TensorList inputs, const c10::optional<Tensor>& gradient, c10::optional<bool> keep_graph, bool create_graph) const;
const Tensor& requires_grad_(bool _requires_grad=true) const;

View File

@ -564,6 +564,7 @@ Tensor class reference
Tensor.resize_
Tensor.resize_as_
Tensor.retain_grad
Tensor.retains_grad
Tensor.roll
Tensor.rot90
Tensor.round

View File

@ -165,9 +165,24 @@ TEST(AutogradAPITests, RetainGrad) {
auto h1 = input * 3;
auto out = (h1 * h1).sum();
{
// Warning when grad is accessed for non-leaf tensor
WarningCapture warnings;
ASSERT_FALSE(h1.grad().defined());
ASSERT_TRUE(
warnings.str().find("is not a leaf") != std::string::npos);
}
// It should be possible to call retain_grad() multiple times
h1.retain_grad();
h1.retain_grad();
{
// If retain_grad is true for a non-leaf tensor,
// there should not be any warning when grad is accessed
WarningCapture warnings;
ASSERT_FALSE(h1.grad().defined());
ASSERT_FALSE(
warnings.str().find("is not a leaf") != std::string::npos);
}
// Gradient should be accumulated
// NOLINTNEXTLINE(bugprone-argument-comment)

View File

@ -95,7 +95,7 @@ SKIP_PYTHON_BINDINGS = [
'nonzero(_(out|numpy))?',
'set_data',
'.*_overrideable', # overrideable functions for backend extension
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_',
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retains_grad', 'set_',
'_fw_primal', 'fake_quantize_per_tensor_affine_cachemask',
'fake_quantize_per_channel_affine_cachemask',
]

View File

@ -3,7 +3,6 @@ import functools
from numbers import Number
from typing import Any, Dict, Optional, Tuple, Union
import warnings
import weakref
import torch
import torch._C as _C
@ -356,33 +355,6 @@ class Tensor(torch._C._TensorBase):
have forward mode AD gradients.
""")
def retain_grad(self):
r"""Enables .grad attribute for non-leaf Tensors."""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.retain_grad, (self,), self)
if not self.requires_grad:
raise RuntimeError("can't retain_grad on Tensor that has requires_grad=False")
if self.is_leaf: # no-op for leaves
return
if hasattr(self, 'retains_grad'):
return
weak_self = weakref.ref(self)
def retain_grad_hook(grad):
var = weak_self()
if var is None:
return
if var._grad is None:
if grad.is_sparse:
var._grad = grad.clone()
else:
var._grad = grad.clone(memory_format=torch.contiguous_format)
else:
var._grad = var._grad + grad
self.register_hook(retain_grad_hook)
self.retains_grad = True
def is_shared(self):
r"""Checks if tensor is in shared memory.
@ -996,12 +968,6 @@ class Tensor(torch._C._TensorBase):
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.grad.__get__, (self,), self) # type: ignore[attr-defined]
if self.requires_grad and not hasattr(self, "retains_grad") and not self.is_leaf and self._grad is None:
warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
"attribute won't be populated during autograd.backward(). If you indeed want the gradient "
"for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the "
"non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See "
"github.com/pytorch/pytorch/pull/30531 for more information.", stacklevel=2)
return self._grad
@grad.setter

View File

@ -4646,6 +4646,20 @@ masked_fill(mask, value) -> Tensor
Out-of-place version of :meth:`torch.Tensor.masked_fill_`
""")
add_docstr_all('retain_grad',
r"""
retain_grad() -> None
Enables this Tensor to have their :attr:`grad` populated during
:func:`backward`. This is a no-op for leaf tensors.
""")
add_docstr_all('retains_grad',
r"""
Is ``True`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
populated during :func:`backward`, ``False`` otherwise.
""")
add_docstr_all('requires_grad',
r"""
Is ``True`` if gradients need to be computed for this Tensor, ``False`` otherwise.

View File

@ -523,7 +523,25 @@ PyObject *THPVariable_get_requires_grad(THPVariable *self, void *unused)
if (check_has_torch_function((PyObject *)self)) {
return handle_torch_function_getter(self, "requires_grad");
}
return PyBool_FromLong(THPVariable_Unpack(self).requires_grad());
if(THPVariable_Unpack(self).requires_grad()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
PyObject *THPVariable_retains_grad(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject *)self)) {
return handle_torch_function_getter(self, "retains_grad");
}
if(THPVariable_Unpack(self).retains_grad()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
@ -907,6 +925,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"grad_fn", (getter)THPVariable_get_grad_fn, nullptr, nullptr, nullptr},
{"_grad_fn", (getter)THPVariable_get_grad_fn, (setter)THPVariable_set_grad_fn, nullptr, nullptr},
{"is_leaf", (getter)THPVariable_is_leaf, nullptr, nullptr, nullptr},
{"retains_grad", (getter)THPVariable_retains_grad, nullptr, nullptr, nullptr},
{"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, nullptr, nullptr},
{"_grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, nullptr, nullptr}, // Allows the python class to override .grad
{"grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, nullptr, nullptr},

View File

@ -347,7 +347,8 @@ struct VariableHooks final : at::impl::VariableHooksInterface {
void set_data(const Tensor & self, const Tensor & new_data) const override;
Tensor data(const Tensor & self) const override;
int64_t _version(const Tensor & self) const override;
void retain_grad(const Tensor & self) const override;
void retain_grad(const Tensor& self) const override;
bool retains_grad(const Tensor& self) const override;
void _backward(const Tensor& self, at::TensorList inputs,
const c10::optional<Tensor>& gradient, c10::optional<bool> keep_graph,
bool create_graph) const override;
@ -434,7 +435,7 @@ int64_t VariableHooks::_version(const Tensor & self) const {
return self.unsafeGetTensorImpl()->version_counter().current_version();
}
void VariableHooks::retain_grad(const Tensor & self) const {
void VariableHooks::retain_grad(const Tensor& self) const {
TORCH_CHECK(self.requires_grad(), "can't retain_grad on Tensor that has requires_grad=False");
if (self.is_leaf()) { // no-op for leaves
return;
@ -465,6 +466,14 @@ void VariableHooks::retain_grad(const Tensor & self) const {
impl::get_autograd_meta(self)->retains_grad_ = true;
}
bool VariableHooks::retains_grad(const Tensor& self) const {
if (impl::get_autograd_meta(self)) {
return impl::get_autograd_meta(self)->retains_grad_;
} else {
return false;
}
}
void VariableHooks::_backward(
const Tensor& self,
at::TensorList inputs,

View File

@ -119,6 +119,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
{"layout", "prim"}, {"T", "prim"},
{"ndim", "prim"}, {"name", "prim"},
{"real", "aten"}, {"imag", "aten"},
{"retains_grad", "aten"},
}},
{TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
auto kind = value_->type()->kind();

View File

@ -966,6 +966,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.is_cuda.__get__: lambda self: -1,
Tensor.is_xpu.__get__: lambda self: -1,
Tensor.is_leaf.__get__: lambda self: -1,
Tensor.retains_grad.__get__: lambda self: -1,
Tensor.is_meta.__get__: lambda self: -1,
Tensor.is_mlc.__get__: lambda self: -1,
Tensor.is_mkldnn.__get__: lambda self: -1,