Clear custom autograd Function ctx.to_save earlier (#161171)

Fixes https://github.com/pytorch/pytorch/issues/161186

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161171
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2025-09-01 19:26:15 -04:00
committed by PyTorch MergeBot
parent d5e0f4202b
commit 8171d6052e
3 changed files with 35 additions and 3 deletions

View File

@ -5183,6 +5183,7 @@ known_graph_breaks_tests = {
"test_nested_checkpoint_set_early_stop", # dynamo disable
"test_nested_checkpoint_two_children_early_stop_False", # dynamo disable
"test_nested_checkpoint_two_children_early_stop_True", # dynamo disable
"test_custom_autograd_ac_early_stop", # marked as skipped
"test_dropout", # dynamo disable
"test_dropout_inductor", # dynamo disable
"test_function_with_kwargs", # dynamo disable

View File

@ -3888,6 +3888,38 @@ class TestAutograd(TestCase):
torch.autograd.grad(y, x, create_graph=True)
torch.autograd.grad(y, x) # should not error!
def test_custom_autograd_ac_early_stop(self):
refs = []
class Test(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = x.clone()
ctx.save_for_backward(y)
refs.append(weakref.ref(y))
return y
@staticmethod
def backward(ctx, *args):
_ = ctx.saved_tensors
return None
def fn(inp):
return Test.apply(inp)
inp = torch.randn(5, 5, requires_grad=True)
def scope():
# Early-stop is true by default in non-reentrant torch.utils.checkpoint
out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=False)
out.sum().backward()
with disable_gc():
scope()
for ref in refs:
self.assertIsNone(ref())
def test_detach(self):
x = torch.randn(10, 10, requires_grad=True)
y = x + 2

View File

@ -803,6 +803,7 @@ static void _get_tensors_to_save(
}
}
}
Py_CLEAR(self->to_save);
}
}
// Save any variables that requested by to_save
@ -810,7 +811,7 @@ static void _save_variables(
const std::vector<std::optional<at::Tensor>>& tensors_to_save,
const std::shared_ptr<PyNode>& cdata_ptr,
THPFunction* self) {
if (!self->to_save)
if (tensors_to_save.size() == 0)
return;
size_t num_saved = tensors_to_save.size();
self->saved_variables.clear();
@ -823,8 +824,6 @@ static void _save_variables(
self->saved_variables.emplace_back(opt_tensor.value(), is_output);
}
}
// Free .to_save
Py_CLEAR(self->to_save);
}
// Mark requires_grad = 0 on non-differentiable variables (as per