mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127839 Approved by: https://github.com/oulgen
40 lines
975 B
Python
40 lines
975 B
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
from typing import Tuple, Union
|
|
|
|
import torch
|
|
from torch._C._functorch import (
|
|
get_single_level_autograd_function_allowed,
|
|
set_single_level_autograd_function_allowed,
|
|
unwrap_if_dead,
|
|
)
|
|
from torch.utils._exposed_in import exposed_in
|
|
|
|
__all__ = [
|
|
"exposed_in",
|
|
"argnums_t",
|
|
"enable_single_level_autograd_function",
|
|
"unwrap_dead_wrappers",
|
|
]
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def enable_single_level_autograd_function():
|
|
try:
|
|
prev_state = get_single_level_autograd_function_allowed()
|
|
set_single_level_autograd_function_allowed(True)
|
|
yield
|
|
finally:
|
|
set_single_level_autograd_function_allowed(prev_state)
|
|
|
|
|
|
def unwrap_dead_wrappers(args):
|
|
# NB: doesn't use tree_map_only for performance reasons
|
|
result = tuple(
|
|
unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
|
|
)
|
|
return result
|
|
|
|
|
|
argnums_t = Union[int, Tuple[int, ...]]
|