mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch.func.debug_unwrap (#146528)
Use it to unwrap any functorch-wrapped tensor. I don't recommend using the output in a program since it breaks the semantics of the transforms, but it seems useful for debugging. I will note that some people have wanted to get intermediate values out of an e.g. grad transform, so this might be a way to do that... Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/146528 Approved by: https://github.com/Chillee
This commit is contained in:
@ -76,3 +76,12 @@ guidance here
|
||||
:maxdepth: 1
|
||||
|
||||
func.batch_norm
|
||||
|
||||
Debug utilities
|
||||
---------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
debug_unwrap
|
||||
|
@ -3261,6 +3261,18 @@ class TestHelpers(TestCase):
|
||||
out = A.apply(x, y)
|
||||
out.backward()
|
||||
|
||||
def test_debug_unwrap(self):
|
||||
stuff = []
|
||||
|
||||
def f(x):
|
||||
stuff.append(torch.func.debug_unwrap(x))
|
||||
return x.sin()
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
_ = vmap(vmap(f))(x)
|
||||
self.assertEqual(stuff[0], x)
|
||||
self.assertTrue(stuff[0] is x)
|
||||
|
||||
def test_reductify_leaf(self, device):
|
||||
reductify_leaf = torch._functorch.autograd_function.reductify_leaf
|
||||
B = 2
|
||||
|
@ -26,6 +26,8 @@ from torch._C._functorch import (
|
||||
_wrap_for_grad,
|
||||
_wrap_functional_tensor,
|
||||
get_inplace_requires_grad_allowed,
|
||||
get_unwrapped,
|
||||
is_functorch_wrapped_tensor,
|
||||
set_inplace_requires_grad_allowed,
|
||||
)
|
||||
from torch._functorch.utils import argnums_t, exposed_in
|
||||
@ -1795,3 +1797,19 @@ def linearize(func: Callable, *primals) -> tuple[Any, Callable]:
|
||||
return tree_unflatten(flat_output, output_spec)
|
||||
|
||||
return output, jvp_fn
|
||||
|
||||
|
||||
@exposed_in("torch.func")
|
||||
def debug_unwrap(tensor: torch.Tensor, *, recurse=True) -> torch.Tensor:
|
||||
"""Unwraps a functorch tensor (e.g. BatchedTensor, GradTrackingTensor) to its underlying tensor.
|
||||
|
||||
This function should only be used in a debug setting (e.g. trying to print the
|
||||
value of a Tensor in a debugger). Otherwise, using the result of function
|
||||
inside of a function being transformed will lead to undefined behavior.
|
||||
"""
|
||||
if not is_functorch_wrapped_tensor(tensor):
|
||||
return tensor
|
||||
result = get_unwrapped(tensor)
|
||||
if recurse:
|
||||
return debug_unwrap(result)
|
||||
return result
|
||||
|
@ -1,6 +1,7 @@
|
||||
from torch._functorch.apis import grad, grad_and_value, vmap
|
||||
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
|
||||
from torch._functorch.eager_transforms import (
|
||||
debug_unwrap,
|
||||
functionalize,
|
||||
hessian,
|
||||
jacfwd,
|
||||
@ -26,4 +27,5 @@ __all__ = [
|
||||
"vjp",
|
||||
"functional_call",
|
||||
"stack_module_state",
|
||||
"debug_unwrap",
|
||||
]
|
||||
|
Reference in New Issue
Block a user