mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add warning when accessing Tensor::grad() in the C++ API (#59362)
Summary: Fixes https://github.com/pytorch/pytorch/issues/35379 - Adds `retains_grad` attribute backed by cpp as a native function. The python bindings for the function are skipped to be consistent with `is_leaf`. - Tried writing it without native function, but the jit test `test_tensor_properties` seems to require that it be a native function (or alternatively maybe it could also work if we manually add a prim implementation?). - Python API now uses `retain_grad` implementation from cpp Pull Request resolved: https://github.com/pytorch/pytorch/pull/59362 Reviewed By: jbschlosser Differential Revision: D28969298 Pulled By: soulitzer fbshipit-source-id: 335f2be50b9fb870cd35dc72f7dadd6c8666cc02
This commit is contained in:
committed by
Facebook GitHub Bot
parent
90303157ab
commit
f52e202840
@ -76,6 +76,10 @@ void Tensor::retain_grad() const {
|
||||
impl::GetVariableHooks()->retain_grad(*this);
|
||||
}
|
||||
|
||||
bool Tensor::retains_grad() const {
|
||||
return impl::GetVariableHooks()->retains_grad(*this);
|
||||
}
|
||||
|
||||
void Tensor::_backward(TensorList inputs,
|
||||
const c10::optional<Tensor>& gradient,
|
||||
c10::optional<bool> keep_graph,
|
||||
|
@ -54,6 +54,7 @@ struct TORCH_API VariableHooksInterface {
|
||||
virtual Tensor data(const Tensor&) const = 0;
|
||||
virtual int64_t _version(const Tensor&) const = 0;
|
||||
virtual void retain_grad(const Tensor&) const = 0;
|
||||
virtual bool retains_grad(const Tensor&) const = 0;
|
||||
virtual void _backward(const Tensor&, TensorList, const c10::optional<Tensor>&, c10::optional<bool>, bool) const = 0;
|
||||
virtual void requires_grad_(const Tensor&, bool) const = 0;
|
||||
};
|
||||
|
@ -40,6 +40,10 @@ void retain_grad(Tensor& self) {
|
||||
return self.retain_grad();
|
||||
}
|
||||
|
||||
bool retains_grad(const Tensor& self) {
|
||||
return self.retains_grad();
|
||||
}
|
||||
|
||||
Tensor _fw_primal(const Tensor& self, int64_t level) {
|
||||
AT_ERROR("_fw_primal is not implemented for Tensor");
|
||||
}
|
||||
|
@ -89,6 +89,10 @@
|
||||
manual_cpp_binding: True
|
||||
variants: method
|
||||
|
||||
- func: retains_grad(Tensor self) -> bool
|
||||
manual_cpp_binding: True
|
||||
variants: method
|
||||
|
||||
- func: _fw_primal(Tensor(a) self, int level) -> Tensor(a)
|
||||
variants: method
|
||||
dispatch:
|
||||
|
@ -711,7 +711,13 @@ class TORCH_API Tensor {
|
||||
|
||||
/// \fn void retain_grad() const;
|
||||
///
|
||||
/// Enables .grad() for non-leaf Tensors.
|
||||
/// Enables this Tensor to have their :attr:`grad` populated during
|
||||
/// :func:`backward`. This is a no-op for leaf tensors.
|
||||
|
||||
/// \fn bool retains_grad() const;
|
||||
///
|
||||
/// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
|
||||
/// populated during :func:`backward`, ``false`` otherwise.
|
||||
|
||||
const Tensor& set_requires_grad(bool requires_grad) const {
|
||||
impl_->set_requires_grad(requires_grad);
|
||||
@ -734,7 +740,16 @@ class TORCH_API Tensor {
|
||||
/// The attribute will then contain the gradients computed and future calls
|
||||
/// to `backward()` will accumulate (add) gradients into it.
|
||||
const Tensor& grad() const {
|
||||
return impl_->grad();
|
||||
const Tensor& maybe_grad = impl_->grad();
|
||||
if (!is_leaf() && !retains_grad() && !maybe_grad.defined()) {
|
||||
TORCH_WARN(
|
||||
"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
|
||||
"attribute won't be populated during autograd.backward(). If you indeed want the .grad "
|
||||
"field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. "
|
||||
"If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor "
|
||||
"instead. See github.com/pytorch/pytorch/pull/30531 for more informations.");
|
||||
}
|
||||
return maybe_grad;
|
||||
}
|
||||
|
||||
// The Forward AD API functions below are low level and are not to be used by end
|
||||
@ -891,6 +906,8 @@ public:
|
||||
|
||||
void retain_grad() const;
|
||||
|
||||
bool retains_grad() const;
|
||||
|
||||
void _backward(TensorList inputs, const c10::optional<Tensor>& gradient, c10::optional<bool> keep_graph, bool create_graph) const;
|
||||
|
||||
const Tensor& requires_grad_(bool _requires_grad=true) const;
|
||||
|
@ -564,6 +564,7 @@ Tensor class reference
|
||||
Tensor.resize_
|
||||
Tensor.resize_as_
|
||||
Tensor.retain_grad
|
||||
Tensor.retains_grad
|
||||
Tensor.roll
|
||||
Tensor.rot90
|
||||
Tensor.round
|
||||
|
@ -165,9 +165,24 @@ TEST(AutogradAPITests, RetainGrad) {
|
||||
auto h1 = input * 3;
|
||||
auto out = (h1 * h1).sum();
|
||||
|
||||
{
|
||||
// Warning when grad is accessed for non-leaf tensor
|
||||
WarningCapture warnings;
|
||||
ASSERT_FALSE(h1.grad().defined());
|
||||
ASSERT_TRUE(
|
||||
warnings.str().find("is not a leaf") != std::string::npos);
|
||||
}
|
||||
// It should be possible to call retain_grad() multiple times
|
||||
h1.retain_grad();
|
||||
h1.retain_grad();
|
||||
{
|
||||
// If retain_grad is true for a non-leaf tensor,
|
||||
// there should not be any warning when grad is accessed
|
||||
WarningCapture warnings;
|
||||
ASSERT_FALSE(h1.grad().defined());
|
||||
ASSERT_FALSE(
|
||||
warnings.str().find("is not a leaf") != std::string::npos);
|
||||
}
|
||||
|
||||
// Gradient should be accumulated
|
||||
// NOLINTNEXTLINE(bugprone-argument-comment)
|
||||
|
@ -95,7 +95,7 @@ SKIP_PYTHON_BINDINGS = [
|
||||
'nonzero(_(out|numpy))?',
|
||||
'set_data',
|
||||
'.*_overrideable', # overrideable functions for backend extension
|
||||
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_',
|
||||
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retains_grad', 'set_',
|
||||
'_fw_primal', 'fake_quantize_per_tensor_affine_cachemask',
|
||||
'fake_quantize_per_channel_affine_cachemask',
|
||||
]
|
||||
|
@ -3,7 +3,6 @@ import functools
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch._C as _C
|
||||
@ -356,33 +355,6 @@ class Tensor(torch._C._TensorBase):
|
||||
have forward mode AD gradients.
|
||||
""")
|
||||
|
||||
def retain_grad(self):
|
||||
r"""Enables .grad attribute for non-leaf Tensors."""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.retain_grad, (self,), self)
|
||||
if not self.requires_grad:
|
||||
raise RuntimeError("can't retain_grad on Tensor that has requires_grad=False")
|
||||
if self.is_leaf: # no-op for leaves
|
||||
return
|
||||
if hasattr(self, 'retains_grad'):
|
||||
return
|
||||
weak_self = weakref.ref(self)
|
||||
|
||||
def retain_grad_hook(grad):
|
||||
var = weak_self()
|
||||
if var is None:
|
||||
return
|
||||
if var._grad is None:
|
||||
if grad.is_sparse:
|
||||
var._grad = grad.clone()
|
||||
else:
|
||||
var._grad = grad.clone(memory_format=torch.contiguous_format)
|
||||
else:
|
||||
var._grad = var._grad + grad
|
||||
|
||||
self.register_hook(retain_grad_hook)
|
||||
self.retains_grad = True
|
||||
|
||||
def is_shared(self):
|
||||
r"""Checks if tensor is in shared memory.
|
||||
|
||||
@ -996,12 +968,6 @@ class Tensor(torch._C._TensorBase):
|
||||
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
|
||||
return handle_torch_function(Tensor.grad.__get__, (self,), self) # type: ignore[attr-defined]
|
||||
|
||||
if self.requires_grad and not hasattr(self, "retains_grad") and not self.is_leaf and self._grad is None:
|
||||
warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
|
||||
"attribute won't be populated during autograd.backward(). If you indeed want the gradient "
|
||||
"for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the "
|
||||
"non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See "
|
||||
"github.com/pytorch/pytorch/pull/30531 for more information.", stacklevel=2)
|
||||
return self._grad
|
||||
|
||||
@grad.setter
|
||||
|
@ -4646,6 +4646,20 @@ masked_fill(mask, value) -> Tensor
|
||||
Out-of-place version of :meth:`torch.Tensor.masked_fill_`
|
||||
""")
|
||||
|
||||
add_docstr_all('retain_grad',
|
||||
r"""
|
||||
retain_grad() -> None
|
||||
|
||||
Enables this Tensor to have their :attr:`grad` populated during
|
||||
:func:`backward`. This is a no-op for leaf tensors.
|
||||
""")
|
||||
|
||||
add_docstr_all('retains_grad',
|
||||
r"""
|
||||
Is ``True`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
|
||||
populated during :func:`backward`, ``False`` otherwise.
|
||||
""")
|
||||
|
||||
add_docstr_all('requires_grad',
|
||||
r"""
|
||||
Is ``True`` if gradients need to be computed for this Tensor, ``False`` otherwise.
|
||||
|
@ -523,7 +523,25 @@ PyObject *THPVariable_get_requires_grad(THPVariable *self, void *unused)
|
||||
if (check_has_torch_function((PyObject *)self)) {
|
||||
return handle_torch_function_getter(self, "requires_grad");
|
||||
}
|
||||
return PyBool_FromLong(THPVariable_Unpack(self).requires_grad());
|
||||
if(THPVariable_Unpack(self).requires_grad()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject *THPVariable_retains_grad(THPVariable *self, void *unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject *)self)) {
|
||||
return handle_torch_function_getter(self, "retains_grad");
|
||||
}
|
||||
if(THPVariable_Unpack(self).retains_grad()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -907,6 +925,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
|
||||
{"grad_fn", (getter)THPVariable_get_grad_fn, nullptr, nullptr, nullptr},
|
||||
{"_grad_fn", (getter)THPVariable_get_grad_fn, (setter)THPVariable_set_grad_fn, nullptr, nullptr},
|
||||
{"is_leaf", (getter)THPVariable_is_leaf, nullptr, nullptr, nullptr},
|
||||
{"retains_grad", (getter)THPVariable_retains_grad, nullptr, nullptr, nullptr},
|
||||
{"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, nullptr, nullptr},
|
||||
{"_grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, nullptr, nullptr}, // Allows the python class to override .grad
|
||||
{"grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, nullptr, nullptr},
|
||||
|
@ -347,7 +347,8 @@ struct VariableHooks final : at::impl::VariableHooksInterface {
|
||||
void set_data(const Tensor & self, const Tensor & new_data) const override;
|
||||
Tensor data(const Tensor & self) const override;
|
||||
int64_t _version(const Tensor & self) const override;
|
||||
void retain_grad(const Tensor & self) const override;
|
||||
void retain_grad(const Tensor& self) const override;
|
||||
bool retains_grad(const Tensor& self) const override;
|
||||
void _backward(const Tensor& self, at::TensorList inputs,
|
||||
const c10::optional<Tensor>& gradient, c10::optional<bool> keep_graph,
|
||||
bool create_graph) const override;
|
||||
@ -434,7 +435,7 @@ int64_t VariableHooks::_version(const Tensor & self) const {
|
||||
return self.unsafeGetTensorImpl()->version_counter().current_version();
|
||||
}
|
||||
|
||||
void VariableHooks::retain_grad(const Tensor & self) const {
|
||||
void VariableHooks::retain_grad(const Tensor& self) const {
|
||||
TORCH_CHECK(self.requires_grad(), "can't retain_grad on Tensor that has requires_grad=False");
|
||||
if (self.is_leaf()) { // no-op for leaves
|
||||
return;
|
||||
@ -465,6 +466,14 @@ void VariableHooks::retain_grad(const Tensor & self) const {
|
||||
impl::get_autograd_meta(self)->retains_grad_ = true;
|
||||
}
|
||||
|
||||
bool VariableHooks::retains_grad(const Tensor& self) const {
|
||||
if (impl::get_autograd_meta(self)) {
|
||||
return impl::get_autograd_meta(self)->retains_grad_;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void VariableHooks::_backward(
|
||||
const Tensor& self,
|
||||
at::TensorList inputs,
|
||||
|
@ -119,6 +119,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
||||
{"layout", "prim"}, {"T", "prim"},
|
||||
{"ndim", "prim"}, {"name", "prim"},
|
||||
{"real", "aten"}, {"imag", "aten"},
|
||||
{"retains_grad", "aten"},
|
||||
}},
|
||||
{TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
|
||||
auto kind = value_->type()->kind();
|
||||
|
@ -966,6 +966,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
Tensor.is_cuda.__get__: lambda self: -1,
|
||||
Tensor.is_xpu.__get__: lambda self: -1,
|
||||
Tensor.is_leaf.__get__: lambda self: -1,
|
||||
Tensor.retains_grad.__get__: lambda self: -1,
|
||||
Tensor.is_meta.__get__: lambda self: -1,
|
||||
Tensor.is_mlc.__get__: lambda self: -1,
|
||||
Tensor.is_mkldnn.__get__: lambda self: -1,
|
||||
|
Reference in New Issue
Block a user