mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Use it to unwrap any functorch-wrapped tensor. I don't recommend using the output in a program since it breaks the semantics of the transforms, but it seems useful for debugging. I will note that some people have wanted to get intermediate values out of an e.g. grad transform, so this might be a way to do that... Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/146528 Approved by: https://github.com/Chillee
32 lines
656 B
Python
32 lines
656 B
Python
from torch._functorch.apis import grad, grad_and_value, vmap
|
|
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
|
|
from torch._functorch.eager_transforms import (
|
|
debug_unwrap,
|
|
functionalize,
|
|
hessian,
|
|
jacfwd,
|
|
jacrev,
|
|
jvp,
|
|
linearize,
|
|
vjp,
|
|
)
|
|
from torch._functorch.functional_call import functional_call, stack_module_state
|
|
|
|
|
|
__all__ = [
|
|
"grad",
|
|
"grad_and_value",
|
|
"vmap",
|
|
"replace_all_batch_norm_modules_",
|
|
"functionalize",
|
|
"hessian",
|
|
"jacfwd",
|
|
"jacrev",
|
|
"jvp",
|
|
"linearize",
|
|
"vjp",
|
|
"functional_call",
|
|
"stack_module_state",
|
|
"debug_unwrap",
|
|
]
|