Relax use_count constraints for swap_tensors when AccumulateGrad holds a reference (#127313)

### Before this PR:
`torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1

```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward
# torch.utils.swap_tensors(a, b)
del out
# Calling swap_tensors here would pass
torch.utils.swap_tensors(a, b)
```
### After this PR:
`torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad`

A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph).

```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here is ok
torch.utils.swap_tensors(a, b)
# If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors
```

### Application to `nn.Module`

This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address https://github.com/pytorch/pytorch/pull/126814#issuecomment-2127777866. Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node

```python
import torch
import torch.nn as nn
m = nn.Linear(3, 5)
inp = torch.randn(2, 3)
out = m(inp)
out.sum().backward()
m.cpu()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127313
Approved by: https://github.com/soulitzer
This commit is contained in:
Mikayla Gawarecki
2024-05-29 17:14:05 -07:00
committed by PyTorch MergeBot
parent d44ab8ba6d
commit cd06ae0cb8
7 changed files with 70 additions and 23 deletions

View File

@ -863,7 +863,8 @@ class TestModule(TestCase):
else:
raise NotImplementedError(f"Unknown error type {error_input.error_on}")
@modules([module for module in module_db if not module.is_lazy])
# Only run this test for float32 because the test loops over all the dtypes
@modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
@parametrize('swap', [True, False])
@parametrize('set_grad', [True, False])
@wrapSwapTensorsTest()
@ -879,6 +880,7 @@ class TestModule(TestCase):
for module_input in module_inputs:
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
m = module_cls(*c_args, **c_kwargs)
@ -896,6 +898,17 @@ class TestModule(TestCase):
setattr(m, n, new_b)
_to(m, set_grad=set_grad)
# Check .to() can be run after forward and backward with swap
has_params = len(list(m.parameters())) > 0
if swap and not set_grad and has_params:
out = m(*args, **kwargs)
if isinstance(out, tuple):
out = out[0]
out.sum().backward()
m.to(dtype=torch.half)
# reset
m.to(dtype=torch.float32)
prev_device, prev_dtype = device, dtype
for device_, dtype_ in product(devices, dtypes):
# if device/dtype do not change, grad.to(device, dtype) is a no-op so
@ -903,6 +916,7 @@ class TestModule(TestCase):
# parameters will be wrapped in an nn.Parameter before swapping
# which will cause the ._cdata to change
g_no_swap = device_ == prev_device and dtype_ == prev_dtype
prev_prev_device, prev_prev_dtype = prev_device, prev_dtype
prev_device, prev_dtype = device_, dtype_
p_ids_before = [id(p) for p in m.parameters()]
@ -940,7 +954,6 @@ class TestModule(TestCase):
self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after)))
self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
@modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
@parametrize('swap', [True, False])
@wrapSwapTensorsTest()

View File

@ -1594,19 +1594,29 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
finally:
torch.__future__.set_overwrite_module_params_on_conversion(False)
def test_swap_module_params_fails_after_forward(self):
def test_swap_module_params_poisons_acc_grad(self):
try:
torch.__future__.set_swap_module_params_on_conversion(True)
# (1) backward cannot be run after _apply
# forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors
# additionally, if any Tensors are saved for backward, their use_count will be bumped
m = torch.nn.Linear(2, 3)
inp = torch.randn(2, 2)
# forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors
out = m(inp)
with self.assertRaisesRegex(RuntimeError, re.escape("_apply(): Couldn't swap Linear.weight")):
m.half()
del out
# works as expected now
m.half()
self.assertTrue(all(p.dtype == torch.float16 for p in m.parameters()))
with self.assertRaisesRegex(RuntimeError, "Trying to execute AccumulateGrad node that was poisoned by swap_tensors"):
out.sum().backward()
# (2) _apply can be run after backward()
# After running backward, all the references generated by "save for backward" will be cleared
# So the use_count will be 2 (1 from Tensor itself, and 1 from AccumulateGrad node), swap_tensors
# should allow this.
inp2 = torch.randn(2, 2, dtype=torch.half)
out2 = m(inp2)
out2.sum().backward()
m.float()
self.assertTrue(all(p.dtype == torch.float32 for p in m.parameters()))
out3 = m(inp)
finally:
torch.__future__.set_swap_module_params_on_conversion(False)

View File

@ -10623,12 +10623,9 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
if t1.is_floating_point():
t3 = t1.clone().detach().requires_grad_(True)
out = t3 * 2
with self.assertRaisesRegex(RuntimeError, "Expected single reference to a's"):
torch.utils.swap_tensors(t3, t2)
del out
# Now succeeds
torch.utils.swap_tensors(t3, t2)
torch.utils.swap_tensors(t1, t2)
with self.assertRaisesRegex(RuntimeError, "AccumulateGrad node that was poisoned by swap_tensors"):
out.sum().backward()
wr = weakref.ref(t1)
with self.assertRaisesRegex(RuntimeError, "has weakref"):

View File

@ -375,22 +375,14 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
THPVariable* a = reinterpret_cast<THPVariable*>(a_);
THPVariable* b = reinterpret_cast<THPVariable*>(b_);
TORCH_CHECK(
a->cdata->use_count() == 1,
"Expected single reference to a's Tensor object but got ",
a->cdata->use_count());
TORCH_CHECK(
b->cdata->use_count() == 1,
"Expected single reference to b's Tensor object but got ",
b->cdata->use_count());
// weak_use_count() adds 1 if use_count is non-zero
TORCH_CHECK(
a->cdata->weak_use_count() == 1,
"Expected no weakrefs to a's Tensor object but got ",
"Expected no weakrefs to t1's Tensor object but got ",
a->cdata->weak_use_count() - 1);
TORCH_CHECK(
b->cdata->weak_use_count() == 1,
"Expected no weakrefs to b's Tensor object but got ",
"Expected no weakrefs to t2's Tensor object but got ",
b->cdata->weak_use_count() - 1);
// Swap the Tensor Impl

View File

@ -1615,6 +1615,13 @@ int THPVariable_set_imag(PyObject* self, PyObject* imag, void* unused) {
END_HANDLE_TH_ERRORS_RET(-1)
}
PyObject* THPVariable__use_count(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
const auto& t = THPVariable_Unpack(self);
return THPUtils_packUInt64(t.use_count());
END_HANDLE_TH_ERRORS
}
// properties are registered here because we are currently only able to bind
// them manually. TODO: make declarable in native_functions
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
@ -1766,6 +1773,7 @@ static PyMethodDef extra_methods[] = {
THPVariable_rev_view_func_unsafe,
METH_O,
nullptr},
{"_use_count", THPVariable__use_count, METH_NOARGS, nullptr},
{nullptr}};
struct THPVariableMeta {

View File

@ -357,6 +357,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._is_any_true,
Tensor._addmm_activation,
Tensor.to_padded_tensor,
Tensor._use_count,
}

View File

@ -46,6 +46,32 @@ def swap_tensors(t1, t2):
setattr(t1, name, (getattr(t2, name)))
setattr(t2, name, tmp)
def error_pre_hook(grad_outputs):
raise RuntimeError("Trying to execute AccumulateGrad node that was poisoned by swap_tensors "
"this can happen when you try to run backward on a tensor that was swapped. "
"For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` "
"you should not change the device or dtype of the module (e.g. `m.cpu()` or `m.half()`) "
"between running forward and backward. To resolve this, please only change the "
"device/dtype before running forward (or after both forward and backward).")
def check_use_count(t, name='t1'):
use_count = t._use_count()
error_str = (f"Expected use_count of {name} to be 1 or 2 with an AccumulateGrad node but got {use_count} "
f"make sure you are not holding references to the tensor in other places.")
if use_count > 1:
if use_count == 2 and t.is_leaf:
accum_grad_node = torch.autograd.graph.get_gradient_edge(t).node
# Make sure that the accumulate_grad node was not lazy_init-ed by get_gradient_edge
if t._use_count() == 2:
accum_grad_node.register_prehook(error_pre_hook)
else:
raise RuntimeError(error_str)
else:
raise RuntimeError(error_str)
check_use_count(t1, 't1')
check_use_count(t2, 't2')
# Swap the types
# Note that this will fail if there are mismatched slots
swap_attr("__class__")