mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] removed python key stuff
This commit is contained in:
@ -5,13 +5,6 @@ from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
|
||||
from ._src.make_functional import make_functional, make_functional_with_buffers
|
||||
|
||||
# Python key stuff is not in pytorch/pytorch core
|
||||
try:
|
||||
from ._src.python_key import key_wrap, ModuleWrap
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Monkeypatching lol
|
||||
_old_cross_entropy = torch.nn.functional.cross_entropy
|
||||
|
||||
|
@ -1,40 +0,0 @@
|
||||
import functools
|
||||
import torch._C.key as key
|
||||
from torch.fx import PythonTensor
|
||||
import torch
|
||||
|
||||
class ModuleWrap(torch.nn.Module):
|
||||
def __init__(self, mod, inps):
|
||||
super().__init__()
|
||||
self.mod = mod
|
||||
self.inps = inps
|
||||
@functools.wraps(mod.forward)
|
||||
def forward_wrapped(self, *args):
|
||||
new_args = []
|
||||
for inp, arg in zip(inps, args):
|
||||
if isinstance(inp, torch.Tensor):
|
||||
new_arg = key.addKey(PythonTensor(inp.shape, arg))
|
||||
else:
|
||||
new_arg = inp
|
||||
new_args.append(new_arg)
|
||||
out = self.mod(*new_args)
|
||||
return key.removeKey(out).proxy
|
||||
|
||||
type(self).forward = forward_wrapped
|
||||
|
||||
def key_wrap(f, inps):
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args):
|
||||
new_args = []
|
||||
for inp, arg in zip(inps, args):
|
||||
if isinstance(inp, torch.Tensor):
|
||||
new_arg = key.addKey(PythonTensor(inp.shape, arg))
|
||||
else:
|
||||
new_arg = inp
|
||||
new_args.append(new_arg)
|
||||
out = f(*new_args)
|
||||
if key.hasKey(out):
|
||||
return key.removeKey(out).proxy
|
||||
else:
|
||||
return out
|
||||
return wrapped
|
Reference in New Issue
Block a user