[functorch] removed python key stuff

This commit is contained in:
Horace He
2021-04-29 11:58:24 -07:00
committed by Jon Janzen
parent 2467773967
commit 033e6e30af
2 changed files with 0 additions and 47 deletions

View File

@ -5,13 +5,6 @@ from ._src.vmap import vmap
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
from ._src.make_functional import make_functional, make_functional_with_buffers
# Python key stuff is not in pytorch/pytorch core
try:
from ._src.python_key import key_wrap, ModuleWrap
except:
pass
# Monkeypatching lol
_old_cross_entropy = torch.nn.functional.cross_entropy

View File

@ -1,40 +0,0 @@
import functools
import torch._C.key as key
from torch.fx import PythonTensor
import torch
class ModuleWrap(torch.nn.Module):
def __init__(self, mod, inps):
super().__init__()
self.mod = mod
self.inps = inps
@functools.wraps(mod.forward)
def forward_wrapped(self, *args):
new_args = []
for inp, arg in zip(inps, args):
if isinstance(inp, torch.Tensor):
new_arg = key.addKey(PythonTensor(inp.shape, arg))
else:
new_arg = inp
new_args.append(new_arg)
out = self.mod(*new_args)
return key.removeKey(out).proxy
type(self).forward = forward_wrapped
def key_wrap(f, inps):
@functools.wraps(f)
def wrapped(*args):
new_args = []
for inp, arg in zip(inps, args):
if isinstance(inp, torch.Tensor):
new_arg = key.addKey(PythonTensor(inp.shape, arg))
else:
new_arg = inp
new_args.append(new_arg)
out = f(*new_args)
if key.hasKey(out):
return key.removeKey(out).proxy
else:
return out
return wrapped