mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Relax use_count constraints for swap_tensors when AccumulateGrad holds a reference (#127313)
### Before this PR: `torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1 ```python a = torch.randn(2, 3, requires_grad=True) b = torch.randn(2, 4) out = a * 2 out.sum().backward() # Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward # torch.utils.swap_tensors(a, b) del out # Calling swap_tensors here would pass torch.utils.swap_tensors(a, b) ``` ### After this PR: `torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad` A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph). ```python a = torch.randn(2, 3, requires_grad=True) b = torch.randn(2, 4) out = a * 2 out.sum().backward() # Calling swap_tensors here is ok torch.utils.swap_tensors(a, b) # If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors ``` ### Application to `nn.Module` This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address https://github.com/pytorch/pytorch/pull/126814#issuecomment-2127777866. Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node ```python import torch import torch.nn as nn m = nn.Linear(3, 5) inp = torch.randn(2, 3) out = m(inp) out.sum().backward() m.cpu() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127313 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
d44ab8ba6d
commit
cd06ae0cb8
@ -375,22 +375,14 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
|
||||
THPVariable* a = reinterpret_cast<THPVariable*>(a_);
|
||||
THPVariable* b = reinterpret_cast<THPVariable*>(b_);
|
||||
|
||||
TORCH_CHECK(
|
||||
a->cdata->use_count() == 1,
|
||||
"Expected single reference to a's Tensor object but got ",
|
||||
a->cdata->use_count());
|
||||
TORCH_CHECK(
|
||||
b->cdata->use_count() == 1,
|
||||
"Expected single reference to b's Tensor object but got ",
|
||||
b->cdata->use_count());
|
||||
// weak_use_count() adds 1 if use_count is non-zero
|
||||
TORCH_CHECK(
|
||||
a->cdata->weak_use_count() == 1,
|
||||
"Expected no weakrefs to a's Tensor object but got ",
|
||||
"Expected no weakrefs to t1's Tensor object but got ",
|
||||
a->cdata->weak_use_count() - 1);
|
||||
TORCH_CHECK(
|
||||
b->cdata->weak_use_count() == 1,
|
||||
"Expected no weakrefs to b's Tensor object but got ",
|
||||
"Expected no weakrefs to t2's Tensor object but got ",
|
||||
b->cdata->weak_use_count() - 1);
|
||||
|
||||
// Swap the Tensor Impl
|
||||
|
Reference in New Issue
Block a user