diff --git a/docs/source/func.api.rst b/docs/source/func.api.rst index 3e03382ffe48..362954f731af 100644 --- a/docs/source/func.api.rst +++ b/docs/source/func.api.rst @@ -76,3 +76,12 @@ guidance here :maxdepth: 1 func.batch_norm + +Debug utilities +--------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + debug_unwrap diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 34ad36cbb417..ce424b4ee157 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -3261,6 +3261,18 @@ class TestHelpers(TestCase): out = A.apply(x, y) out.backward() + def test_debug_unwrap(self): + stuff = [] + + def f(x): + stuff.append(torch.func.debug_unwrap(x)) + return x.sin() + + x = torch.randn(2, 3) + _ = vmap(vmap(f))(x) + self.assertEqual(stuff[0], x) + self.assertTrue(stuff[0] is x) + def test_reductify_leaf(self, device): reductify_leaf = torch._functorch.autograd_function.reductify_leaf B = 2 diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index edb24232e417..f058c215c39e 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -26,6 +26,8 @@ from torch._C._functorch import ( _wrap_for_grad, _wrap_functional_tensor, get_inplace_requires_grad_allowed, + get_unwrapped, + is_functorch_wrapped_tensor, set_inplace_requires_grad_allowed, ) from torch._functorch.utils import argnums_t, exposed_in @@ -1795,3 +1797,19 @@ def linearize(func: Callable, *primals) -> tuple[Any, Callable]: return tree_unflatten(flat_output, output_spec) return output, jvp_fn + + +@exposed_in("torch.func") +def debug_unwrap(tensor: torch.Tensor, *, recurse=True) -> torch.Tensor: + """Unwraps a functorch tensor (e.g. BatchedTensor, GradTrackingTensor) to its underlying tensor. + + This function should only be used in a debug setting (e.g. trying to print the + value of a Tensor in a debugger). Otherwise, using the result of function + inside of a function being transformed will lead to undefined behavior. + """ + if not is_functorch_wrapped_tensor(tensor): + return tensor + result = get_unwrapped(tensor) + if recurse: + return debug_unwrap(result) + return result diff --git a/torch/func/__init__.py b/torch/func/__init__.py index 3dd1d391e2c0..35743fcf429a 100644 --- a/torch/func/__init__.py +++ b/torch/func/__init__.py @@ -1,6 +1,7 @@ from torch._functorch.apis import grad, grad_and_value, vmap from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ from torch._functorch.eager_transforms import ( + debug_unwrap, functionalize, hessian, jacfwd, @@ -26,4 +27,5 @@ __all__ = [ "vjp", "functional_call", "stack_module_state", + "debug_unwrap", ]