Add torch.func.debug_unwrap (#146528)

Use it to unwrap any functorch-wrapped tensor. I don't recommend using
the output in a program since it breaks the semantics of the transforms,
but it seems useful for debugging.

I will note that some people have wanted to get intermediate values out
of an e.g. grad transform, so this might be a way to do that...

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146528
Approved by: https://github.com/Chillee
This commit is contained in:
rzou
2025-02-06 07:42:55 -08:00
committed by PyTorch MergeBot
parent 49082f9dba
commit 15b1ac3e86
4 changed files with 41 additions and 0 deletions

View File

@ -76,3 +76,12 @@ guidance here
:maxdepth: 1 :maxdepth: 1
func.batch_norm func.batch_norm
Debug utilities
---------------
.. autosummary::
:toctree: generated
:nosignatures:
debug_unwrap

View File

@ -3261,6 +3261,18 @@ class TestHelpers(TestCase):
out = A.apply(x, y) out = A.apply(x, y)
out.backward() out.backward()
def test_debug_unwrap(self):
stuff = []
def f(x):
stuff.append(torch.func.debug_unwrap(x))
return x.sin()
x = torch.randn(2, 3)
_ = vmap(vmap(f))(x)
self.assertEqual(stuff[0], x)
self.assertTrue(stuff[0] is x)
def test_reductify_leaf(self, device): def test_reductify_leaf(self, device):
reductify_leaf = torch._functorch.autograd_function.reductify_leaf reductify_leaf = torch._functorch.autograd_function.reductify_leaf
B = 2 B = 2

View File

@ -26,6 +26,8 @@ from torch._C._functorch import (
_wrap_for_grad, _wrap_for_grad,
_wrap_functional_tensor, _wrap_functional_tensor,
get_inplace_requires_grad_allowed, get_inplace_requires_grad_allowed,
get_unwrapped,
is_functorch_wrapped_tensor,
set_inplace_requires_grad_allowed, set_inplace_requires_grad_allowed,
) )
from torch._functorch.utils import argnums_t, exposed_in from torch._functorch.utils import argnums_t, exposed_in
@ -1795,3 +1797,19 @@ def linearize(func: Callable, *primals) -> tuple[Any, Callable]:
return tree_unflatten(flat_output, output_spec) return tree_unflatten(flat_output, output_spec)
return output, jvp_fn return output, jvp_fn
@exposed_in("torch.func")
def debug_unwrap(tensor: torch.Tensor, *, recurse=True) -> torch.Tensor:
"""Unwraps a functorch tensor (e.g. BatchedTensor, GradTrackingTensor) to its underlying tensor.
This function should only be used in a debug setting (e.g. trying to print the
value of a Tensor in a debugger). Otherwise, using the result of function
inside of a function being transformed will lead to undefined behavior.
"""
if not is_functorch_wrapped_tensor(tensor):
return tensor
result = get_unwrapped(tensor)
if recurse:
return debug_unwrap(result)
return result

View File

@ -1,6 +1,7 @@
from torch._functorch.apis import grad, grad_and_value, vmap from torch._functorch.apis import grad, grad_and_value, vmap
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.eager_transforms import ( from torch._functorch.eager_transforms import (
debug_unwrap,
functionalize, functionalize,
hessian, hessian,
jacfwd, jacfwd,
@ -26,4 +27,5 @@ __all__ = [
"vjp", "vjp",
"functional_call", "functional_call",
"stack_module_state", "stack_module_state",
"debug_unwrap",
] ]