Files
pytorch/torch/func/__init__.py
rzou 15b1ac3e86 Add torch.func.debug_unwrap (#146528)
Use it to unwrap any functorch-wrapped tensor. I don't recommend using
the output in a program since it breaks the semantics of the transforms,
but it seems useful for debugging.

I will note that some people have wanted to get intermediate values out
of an e.g. grad transform, so this might be a way to do that...

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146528
Approved by: https://github.com/Chillee
2025-02-06 18:48:09 +00:00

32 lines
656 B
Python

from torch._functorch.apis import grad, grad_and_value, vmap
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.eager_transforms import (
debug_unwrap,
functionalize,
hessian,
jacfwd,
jacrev,
jvp,
linearize,
vjp,
)
from torch._functorch.functional_call import functional_call, stack_module_state
__all__ = [
"grad",
"grad_and_value",
"vmap",
"replace_all_batch_norm_modules_",
"functionalize",
"hessian",
"jacfwd",
"jacrev",
"jvp",
"linearize",
"vjp",
"functional_call",
"stack_module_state",
"debug_unwrap",
]