Revert "Add tensor post accumulate grad hook API (#107063)"

This reverts commit 3f655277d44909e0770e77e1b4fe1c9b0f39d7b9.

Reverted https://github.com/pytorch/pytorch/pull/107063 on behalf of https://github.com/ZainRizvi due to Diff train weirdness. Need to temporarily revert this PR and will right land it soon afterwards ([comment](https://github.com/pytorch/pytorch/pull/107063#issuecomment-1690799057))
This commit is contained in:
PyTorch MergeBot
2023-08-24 00:12:32 +00:00
parent bc0790559b
commit 432fce4e0d
14 changed files with 8 additions and 378 deletions

View File

@ -889,166 +889,6 @@ class TestAutograd(TestCase):
with self.assertRaisesRegex(ValueError, "Expected allow_unused to be True or not passed when"):
torch.autograd.grad(y, x, allow_unused=False, materialize_grads=True)
def test_post_accumulate_grad_hook_on_non_leaf(self):
def hook(tensor):
tensor.sub_(1.)
leaf = torch.rand(3, requires_grad=True)
non_leaf = 2. * leaf
with self.assertRaisesRegex(
RuntimeError,
"post accumulate grad hooks cannot be registered on non-leaf tensors"):
non_leaf.register_post_accumulate_grad_hook(hook)
def test_post_accumulate_grad_hook_multiple_hooks(self):
def hook1(tensor):
tensor.sub_(tensor.grad)
def hook2(tensor):
tensor.mul_(4.)
tensor = torch.rand(3, requires_grad=True)
tensor_ref = tensor.clone().detach()
tensor.register_post_accumulate_grad_hook(hook1)
tensor.register_post_accumulate_grad_hook(hook2)
sum = tensor.sum()
sum.backward()
# both hooks should be called, in order
self.assertEqual(4. * (tensor_ref - 1.), tensor)
def test_post_accumulate_grad_hook_multiple_tensors(self):
def hook(tensor):
tensor.sub_(tensor.grad)
tensor1 = torch.rand(3, requires_grad=True)
tensor1_ref = tensor1.clone().detach()
tensor2 = torch.rand(5, requires_grad=True)
tensor2_ref = tensor2.clone().detach()
tensor1.register_post_accumulate_grad_hook(hook)
tensor2.register_post_accumulate_grad_hook(hook)
tensor1.sum().backward()
tensor2.sum().backward()
# both tensors should have been modified
self.assertEqual(tensor1_ref - 1., tensor1)
self.assertEqual(tensor2_ref - 1., tensor2)
def test_post_accumulate_grad_hook_returns_not_None(self):
def bad_hook(tensor):
return tensor.grad
tensor = torch.rand(2, 3, requires_grad=True)
tensor.register_post_accumulate_grad_hook(bad_hook)
# should error!
with self.assertRaisesRegex(RuntimeError, "hooks should return None."):
tensor.sum().backward()
def test_post_accumulate_grad_hook_e2e(self):
def setup_optim_in_bwd(model):
optims = {}
handles = []
def optim_step_hook(param):
optims[param].step()
optims[param].zero_grad()
for p in model.parameters():
optims[p] = torch.optim.Adam([p])
handles.append(p.register_post_accumulate_grad_hook(optim_step_hook))
return handles
model = torch.nn.Linear(3, 2)
input = torch.rand(2, 3)
handles = setup_optim_in_bwd(model)
# make a copy for reference
model_copy = deepcopy(model)
optim_copy = torch.optim.Adam(model_copy.parameters())
iters = 5
for _ in range(iters):
loss = model(input).sum()
loss.backward()
loss_copy = model_copy(input).sum()
loss_copy.backward()
optim_copy.step()
optim_copy.zero_grad()
params_copy = [] # freeze a copy of the params to compare later
for p_reference, p in zip(model_copy.parameters(), model.parameters()):
self.assertEqual(p_reference, p)
params_copy.append(p_reference.clone().detach())
# After removing the handle, the model should no longer update.
for h in handles:
h.remove()
for _ in range(iters):
loss = model(input).sum()
loss.backward()
loss_copy = model_copy(input).sum()
loss_copy.backward()
optim_copy.step()
optim_copy.zero_grad()
for p_static, p_reference, p in zip(params_copy, model_copy.parameters(), model.parameters()):
self.assertEqual(p_static, p)
self.assertNotEqual(p_reference, p)
def test_post_accumulate_grad_hook_gets_cleaned_up(self):
def fun_stuff_with_hook():
thing_to_put_in_hook = torch.rand(3)
def hook(tensor):
tensor.sub_(tensor.grad)
tensor.add_(thing_to_put_in_hook)
tensor = torch.rand(3, requires_grad=True)
tensor.register_post_accumulate_grad_hook(hook)
tensor.sum().backward()
ref = weakref.ref(thing_to_put_in_hook)
gc.collect()
return tensor, ref
with disable_gc():
tensor, ref = fun_stuff_with_hook()
self.assertIsNotNone(ref()) # thing_to_put_in_hook should be kept alive by tensor
del tensor
gc.collect()
self.assertIsNone(ref()) # thing_to_put_in_hook should be cleaned
def test_post_accumulate_grad_hook_ordering(self):
tensor = torch.rand(3, requires_grad=True)
def pre_hook(grad):
return grad.sub(2.)
def acc_grad_node_pre_hook(grad_out):
return (grad_out[0].div(5.),)
def post_acc_grad_hook(tensor):
tensor.grad.add_(0.5)
def acc_grad_node_post_hook(grad_in, grad_out):
tensor.grad = grad_out[0].mul(10)
acc_grad = tensor.view_as(tensor).grad_fn.next_functions[0][0]
tensor.register_hook(pre_hook)
acc_grad.register_prehook(acc_grad_node_pre_hook)
tensor.register_post_accumulate_grad_hook(post_acc_grad_hook)
acc_grad.register_hook(acc_grad_node_post_hook)
tensor.sum().backward()
# the hooks should run in the order of:
# 1. tensor prehook
# 2. acc_grad prehook
# 3. tensor post acc_grad hook
# 4. acc_grad posthook
# so that would be ((1 - 2) / 5 + 0.5) * 10 = 3
self.assertEqual(torch.tensor([3., 3., 3.]), tensor.grad)
def test_hook_with_no_name(self):
# Create a hook that do not have a __name__ attribute
class MyHookClass:

View File

@ -530,7 +530,7 @@ class Tensor(torch._C._TensorBase):
return handle_torch_function(Tensor.register_hook, (self,), self, hook)
if not self.requires_grad:
raise RuntimeError(
"cannot register a hook on a tensor that doesn't require gradient"
"cannot register a hook on a tensor that " "doesn't require gradient"
)
if self._backward_hooks is None:
self._backward_hooks = OrderedDict()
@ -540,62 +540,6 @@ class Tensor(torch._C._TensorBase):
self._backward_hooks[handle.id] = hook
return handle
def register_post_accumulate_grad_hook(self, hook):
r"""Registers a backward hook that runs after grad accumulation.
The hook will be called after all gradients for a tensor have been accumulated,
meaning that the .grad field has been updated on that tensor. The post
accumulate grad hook is ONLY applicable for leaf tensors (tensors without a
.grad_fn field). Registering this hook on a non-leaf tensor will error!
The hook should have the following signature::
hook(param: Tensor) -> None
Note that, unlike other autograd hooks, this hook operates on the tensor
that requires grad and not the grad itself. The hook can in-place modify
and access its Tensor argument, including its .grad field.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks. Since
this hook runs during the backward pass, it will run in no_grad mode (unless
create_graph is True). You can use torch.enable_grad() to re-enable autograd
within the hook if you need it.
Example::
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> lr = 0.01
>>> # simulate a simple SGD update
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)
>>> h.remove() # removes the hook
"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.register_post_accumulate_grad_hook, (self,), self, hook
)
if not self.requires_grad:
raise RuntimeError(
"cannot register a hook on a tensor that doesn't require gradient"
)
if self.grad_fn is not None:
raise RuntimeError(
"post accumulate grad hooks cannot be registered on non-leaf tensors"
)
if self._post_accumulate_grad_hooks is None:
self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict()
handle = hooks.RemovableHandle(self._post_accumulate_grad_hooks)
self._post_accumulate_grad_hooks[handle.id] = hook
return handle
def reinforce(self, reward):
def trim(str):
return "\n".join([line.strip() for line in str.split("\n")])

View File

@ -523,12 +523,6 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
return tensor_pre_hooks_;
}
virtual std::unique_ptr<PostAccumulateGradHook>&
tensor_post_acc_grad_hooks() noexcept {
static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
return empty;
}
std::unordered_map<int, std::unique_ptr<FunctionPreHook>>&
retains_grad_hooks() noexcept {
return retains_grad_hooks_;

View File

@ -42,17 +42,5 @@ struct TORCH_API FunctionPostHook {
}
};
struct TORCH_API PostAccumulateGradHook {
virtual ~PostAccumulateGradHook() = default;
virtual void operator()(const Variable& tensor) = 0;
// only implemented for python hooks on nodes, registers hook with compiled
// autograd
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
throw std::runtime_error(
std::string("not yet implemented for compiled autograd: ") +
typeid(*this).name());
}
};
} // namespace autograd
} // namespace torch

View File

@ -57,11 +57,6 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
1 + !post_hooks().empty() /* num_expected_refs */,
[&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); });
auto& hook = tensor_post_acc_grad_hooks();
if (hook != nullptr) {
(*hook)(variable);
}
return variable_list();
}
@ -93,7 +88,6 @@ variable_list AccumulateGrad::apply_with_saved(
});
saved.after(variable_copy);
saved.after(grad_copy);
return variable_list();
}

View File

@ -50,15 +50,6 @@ struct TORCH_API AccumulateGrad : public Node {
// to all other Nodes). So we must lazily read the Tensor hooks here.
return impl::hooks(variable);
}
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks() noexcept
override {
// NB: Since the AccumulateGrad Node is only a weak ref from the Tensor,
// it can be destroyed even though the Tensor is still alive (contrary
// to all other Nodes). So we must lazily read the Tensor hooks here.
return impl::post_acc_grad_hooks(variable);
}
// Given a variable with its current grad as variable_grad, accumulates
// new_grad into variable_grad if in place accumulation is possible.
// Otherwise, uses 'update_grad' to update the grad for the variable.

View File

@ -30,11 +30,10 @@ namespace autograd {
namespace {
// This function is called in 4 different cases:
// This function is called in 3 different cases:
// 1) TensorPreHook
// 2) PreHook
// 3) PostHook
// 4) TensorPostAccGradHook
//
// Depending on the case, args and res can hold different types of objects:
//
@ -42,13 +41,11 @@ namespace {
// TensorPreHook (Tensor,)
// PreHook ((Tensor, ...),) (grad_outputs,)
// PostHook ((Tensor, ...), (Tensor, ...)) (grad_inputs, grad_outputs)
// TensorPostAccGradHook ((Tensor), ()) (tensor,)
//
// res:
// TensorPreHook Tensor
// PreHook ((Tensor, ...),) (grad_outputs,)
// PostHook ((Tensor, ...),) (grad_inputs,)
// TensorPostAccGradHook None
//
// This function returns True if any hook returned non-None value, and False
// otherwise.
@ -196,30 +193,6 @@ void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) {
}
}
PyFunctionTensorPostAccGradHooks::PyFunctionTensorPostAccGradHooks(
PyObject* dict)
: dict(dict) {
Py_INCREF(dict);
}
PyFunctionTensorPostAccGradHooks::~PyFunctionTensorPostAccGradHooks() {
// If python is already dead, leak the wrapped python objects
if (Py_IsInitialized()) {
pybind11::gil_scoped_acquire gil;
Py_DECREF(dict);
}
}
auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor)
-> void {
pybind11::gil_scoped_acquire gil;
THPObjectPtr tup(PyTuple_New(1));
PyTuple_SET_ITEM(tup.get(), 0, THPVariable_Wrap(tensor));
bool returned_none = !_call_hooks(dict, tup.get());
TORCH_CHECK(
returned_none, "Tensor post accumulate grad hooks should return None.");
}
} // namespace autograd
} // namespace torch

View File

@ -34,17 +34,5 @@ struct PyFunctionPostHook : public FunctionPostHook {
PyObject* dict;
};
// PyFunctionTensorPostAccGradHooks is a dictionary of PostAccumulateGradHooks,
// and it is understandable if you are confused by why it's a subclass. We are
// simply following the precedent of PyFunctionPreHook and PyFunctionPostHook
// above to easily enroll into existing infrastructure.
struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
PyFunctionTensorPostAccGradHooks(PyObject* dict);
~PyFunctionTensorPostAccGradHooks() override;
void operator()(const Variable& tensor) override;
// fall back to the compiled_args of PostAccumulateGradHook superclass
PyObject* dict;
};
} // namespace autograd
} // namespace torch

View File

@ -428,7 +428,6 @@ static int THPVariable_clear(THPVariable* self) {
return 0;
}
Py_CLEAR(self->backward_hooks);
Py_CLEAR(self->post_accumulate_grad_hooks);
const auto& tensor = THPVariable_Unpack(self);
if (tensor.defined()) {
// Two situations to consider:
@ -1166,47 +1165,6 @@ int THPVariable_set_backwards_hooks(
END_HANDLE_TH_ERRORS_RET(-1)
}
PyObject* THPVariable_get_post_accumulate_grad_hooks(
THPVariable* self,
void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "_post_accumulate_grad_hooks");
}
if (self->post_accumulate_grad_hooks) {
Py_INCREF(self->post_accumulate_grad_hooks);
return self->post_accumulate_grad_hooks;
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
int THPVariable_set_post_accumulate_grad_hooks(
THPVariable* self,
PyObject* obj,
void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_setter(
self, "_post_accumulate_grad_hooks", obj);
}
THPUtils_assertRet(
-1, obj, "Deletion of _post_accumulate_grad_hooks not allowed!");
if (obj == Py_None) {
obj = nullptr;
}
Py_XINCREF(obj);
Py_CLEAR(self->post_accumulate_grad_hooks);
self->post_accumulate_grad_hooks = obj;
const auto& tensor = THPVariable_Unpack(self);
if (obj) {
torch::autograd::impl::set_post_acc_grad_hooks(
tensor, std::make_unique<PyFunctionTensorPostAccGradHooks>(obj));
}
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
}
PyObject* THPVariable_get_base(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
@ -1521,11 +1479,6 @@ static struct PyGetSetDef THPVariable_properties[] = {
(setter)THPVariable_set_backwards_hooks,
nullptr,
nullptr},
{"_post_accumulate_grad_hooks",
(getter)THPVariable_get_post_accumulate_grad_hooks,
(setter)THPVariable_set_post_accumulate_grad_hooks,
nullptr,
nullptr},
{"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr},
{"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
@ -2083,7 +2036,6 @@ static int THPVariable_subclass_traverse(
// Finally traverse THPVariable special stuff
Py_VISIT(var->backward_hooks);
Py_VISIT(var->post_accumulate_grad_hooks);
if (!var->cdata.unsafeIsBorrowed()) {
const auto& tensor = THPVariable_Unpack(var);
if (tensor.defined()) {

View File

@ -22,10 +22,6 @@ struct THPVariable {
// Hooks to be run on backwards pass (corresponds to Python attr
// '_backwards_hooks', set by 'register_hook')
PyObject* backward_hooks = nullptr;
// Hooks to be run in the backwards pass after accumulate grad,
// i.e., after the .grad has been set (corresponds to Python attr
// '_post_accumulate_grad_hooks', set by 'register_post_accumulate_grad_hook')
PyObject* post_accumulate_grad_hooks = nullptr;
};
TORCH_PYTHON_API void registerPythonTensorClass(

View File

@ -363,19 +363,6 @@ void clear_hooks(const at::TensorBase& self) {
materialize_autograd_meta(self)->hooks_.clear();
}
void set_post_acc_grad_hooks(
const at::TensorBase& self,
std::unique_ptr<PostAccumulateGradHook> dict) {
AutogradMeta* meta = materialize_autograd_meta(self);
meta->post_acc_grad_hooks_ = std::move(dict);
}
std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
const Variable& self) {
TORCH_INTERNAL_ASSERT(get_autograd_meta(self));
return get_autograd_meta(self)->post_acc_grad_hooks_;
}
void set_name(const Variable& self, const std::string& name) {
materialize_autograd_meta(self)->name_ = name;
}

View File

@ -186,12 +186,6 @@ TORCH_API void add_hook(
TORCH_API std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable&);
TORCH_API void clear_hooks(const at::TensorBase&);
TORCH_API void set_post_acc_grad_hooks(
const at::TensorBase&,
std::unique_ptr<PostAccumulateGradHook> dict);
TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
const Variable&);
TORCH_API void create_cpp_hook(
const at::TensorBase&,
bool is_retains_grad_hooks = false);
@ -237,12 +231,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
std::vector<std::unique_ptr<FunctionPreHook>> hooks_;
std::shared_ptr<hooks_list> cpp_hooks_list_;
// The post_acc_grad_hooks_ field stores only Python hooks
// (PyFunctionTensorPostAccGradHooks) that are called after the
// .grad field has been accumulated into. This is less complicated
// than the hooks_ field, which encapsulates a lot more.
std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;
// Only meaningful on leaf variables (must be false otherwise)
bool requires_grad_{false};

View File

@ -347,9 +347,6 @@ class CompiledNodeArgs {
TORCH_CHECK(
fn->retains_grad_hooks().empty(),
"retains_grad_hooks not implemented for compiled autograd");
TORCH_CHECK(
fn->tensor_post_acc_grad_hooks() == nullptr,
"tensor_post_acc_grad_hooks not implemented for compiled autograd");
for (auto& i : fn->tensor_pre_hooks()) {
i->compiled_args(*this);
}

View File

@ -1205,7 +1205,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.mT.__get__: lambda self: -1,
Tensor.mH.__get__: lambda self: -1,
Tensor._backward_hooks.__get__: lambda self: -1,
Tensor._post_accumulate_grad_hooks.__get__: lambda self: -1,
Tensor._base.__get__: lambda self: -1,
Tensor._cdata.__get__: lambda self: -1,
Tensor.grad.__get__: lambda self: -1,
@ -1329,7 +1328,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.record_stream: lambda self, stream: -1,
Tensor.refine_names: lambda self, names: -1,
Tensor.register_hook: lambda self, hook: -1,
Tensor.register_post_accumulate_grad_hook: lambda self, hook: -1,
Tensor.rename: lambda self, name: -1,
Tensor.repeat: lambda self, *size: -1,
Tensor.requires_grad_: lambda self, requires_grad=True: -1,