mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Clear custom autograd Function ctx.to_save earlier (#161171)
Fixes https://github.com/pytorch/pytorch/issues/161186 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161171 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5e0f4202b
commit
8171d6052e
@ -5183,6 +5183,7 @@ known_graph_breaks_tests = {
|
||||
"test_nested_checkpoint_set_early_stop", # dynamo disable
|
||||
"test_nested_checkpoint_two_children_early_stop_False", # dynamo disable
|
||||
"test_nested_checkpoint_two_children_early_stop_True", # dynamo disable
|
||||
"test_custom_autograd_ac_early_stop", # marked as skipped
|
||||
"test_dropout", # dynamo disable
|
||||
"test_dropout_inductor", # dynamo disable
|
||||
"test_function_with_kwargs", # dynamo disable
|
||||
|
@ -3888,6 +3888,38 @@ class TestAutograd(TestCase):
|
||||
torch.autograd.grad(y, x, create_graph=True)
|
||||
torch.autograd.grad(y, x) # should not error!
|
||||
|
||||
def test_custom_autograd_ac_early_stop(self):
|
||||
refs = []
|
||||
|
||||
class Test(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
y = x.clone()
|
||||
ctx.save_for_backward(y)
|
||||
refs.append(weakref.ref(y))
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
_ = ctx.saved_tensors
|
||||
return None
|
||||
|
||||
def fn(inp):
|
||||
return Test.apply(inp)
|
||||
|
||||
inp = torch.randn(5, 5, requires_grad=True)
|
||||
|
||||
def scope():
|
||||
# Early-stop is true by default in non-reentrant torch.utils.checkpoint
|
||||
out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=False)
|
||||
out.sum().backward()
|
||||
|
||||
with disable_gc():
|
||||
scope()
|
||||
|
||||
for ref in refs:
|
||||
self.assertIsNone(ref())
|
||||
|
||||
def test_detach(self):
|
||||
x = torch.randn(10, 10, requires_grad=True)
|
||||
y = x + 2
|
||||
|
@ -803,6 +803,7 @@ static void _get_tensors_to_save(
|
||||
}
|
||||
}
|
||||
}
|
||||
Py_CLEAR(self->to_save);
|
||||
}
|
||||
}
|
||||
// Save any variables that requested by to_save
|
||||
@ -810,7 +811,7 @@ static void _save_variables(
|
||||
const std::vector<std::optional<at::Tensor>>& tensors_to_save,
|
||||
const std::shared_ptr<PyNode>& cdata_ptr,
|
||||
THPFunction* self) {
|
||||
if (!self->to_save)
|
||||
if (tensors_to_save.size() == 0)
|
||||
return;
|
||||
size_t num_saved = tensors_to_save.size();
|
||||
self->saved_variables.clear();
|
||||
@ -823,8 +824,6 @@ static void _save_variables(
|
||||
self->saved_variables.emplace_back(opt_tensor.value(), is_output);
|
||||
}
|
||||
}
|
||||
// Free .to_save
|
||||
Py_CLEAR(self->to_save);
|
||||
}
|
||||
|
||||
// Mark requires_grad = 0 on non-differentiable variables (as per
|
||||
|
Reference in New Issue
Block a user