mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch.func.debug_unwrap (#146528)
Use it to unwrap any functorch-wrapped tensor. I don't recommend using the output in a program since it breaks the semantics of the transforms, but it seems useful for debugging. I will note that some people have wanted to get intermediate values out of an e.g. grad transform, so this might be a way to do that... Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/146528 Approved by: https://github.com/Chillee
This commit is contained in:
@ -76,3 +76,12 @@ guidance here
|
|||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
func.batch_norm
|
func.batch_norm
|
||||||
|
|
||||||
|
Debug utilities
|
||||||
|
---------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
debug_unwrap
|
||||||
|
@ -3261,6 +3261,18 @@ class TestHelpers(TestCase):
|
|||||||
out = A.apply(x, y)
|
out = A.apply(x, y)
|
||||||
out.backward()
|
out.backward()
|
||||||
|
|
||||||
|
def test_debug_unwrap(self):
|
||||||
|
stuff = []
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
stuff.append(torch.func.debug_unwrap(x))
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
_ = vmap(vmap(f))(x)
|
||||||
|
self.assertEqual(stuff[0], x)
|
||||||
|
self.assertTrue(stuff[0] is x)
|
||||||
|
|
||||||
def test_reductify_leaf(self, device):
|
def test_reductify_leaf(self, device):
|
||||||
reductify_leaf = torch._functorch.autograd_function.reductify_leaf
|
reductify_leaf = torch._functorch.autograd_function.reductify_leaf
|
||||||
B = 2
|
B = 2
|
||||||
|
@ -26,6 +26,8 @@ from torch._C._functorch import (
|
|||||||
_wrap_for_grad,
|
_wrap_for_grad,
|
||||||
_wrap_functional_tensor,
|
_wrap_functional_tensor,
|
||||||
get_inplace_requires_grad_allowed,
|
get_inplace_requires_grad_allowed,
|
||||||
|
get_unwrapped,
|
||||||
|
is_functorch_wrapped_tensor,
|
||||||
set_inplace_requires_grad_allowed,
|
set_inplace_requires_grad_allowed,
|
||||||
)
|
)
|
||||||
from torch._functorch.utils import argnums_t, exposed_in
|
from torch._functorch.utils import argnums_t, exposed_in
|
||||||
@ -1795,3 +1797,19 @@ def linearize(func: Callable, *primals) -> tuple[Any, Callable]:
|
|||||||
return tree_unflatten(flat_output, output_spec)
|
return tree_unflatten(flat_output, output_spec)
|
||||||
|
|
||||||
return output, jvp_fn
|
return output, jvp_fn
|
||||||
|
|
||||||
|
|
||||||
|
@exposed_in("torch.func")
|
||||||
|
def debug_unwrap(tensor: torch.Tensor, *, recurse=True) -> torch.Tensor:
|
||||||
|
"""Unwraps a functorch tensor (e.g. BatchedTensor, GradTrackingTensor) to its underlying tensor.
|
||||||
|
|
||||||
|
This function should only be used in a debug setting (e.g. trying to print the
|
||||||
|
value of a Tensor in a debugger). Otherwise, using the result of function
|
||||||
|
inside of a function being transformed will lead to undefined behavior.
|
||||||
|
"""
|
||||||
|
if not is_functorch_wrapped_tensor(tensor):
|
||||||
|
return tensor
|
||||||
|
result = get_unwrapped(tensor)
|
||||||
|
if recurse:
|
||||||
|
return debug_unwrap(result)
|
||||||
|
return result
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from torch._functorch.apis import grad, grad_and_value, vmap
|
from torch._functorch.apis import grad, grad_and_value, vmap
|
||||||
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
|
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
|
||||||
from torch._functorch.eager_transforms import (
|
from torch._functorch.eager_transforms import (
|
||||||
|
debug_unwrap,
|
||||||
functionalize,
|
functionalize,
|
||||||
hessian,
|
hessian,
|
||||||
jacfwd,
|
jacfwd,
|
||||||
@ -26,4 +27,5 @@ __all__ = [
|
|||||||
"vjp",
|
"vjp",
|
||||||
"functional_call",
|
"functional_call",
|
||||||
"stack_module_state",
|
"stack_module_state",
|
||||||
|
"debug_unwrap",
|
||||||
]
|
]
|
||||||
|
Reference in New Issue
Block a user