Fix mypy issues in fake_tensor.py (#124428)

fake_tensor.py had mypy error ignored. That seems less than desirable.

Also added SafePyObjectT<T> which is a tagged wrapper around a SafePyObject but provides static type checking (with no other guarantees).

Used `SafePyObjectT<TorchDispatchModeKey>` on some of the TorchDispatchModeTLS API to ensure that we don't accidentally inject a different type than expected into the stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124428
Approved by: https://github.com/malfet
This commit is contained in:
Aaron Orenstein
2024-04-24 20:22:24 -07:00
committed by PyTorch MergeBot
parent 8d12ba9acf
commit 609c958281
9 changed files with 98 additions and 48 deletions

View File

@ -159,7 +159,7 @@ def _get_current_dispatch_mode_stack():
return [_get_dispatch_stack_at(i) for i in range(stack_len)]
def _push_mode(mode):
def _push_mode(mode: TorchDispatchMode):
k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None
assert k is None or k == torch._C.DispatchKey.PreDispatch
if k is None: