Update docs of saved_tensors_hooks to avoid ref cycle (#153049)

Fixes #115255

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153049
Approved by: https://github.com/Skylion007, https://github.com/soulitzer
This commit is contained in:
Yuxin Wu
2025-05-07 18:54:51 +00:00
committed by PyTorch MergeBot
parent 7cf8049d63
commit 2cf7fd0d2b
2 changed files with 8 additions and 3 deletions

View File

@ -793,7 +793,7 @@ Example:
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x
return x.detach()
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
@ -833,7 +833,7 @@ Tensor object creation. For example:
.. code::
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
with torch.autograd.graph.saved_tensors_hooks(lambda x: x.detach(), lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x

View File

@ -272,7 +272,7 @@ class saved_tensors_hooks:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def pack_hook(x):
... print("Packing", x)
... return x
... return x.detach()
>>>
>>> def unpack_hook(x):
... print("Unpacking", x)
@ -295,6 +295,11 @@ class saved_tensors_hooks:
.. warning ::
Only one pair of hooks is allowed at a time. When recursively nesting this
context-manager, only the inner-most pair of hooks will be applied.
.. warning ::
To avoid reference cycle, the return value of ``pack_hook`` cannot hold a
reference to the input tensor. For example, use `lambda x: x.detach()`
instead of `lambda x: x` as the pack hook.
"""
def __init__(