mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Relax use_count constraints for swap_tensors when AccumulateGrad holds a reference (#127313)
### Before this PR: `torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1 ```python a = torch.randn(2, 3, requires_grad=True) b = torch.randn(2, 4) out = a * 2 out.sum().backward() # Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward # torch.utils.swap_tensors(a, b) del out # Calling swap_tensors here would pass torch.utils.swap_tensors(a, b) ``` ### After this PR: `torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad` A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph). ```python a = torch.randn(2, 3, requires_grad=True) b = torch.randn(2, 4) out = a * 2 out.sum().backward() # Calling swap_tensors here is ok torch.utils.swap_tensors(a, b) # If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors ``` ### Application to `nn.Module` This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address https://github.com/pytorch/pytorch/pull/126814#issuecomment-2127777866. Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node ```python import torch import torch.nn as nn m = nn.Linear(3, 5) inp = torch.randn(2, 3) out = m(inp) out.sum().backward() m.cpu() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127313 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
d44ab8ba6d
commit
cd06ae0cb8
@ -863,7 +863,8 @@ class TestModule(TestCase):
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown error type {error_input.error_on}")
|
||||
|
||||
@modules([module for module in module_db if not module.is_lazy])
|
||||
# Only run this test for float32 because the test loops over all the dtypes
|
||||
@modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
|
||||
@parametrize('swap', [True, False])
|
||||
@parametrize('set_grad', [True, False])
|
||||
@wrapSwapTensorsTest()
|
||||
@ -879,6 +880,7 @@ class TestModule(TestCase):
|
||||
|
||||
for module_input in module_inputs:
|
||||
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
||||
m = module_cls(*c_args, **c_kwargs)
|
||||
|
||||
@ -896,6 +898,17 @@ class TestModule(TestCase):
|
||||
setattr(m, n, new_b)
|
||||
_to(m, set_grad=set_grad)
|
||||
|
||||
# Check .to() can be run after forward and backward with swap
|
||||
has_params = len(list(m.parameters())) > 0
|
||||
if swap and not set_grad and has_params:
|
||||
out = m(*args, **kwargs)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
out.sum().backward()
|
||||
m.to(dtype=torch.half)
|
||||
# reset
|
||||
m.to(dtype=torch.float32)
|
||||
|
||||
prev_device, prev_dtype = device, dtype
|
||||
for device_, dtype_ in product(devices, dtypes):
|
||||
# if device/dtype do not change, grad.to(device, dtype) is a no-op so
|
||||
@ -903,6 +916,7 @@ class TestModule(TestCase):
|
||||
# parameters will be wrapped in an nn.Parameter before swapping
|
||||
# which will cause the ._cdata to change
|
||||
g_no_swap = device_ == prev_device and dtype_ == prev_dtype
|
||||
prev_prev_device, prev_prev_dtype = prev_device, prev_dtype
|
||||
prev_device, prev_dtype = device_, dtype_
|
||||
|
||||
p_ids_before = [id(p) for p in m.parameters()]
|
||||
@ -940,7 +954,6 @@ class TestModule(TestCase):
|
||||
self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after)))
|
||||
self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
|
||||
|
||||
|
||||
@modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
|
||||
@parametrize('swap', [True, False])
|
||||
@wrapSwapTensorsTest()
|
||||
|
||||
@ -1594,19 +1594,29 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
finally:
|
||||
torch.__future__.set_overwrite_module_params_on_conversion(False)
|
||||
|
||||
def test_swap_module_params_fails_after_forward(self):
|
||||
def test_swap_module_params_poisons_acc_grad(self):
|
||||
try:
|
||||
torch.__future__.set_swap_module_params_on_conversion(True)
|
||||
# (1) backward cannot be run after _apply
|
||||
# forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors
|
||||
# additionally, if any Tensors are saved for backward, their use_count will be bumped
|
||||
m = torch.nn.Linear(2, 3)
|
||||
inp = torch.randn(2, 2)
|
||||
# forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors
|
||||
out = m(inp)
|
||||
with self.assertRaisesRegex(RuntimeError, re.escape("_apply(): Couldn't swap Linear.weight")):
|
||||
m.half()
|
||||
del out
|
||||
# works as expected now
|
||||
m.half()
|
||||
self.assertTrue(all(p.dtype == torch.float16 for p in m.parameters()))
|
||||
with self.assertRaisesRegex(RuntimeError, "Trying to execute AccumulateGrad node that was poisoned by swap_tensors"):
|
||||
out.sum().backward()
|
||||
# (2) _apply can be run after backward()
|
||||
# After running backward, all the references generated by "save for backward" will be cleared
|
||||
# So the use_count will be 2 (1 from Tensor itself, and 1 from AccumulateGrad node), swap_tensors
|
||||
# should allow this.
|
||||
inp2 = torch.randn(2, 2, dtype=torch.half)
|
||||
out2 = m(inp2)
|
||||
out2.sum().backward()
|
||||
m.float()
|
||||
self.assertTrue(all(p.dtype == torch.float32 for p in m.parameters()))
|
||||
out3 = m(inp)
|
||||
finally:
|
||||
torch.__future__.set_swap_module_params_on_conversion(False)
|
||||
|
||||
|
||||
@ -10623,12 +10623,9 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
if t1.is_floating_point():
|
||||
t3 = t1.clone().detach().requires_grad_(True)
|
||||
out = t3 * 2
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected single reference to a's"):
|
||||
torch.utils.swap_tensors(t3, t2)
|
||||
del out
|
||||
# Now succeeds
|
||||
torch.utils.swap_tensors(t3, t2)
|
||||
torch.utils.swap_tensors(t1, t2)
|
||||
with self.assertRaisesRegex(RuntimeError, "AccumulateGrad node that was poisoned by swap_tensors"):
|
||||
out.sum().backward()
|
||||
|
||||
wr = weakref.ref(t1)
|
||||
with self.assertRaisesRegex(RuntimeError, "has weakref"):
|
||||
|
||||
@ -375,22 +375,14 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
|
||||
THPVariable* a = reinterpret_cast<THPVariable*>(a_);
|
||||
THPVariable* b = reinterpret_cast<THPVariable*>(b_);
|
||||
|
||||
TORCH_CHECK(
|
||||
a->cdata->use_count() == 1,
|
||||
"Expected single reference to a's Tensor object but got ",
|
||||
a->cdata->use_count());
|
||||
TORCH_CHECK(
|
||||
b->cdata->use_count() == 1,
|
||||
"Expected single reference to b's Tensor object but got ",
|
||||
b->cdata->use_count());
|
||||
// weak_use_count() adds 1 if use_count is non-zero
|
||||
TORCH_CHECK(
|
||||
a->cdata->weak_use_count() == 1,
|
||||
"Expected no weakrefs to a's Tensor object but got ",
|
||||
"Expected no weakrefs to t1's Tensor object but got ",
|
||||
a->cdata->weak_use_count() - 1);
|
||||
TORCH_CHECK(
|
||||
b->cdata->weak_use_count() == 1,
|
||||
"Expected no weakrefs to b's Tensor object but got ",
|
||||
"Expected no weakrefs to t2's Tensor object but got ",
|
||||
b->cdata->weak_use_count() - 1);
|
||||
|
||||
// Swap the Tensor Impl
|
||||
|
||||
@ -1615,6 +1615,13 @@ int THPVariable_set_imag(PyObject* self, PyObject* imag, void* unused) {
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
PyObject* THPVariable__use_count(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& t = THPVariable_Unpack(self);
|
||||
return THPUtils_packUInt64(t.use_count());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// properties are registered here because we are currently only able to bind
|
||||
// them manually. TODO: make declarable in native_functions
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
||||
@ -1766,6 +1773,7 @@ static PyMethodDef extra_methods[] = {
|
||||
THPVariable_rev_view_func_unsafe,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_use_count", THPVariable__use_count, METH_NOARGS, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
struct THPVariableMeta {
|
||||
|
||||
@ -357,6 +357,7 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
Tensor._is_any_true,
|
||||
Tensor._addmm_activation,
|
||||
Tensor.to_padded_tensor,
|
||||
Tensor._use_count,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -46,6 +46,32 @@ def swap_tensors(t1, t2):
|
||||
setattr(t1, name, (getattr(t2, name)))
|
||||
setattr(t2, name, tmp)
|
||||
|
||||
def error_pre_hook(grad_outputs):
|
||||
raise RuntimeError("Trying to execute AccumulateGrad node that was poisoned by swap_tensors "
|
||||
"this can happen when you try to run backward on a tensor that was swapped. "
|
||||
"For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` "
|
||||
"you should not change the device or dtype of the module (e.g. `m.cpu()` or `m.half()`) "
|
||||
"between running forward and backward. To resolve this, please only change the "
|
||||
"device/dtype before running forward (or after both forward and backward).")
|
||||
|
||||
def check_use_count(t, name='t1'):
|
||||
use_count = t._use_count()
|
||||
error_str = (f"Expected use_count of {name} to be 1 or 2 with an AccumulateGrad node but got {use_count} "
|
||||
f"make sure you are not holding references to the tensor in other places.")
|
||||
if use_count > 1:
|
||||
if use_count == 2 and t.is_leaf:
|
||||
accum_grad_node = torch.autograd.graph.get_gradient_edge(t).node
|
||||
# Make sure that the accumulate_grad node was not lazy_init-ed by get_gradient_edge
|
||||
if t._use_count() == 2:
|
||||
accum_grad_node.register_prehook(error_pre_hook)
|
||||
else:
|
||||
raise RuntimeError(error_str)
|
||||
else:
|
||||
raise RuntimeError(error_str)
|
||||
|
||||
check_use_count(t1, 't1')
|
||||
check_use_count(t2, 't2')
|
||||
|
||||
# Swap the types
|
||||
# Note that this will fail if there are mismatched slots
|
||||
swap_attr("__class__")
|
||||
|
||||
Reference in New Issue
Block a user