mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add tensor post accumulate grad hook API (#107063)"
This reverts commit 3f655277d44909e0770e77e1b4fe1c9b0f39d7b9. Reverted https://github.com/pytorch/pytorch/pull/107063 on behalf of https://github.com/ZainRizvi due to Diff train weirdness. Need to temporarily revert this PR and will right land it soon afterwards ([comment](https://github.com/pytorch/pytorch/pull/107063#issuecomment-1690799057))
This commit is contained in:
@ -889,166 +889,6 @@ class TestAutograd(TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "Expected allow_unused to be True or not passed when"):
|
||||
torch.autograd.grad(y, x, allow_unused=False, materialize_grads=True)
|
||||
|
||||
def test_post_accumulate_grad_hook_on_non_leaf(self):
|
||||
def hook(tensor):
|
||||
tensor.sub_(1.)
|
||||
leaf = torch.rand(3, requires_grad=True)
|
||||
non_leaf = 2. * leaf
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"post accumulate grad hooks cannot be registered on non-leaf tensors"):
|
||||
non_leaf.register_post_accumulate_grad_hook(hook)
|
||||
|
||||
def test_post_accumulate_grad_hook_multiple_hooks(self):
|
||||
def hook1(tensor):
|
||||
tensor.sub_(tensor.grad)
|
||||
|
||||
def hook2(tensor):
|
||||
tensor.mul_(4.)
|
||||
tensor = torch.rand(3, requires_grad=True)
|
||||
tensor_ref = tensor.clone().detach()
|
||||
tensor.register_post_accumulate_grad_hook(hook1)
|
||||
tensor.register_post_accumulate_grad_hook(hook2)
|
||||
sum = tensor.sum()
|
||||
sum.backward()
|
||||
# both hooks should be called, in order
|
||||
self.assertEqual(4. * (tensor_ref - 1.), tensor)
|
||||
|
||||
def test_post_accumulate_grad_hook_multiple_tensors(self):
|
||||
def hook(tensor):
|
||||
tensor.sub_(tensor.grad)
|
||||
tensor1 = torch.rand(3, requires_grad=True)
|
||||
tensor1_ref = tensor1.clone().detach()
|
||||
tensor2 = torch.rand(5, requires_grad=True)
|
||||
tensor2_ref = tensor2.clone().detach()
|
||||
tensor1.register_post_accumulate_grad_hook(hook)
|
||||
tensor2.register_post_accumulate_grad_hook(hook)
|
||||
tensor1.sum().backward()
|
||||
tensor2.sum().backward()
|
||||
# both tensors should have been modified
|
||||
self.assertEqual(tensor1_ref - 1., tensor1)
|
||||
self.assertEqual(tensor2_ref - 1., tensor2)
|
||||
|
||||
def test_post_accumulate_grad_hook_returns_not_None(self):
|
||||
def bad_hook(tensor):
|
||||
return tensor.grad
|
||||
tensor = torch.rand(2, 3, requires_grad=True)
|
||||
tensor.register_post_accumulate_grad_hook(bad_hook)
|
||||
# should error!
|
||||
with self.assertRaisesRegex(RuntimeError, "hooks should return None."):
|
||||
tensor.sum().backward()
|
||||
|
||||
def test_post_accumulate_grad_hook_e2e(self):
|
||||
def setup_optim_in_bwd(model):
|
||||
optims = {}
|
||||
handles = []
|
||||
|
||||
def optim_step_hook(param):
|
||||
optims[param].step()
|
||||
optims[param].zero_grad()
|
||||
|
||||
for p in model.parameters():
|
||||
optims[p] = torch.optim.Adam([p])
|
||||
handles.append(p.register_post_accumulate_grad_hook(optim_step_hook))
|
||||
|
||||
return handles
|
||||
|
||||
model = torch.nn.Linear(3, 2)
|
||||
input = torch.rand(2, 3)
|
||||
handles = setup_optim_in_bwd(model)
|
||||
|
||||
# make a copy for reference
|
||||
model_copy = deepcopy(model)
|
||||
optim_copy = torch.optim.Adam(model_copy.parameters())
|
||||
|
||||
iters = 5
|
||||
|
||||
for _ in range(iters):
|
||||
loss = model(input).sum()
|
||||
loss.backward()
|
||||
|
||||
loss_copy = model_copy(input).sum()
|
||||
loss_copy.backward()
|
||||
optim_copy.step()
|
||||
optim_copy.zero_grad()
|
||||
|
||||
params_copy = [] # freeze a copy of the params to compare later
|
||||
for p_reference, p in zip(model_copy.parameters(), model.parameters()):
|
||||
self.assertEqual(p_reference, p)
|
||||
params_copy.append(p_reference.clone().detach())
|
||||
|
||||
# After removing the handle, the model should no longer update.
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
for _ in range(iters):
|
||||
loss = model(input).sum()
|
||||
loss.backward()
|
||||
|
||||
loss_copy = model_copy(input).sum()
|
||||
loss_copy.backward()
|
||||
optim_copy.step()
|
||||
optim_copy.zero_grad()
|
||||
|
||||
for p_static, p_reference, p in zip(params_copy, model_copy.parameters(), model.parameters()):
|
||||
self.assertEqual(p_static, p)
|
||||
self.assertNotEqual(p_reference, p)
|
||||
|
||||
def test_post_accumulate_grad_hook_gets_cleaned_up(self):
|
||||
|
||||
def fun_stuff_with_hook():
|
||||
thing_to_put_in_hook = torch.rand(3)
|
||||
|
||||
def hook(tensor):
|
||||
tensor.sub_(tensor.grad)
|
||||
tensor.add_(thing_to_put_in_hook)
|
||||
|
||||
tensor = torch.rand(3, requires_grad=True)
|
||||
tensor.register_post_accumulate_grad_hook(hook)
|
||||
tensor.sum().backward()
|
||||
ref = weakref.ref(thing_to_put_in_hook)
|
||||
gc.collect()
|
||||
return tensor, ref
|
||||
|
||||
with disable_gc():
|
||||
tensor, ref = fun_stuff_with_hook()
|
||||
self.assertIsNotNone(ref()) # thing_to_put_in_hook should be kept alive by tensor
|
||||
|
||||
del tensor
|
||||
gc.collect()
|
||||
self.assertIsNone(ref()) # thing_to_put_in_hook should be cleaned
|
||||
|
||||
def test_post_accumulate_grad_hook_ordering(self):
|
||||
tensor = torch.rand(3, requires_grad=True)
|
||||
|
||||
def pre_hook(grad):
|
||||
return grad.sub(2.)
|
||||
|
||||
def acc_grad_node_pre_hook(grad_out):
|
||||
return (grad_out[0].div(5.),)
|
||||
|
||||
def post_acc_grad_hook(tensor):
|
||||
tensor.grad.add_(0.5)
|
||||
|
||||
def acc_grad_node_post_hook(grad_in, grad_out):
|
||||
tensor.grad = grad_out[0].mul(10)
|
||||
|
||||
acc_grad = tensor.view_as(tensor).grad_fn.next_functions[0][0]
|
||||
tensor.register_hook(pre_hook)
|
||||
acc_grad.register_prehook(acc_grad_node_pre_hook)
|
||||
tensor.register_post_accumulate_grad_hook(post_acc_grad_hook)
|
||||
acc_grad.register_hook(acc_grad_node_post_hook)
|
||||
tensor.sum().backward()
|
||||
|
||||
# the hooks should run in the order of:
|
||||
# 1. tensor prehook
|
||||
# 2. acc_grad prehook
|
||||
# 3. tensor post acc_grad hook
|
||||
# 4. acc_grad posthook
|
||||
# so that would be ((1 - 2) / 5 + 0.5) * 10 = 3
|
||||
self.assertEqual(torch.tensor([3., 3., 3.]), tensor.grad)
|
||||
|
||||
def test_hook_with_no_name(self):
|
||||
# Create a hook that do not have a __name__ attribute
|
||||
class MyHookClass:
|
||||
|
@ -530,7 +530,7 @@ class Tensor(torch._C._TensorBase):
|
||||
return handle_torch_function(Tensor.register_hook, (self,), self, hook)
|
||||
if not self.requires_grad:
|
||||
raise RuntimeError(
|
||||
"cannot register a hook on a tensor that doesn't require gradient"
|
||||
"cannot register a hook on a tensor that " "doesn't require gradient"
|
||||
)
|
||||
if self._backward_hooks is None:
|
||||
self._backward_hooks = OrderedDict()
|
||||
@ -540,62 +540,6 @@ class Tensor(torch._C._TensorBase):
|
||||
self._backward_hooks[handle.id] = hook
|
||||
return handle
|
||||
|
||||
def register_post_accumulate_grad_hook(self, hook):
|
||||
r"""Registers a backward hook that runs after grad accumulation.
|
||||
|
||||
The hook will be called after all gradients for a tensor have been accumulated,
|
||||
meaning that the .grad field has been updated on that tensor. The post
|
||||
accumulate grad hook is ONLY applicable for leaf tensors (tensors without a
|
||||
.grad_fn field). Registering this hook on a non-leaf tensor will error!
|
||||
|
||||
The hook should have the following signature::
|
||||
|
||||
hook(param: Tensor) -> None
|
||||
|
||||
Note that, unlike other autograd hooks, this hook operates on the tensor
|
||||
that requires grad and not the grad itself. The hook can in-place modify
|
||||
and access its Tensor argument, including its .grad field.
|
||||
|
||||
This function returns a handle with a method ``handle.remove()``
|
||||
that removes the hook from the module.
|
||||
|
||||
.. note::
|
||||
See :ref:`backward-hooks-execution` for more information on how when this hook
|
||||
is executed, and how its execution is ordered relative to other hooks. Since
|
||||
this hook runs during the backward pass, it will run in no_grad mode (unless
|
||||
create_graph is True). You can use torch.enable_grad() to re-enable autograd
|
||||
within the hook if you need it.
|
||||
|
||||
Example::
|
||||
|
||||
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
|
||||
>>> lr = 0.01
|
||||
>>> # simulate a simple SGD update
|
||||
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
|
||||
>>> v.backward(torch.tensor([1., 2., 3.]))
|
||||
>>> v
|
||||
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)
|
||||
|
||||
>>> h.remove() # removes the hook
|
||||
"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(
|
||||
Tensor.register_post_accumulate_grad_hook, (self,), self, hook
|
||||
)
|
||||
if not self.requires_grad:
|
||||
raise RuntimeError(
|
||||
"cannot register a hook on a tensor that doesn't require gradient"
|
||||
)
|
||||
if self.grad_fn is not None:
|
||||
raise RuntimeError(
|
||||
"post accumulate grad hooks cannot be registered on non-leaf tensors"
|
||||
)
|
||||
if self._post_accumulate_grad_hooks is None:
|
||||
self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict()
|
||||
handle = hooks.RemovableHandle(self._post_accumulate_grad_hooks)
|
||||
self._post_accumulate_grad_hooks[handle.id] = hook
|
||||
return handle
|
||||
|
||||
def reinforce(self, reward):
|
||||
def trim(str):
|
||||
return "\n".join([line.strip() for line in str.split("\n")])
|
||||
|
@ -523,12 +523,6 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
return tensor_pre_hooks_;
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<PostAccumulateGradHook>&
|
||||
tensor_post_acc_grad_hooks() noexcept {
|
||||
static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
|
||||
return empty;
|
||||
}
|
||||
|
||||
std::unordered_map<int, std::unique_ptr<FunctionPreHook>>&
|
||||
retains_grad_hooks() noexcept {
|
||||
return retains_grad_hooks_;
|
||||
|
@ -42,17 +42,5 @@ struct TORCH_API FunctionPostHook {
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API PostAccumulateGradHook {
|
||||
virtual ~PostAccumulateGradHook() = default;
|
||||
virtual void operator()(const Variable& tensor) = 0;
|
||||
// only implemented for python hooks on nodes, registers hook with compiled
|
||||
// autograd
|
||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
||||
throw std::runtime_error(
|
||||
std::string("not yet implemented for compiled autograd: ") +
|
||||
typeid(*this).name());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
@ -57,11 +57,6 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
|
||||
1 + !post_hooks().empty() /* num_expected_refs */,
|
||||
[&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); });
|
||||
|
||||
auto& hook = tensor_post_acc_grad_hooks();
|
||||
if (hook != nullptr) {
|
||||
(*hook)(variable);
|
||||
}
|
||||
|
||||
return variable_list();
|
||||
}
|
||||
|
||||
@ -93,7 +88,6 @@ variable_list AccumulateGrad::apply_with_saved(
|
||||
});
|
||||
saved.after(variable_copy);
|
||||
saved.after(grad_copy);
|
||||
|
||||
return variable_list();
|
||||
}
|
||||
|
||||
|
@ -50,15 +50,6 @@ struct TORCH_API AccumulateGrad : public Node {
|
||||
// to all other Nodes). So we must lazily read the Tensor hooks here.
|
||||
return impl::hooks(variable);
|
||||
}
|
||||
|
||||
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks() noexcept
|
||||
override {
|
||||
// NB: Since the AccumulateGrad Node is only a weak ref from the Tensor,
|
||||
// it can be destroyed even though the Tensor is still alive (contrary
|
||||
// to all other Nodes). So we must lazily read the Tensor hooks here.
|
||||
return impl::post_acc_grad_hooks(variable);
|
||||
}
|
||||
|
||||
// Given a variable with its current grad as variable_grad, accumulates
|
||||
// new_grad into variable_grad if in place accumulation is possible.
|
||||
// Otherwise, uses 'update_grad' to update the grad for the variable.
|
||||
|
@ -30,11 +30,10 @@ namespace autograd {
|
||||
|
||||
namespace {
|
||||
|
||||
// This function is called in 4 different cases:
|
||||
// This function is called in 3 different cases:
|
||||
// 1) TensorPreHook
|
||||
// 2) PreHook
|
||||
// 3) PostHook
|
||||
// 4) TensorPostAccGradHook
|
||||
//
|
||||
// Depending on the case, args and res can hold different types of objects:
|
||||
//
|
||||
@ -42,13 +41,11 @@ namespace {
|
||||
// TensorPreHook (Tensor,)
|
||||
// PreHook ((Tensor, ...),) (grad_outputs,)
|
||||
// PostHook ((Tensor, ...), (Tensor, ...)) (grad_inputs, grad_outputs)
|
||||
// TensorPostAccGradHook ((Tensor), ()) (tensor,)
|
||||
//
|
||||
// res:
|
||||
// TensorPreHook Tensor
|
||||
// PreHook ((Tensor, ...),) (grad_outputs,)
|
||||
// PostHook ((Tensor, ...),) (grad_inputs,)
|
||||
// TensorPostAccGradHook None
|
||||
//
|
||||
// This function returns True if any hook returned non-None value, and False
|
||||
// otherwise.
|
||||
@ -196,30 +193,6 @@ void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) {
|
||||
}
|
||||
}
|
||||
|
||||
PyFunctionTensorPostAccGradHooks::PyFunctionTensorPostAccGradHooks(
|
||||
PyObject* dict)
|
||||
: dict(dict) {
|
||||
Py_INCREF(dict);
|
||||
}
|
||||
|
||||
PyFunctionTensorPostAccGradHooks::~PyFunctionTensorPostAccGradHooks() {
|
||||
// If python is already dead, leak the wrapped python objects
|
||||
if (Py_IsInitialized()) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
Py_DECREF(dict);
|
||||
}
|
||||
}
|
||||
|
||||
auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor)
|
||||
-> void {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
THPObjectPtr tup(PyTuple_New(1));
|
||||
PyTuple_SET_ITEM(tup.get(), 0, THPVariable_Wrap(tensor));
|
||||
bool returned_none = !_call_hooks(dict, tup.get());
|
||||
TORCH_CHECK(
|
||||
returned_none, "Tensor post accumulate grad hooks should return None.");
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
||||
|
@ -34,17 +34,5 @@ struct PyFunctionPostHook : public FunctionPostHook {
|
||||
PyObject* dict;
|
||||
};
|
||||
|
||||
// PyFunctionTensorPostAccGradHooks is a dictionary of PostAccumulateGradHooks,
|
||||
// and it is understandable if you are confused by why it's a subclass. We are
|
||||
// simply following the precedent of PyFunctionPreHook and PyFunctionPostHook
|
||||
// above to easily enroll into existing infrastructure.
|
||||
struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
|
||||
PyFunctionTensorPostAccGradHooks(PyObject* dict);
|
||||
~PyFunctionTensorPostAccGradHooks() override;
|
||||
void operator()(const Variable& tensor) override;
|
||||
// fall back to the compiled_args of PostAccumulateGradHook superclass
|
||||
PyObject* dict;
|
||||
};
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
@ -428,7 +428,6 @@ static int THPVariable_clear(THPVariable* self) {
|
||||
return 0;
|
||||
}
|
||||
Py_CLEAR(self->backward_hooks);
|
||||
Py_CLEAR(self->post_accumulate_grad_hooks);
|
||||
const auto& tensor = THPVariable_Unpack(self);
|
||||
if (tensor.defined()) {
|
||||
// Two situations to consider:
|
||||
@ -1166,47 +1165,6 @@ int THPVariable_set_backwards_hooks(
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
PyObject* THPVariable_get_post_accumulate_grad_hooks(
|
||||
THPVariable* self,
|
||||
void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_getter(self, "_post_accumulate_grad_hooks");
|
||||
}
|
||||
if (self->post_accumulate_grad_hooks) {
|
||||
Py_INCREF(self->post_accumulate_grad_hooks);
|
||||
return self->post_accumulate_grad_hooks;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
int THPVariable_set_post_accumulate_grad_hooks(
|
||||
THPVariable* self,
|
||||
PyObject* obj,
|
||||
void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_setter(
|
||||
self, "_post_accumulate_grad_hooks", obj);
|
||||
}
|
||||
THPUtils_assertRet(
|
||||
-1, obj, "Deletion of _post_accumulate_grad_hooks not allowed!");
|
||||
if (obj == Py_None) {
|
||||
obj = nullptr;
|
||||
}
|
||||
Py_XINCREF(obj);
|
||||
Py_CLEAR(self->post_accumulate_grad_hooks);
|
||||
self->post_accumulate_grad_hooks = obj;
|
||||
const auto& tensor = THPVariable_Unpack(self);
|
||||
if (obj) {
|
||||
torch::autograd::impl::set_post_acc_grad_hooks(
|
||||
tensor, std::make_unique<PyFunctionTensorPostAccGradHooks>(obj));
|
||||
}
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
PyObject* THPVariable_get_base(THPVariable* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
@ -1521,11 +1479,6 @@ static struct PyGetSetDef THPVariable_properties[] = {
|
||||
(setter)THPVariable_set_backwards_hooks,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"_post_accumulate_grad_hooks",
|
||||
(getter)THPVariable_get_post_accumulate_grad_hooks,
|
||||
(setter)THPVariable_set_post_accumulate_grad_hooks,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr},
|
||||
{"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
|
||||
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
|
||||
@ -2083,7 +2036,6 @@ static int THPVariable_subclass_traverse(
|
||||
|
||||
// Finally traverse THPVariable special stuff
|
||||
Py_VISIT(var->backward_hooks);
|
||||
Py_VISIT(var->post_accumulate_grad_hooks);
|
||||
if (!var->cdata.unsafeIsBorrowed()) {
|
||||
const auto& tensor = THPVariable_Unpack(var);
|
||||
if (tensor.defined()) {
|
||||
|
@ -22,10 +22,6 @@ struct THPVariable {
|
||||
// Hooks to be run on backwards pass (corresponds to Python attr
|
||||
// '_backwards_hooks', set by 'register_hook')
|
||||
PyObject* backward_hooks = nullptr;
|
||||
// Hooks to be run in the backwards pass after accumulate grad,
|
||||
// i.e., after the .grad has been set (corresponds to Python attr
|
||||
// '_post_accumulate_grad_hooks', set by 'register_post_accumulate_grad_hook')
|
||||
PyObject* post_accumulate_grad_hooks = nullptr;
|
||||
};
|
||||
|
||||
TORCH_PYTHON_API void registerPythonTensorClass(
|
||||
|
@ -363,19 +363,6 @@ void clear_hooks(const at::TensorBase& self) {
|
||||
materialize_autograd_meta(self)->hooks_.clear();
|
||||
}
|
||||
|
||||
void set_post_acc_grad_hooks(
|
||||
const at::TensorBase& self,
|
||||
std::unique_ptr<PostAccumulateGradHook> dict) {
|
||||
AutogradMeta* meta = materialize_autograd_meta(self);
|
||||
meta->post_acc_grad_hooks_ = std::move(dict);
|
||||
}
|
||||
|
||||
std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
|
||||
const Variable& self) {
|
||||
TORCH_INTERNAL_ASSERT(get_autograd_meta(self));
|
||||
return get_autograd_meta(self)->post_acc_grad_hooks_;
|
||||
}
|
||||
|
||||
void set_name(const Variable& self, const std::string& name) {
|
||||
materialize_autograd_meta(self)->name_ = name;
|
||||
}
|
||||
|
@ -186,12 +186,6 @@ TORCH_API void add_hook(
|
||||
TORCH_API std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable&);
|
||||
TORCH_API void clear_hooks(const at::TensorBase&);
|
||||
|
||||
TORCH_API void set_post_acc_grad_hooks(
|
||||
const at::TensorBase&,
|
||||
std::unique_ptr<PostAccumulateGradHook> dict);
|
||||
TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
|
||||
const Variable&);
|
||||
|
||||
TORCH_API void create_cpp_hook(
|
||||
const at::TensorBase&,
|
||||
bool is_retains_grad_hooks = false);
|
||||
@ -237,12 +231,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
|
||||
std::vector<std::unique_ptr<FunctionPreHook>> hooks_;
|
||||
std::shared_ptr<hooks_list> cpp_hooks_list_;
|
||||
|
||||
// The post_acc_grad_hooks_ field stores only Python hooks
|
||||
// (PyFunctionTensorPostAccGradHooks) that are called after the
|
||||
// .grad field has been accumulated into. This is less complicated
|
||||
// than the hooks_ field, which encapsulates a lot more.
|
||||
std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;
|
||||
|
||||
// Only meaningful on leaf variables (must be false otherwise)
|
||||
bool requires_grad_{false};
|
||||
|
||||
|
@ -347,9 +347,6 @@ class CompiledNodeArgs {
|
||||
TORCH_CHECK(
|
||||
fn->retains_grad_hooks().empty(),
|
||||
"retains_grad_hooks not implemented for compiled autograd");
|
||||
TORCH_CHECK(
|
||||
fn->tensor_post_acc_grad_hooks() == nullptr,
|
||||
"tensor_post_acc_grad_hooks not implemented for compiled autograd");
|
||||
for (auto& i : fn->tensor_pre_hooks()) {
|
||||
i->compiled_args(*this);
|
||||
}
|
||||
|
@ -1205,7 +1205,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
Tensor.mT.__get__: lambda self: -1,
|
||||
Tensor.mH.__get__: lambda self: -1,
|
||||
Tensor._backward_hooks.__get__: lambda self: -1,
|
||||
Tensor._post_accumulate_grad_hooks.__get__: lambda self: -1,
|
||||
Tensor._base.__get__: lambda self: -1,
|
||||
Tensor._cdata.__get__: lambda self: -1,
|
||||
Tensor.grad.__get__: lambda self: -1,
|
||||
@ -1329,7 +1328,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
Tensor.record_stream: lambda self, stream: -1,
|
||||
Tensor.refine_names: lambda self, names: -1,
|
||||
Tensor.register_hook: lambda self, hook: -1,
|
||||
Tensor.register_post_accumulate_grad_hook: lambda self, hook: -1,
|
||||
Tensor.rename: lambda self, name: -1,
|
||||
Tensor.repeat: lambda self, *size: -1,
|
||||
Tensor.requires_grad_: lambda self, requires_grad=True: -1,
|
||||
|
Reference in New Issue
Block a user