Relax use_count constraints for swap_tensors when AccumulateGrad holds a reference (#127313)

### Before this PR:
`torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1

```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward
# torch.utils.swap_tensors(a, b)
del out
# Calling swap_tensors here would pass
torch.utils.swap_tensors(a, b)
```
### After this PR:
`torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad`

A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph).

```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here is ok
torch.utils.swap_tensors(a, b)
# If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors
```

### Application to `nn.Module`

This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address https://github.com/pytorch/pytorch/pull/126814#issuecomment-2127777866. Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node

```python
import torch
import torch.nn as nn
m = nn.Linear(3, 5)
inp = torch.randn(2, 3)
out = m(inp)
out.sum().backward()
m.cpu()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127313
Approved by: https://github.com/soulitzer
This commit is contained in:
Mikayla Gawarecki
2024-05-29 17:14:05 -07:00
committed by PyTorch MergeBot
parent d44ab8ba6d
commit cd06ae0cb8
7 changed files with 70 additions and 23 deletions

View File

@ -375,22 +375,14 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
THPVariable* a = reinterpret_cast<THPVariable*>(a_);
THPVariable* b = reinterpret_cast<THPVariable*>(b_);
TORCH_CHECK(
a->cdata->use_count() == 1,
"Expected single reference to a's Tensor object but got ",
a->cdata->use_count());
TORCH_CHECK(
b->cdata->use_count() == 1,
"Expected single reference to b's Tensor object but got ",
b->cdata->use_count());
// weak_use_count() adds 1 if use_count is non-zero
TORCH_CHECK(
a->cdata->weak_use_count() == 1,
"Expected no weakrefs to a's Tensor object but got ",
"Expected no weakrefs to t1's Tensor object but got ",
a->cdata->weak_use_count() - 1);
TORCH_CHECK(
b->cdata->weak_use_count() == 1,
"Expected no weakrefs to b's Tensor object but got ",
"Expected no weakrefs to t2's Tensor object but got ",
b->cdata->weak_use_count() - 1);
// Swap the Tensor Impl