diff --git a/docs/source/notes/autograd.rst b/docs/source/notes/autograd.rst index f4caca9495a3..1ed6d292ab8d 100644 --- a/docs/source/notes/autograd.rst +++ b/docs/source/notes/autograd.rst @@ -793,7 +793,7 @@ Example: def pack_hook(x): if x.numel() < SAVE_ON_DISK_THRESHOLD: - return x + return x.detach() temp_file = SelfDeletingTempFile() torch.save(tensor, temp_file.name) return temp_file @@ -833,7 +833,7 @@ Tensor object creation. For example: .. code:: - with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): + with torch.autograd.graph.saved_tensors_hooks(lambda x: x.detach(), lambda x: x): x = torch.randn(5, requires_grad=True) y = x * x diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index a0e182d57901..0e36f89ca085 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -272,7 +272,7 @@ class saved_tensors_hooks: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) >>> def pack_hook(x): ... print("Packing", x) - ... return x + ... return x.detach() >>> >>> def unpack_hook(x): ... print("Unpacking", x) @@ -295,6 +295,11 @@ class saved_tensors_hooks: .. warning :: Only one pair of hooks is allowed at a time. When recursively nesting this context-manager, only the inner-most pair of hooks will be applied. + + .. warning :: + To avoid reference cycle, the return value of ``pack_hook`` cannot hold a + reference to the input tensor. For example, use `lambda x: x.detach()` + instead of `lambda x: x` as the pack hook. """ def __init__(