[functorch] Moved to Keops example

This commit is contained in:
Horace He
2021-07-30 03:44:26 -07:00
committed by Jon Janzen
parent 9ca5ae86d0
commit 1a2e538580

View File

@ -4,35 +4,17 @@ import torch
from functools import partial
import time
class Lambda(nn.Module):
def __init__(self, lambd):
super(Lambda, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
mod = nn.Sequential(
nn.Linear(128, 1024),
Lambda(lambda x: 2*x),
Lambda(lambda x: 2*x),
Lambda(lambda x: 2*x),
Lambda(lambda x: 2*x),
Lambda(lambda x: 2*x),
nn.Flatten(),
)
a = torch.randn(10000, 1, 4, requires_grad=True)
b = torch.randn(1, 10000, 4)
def f(a):
return mod(a)
return (a * b).sin().sum(dim=0)
fw_compiler = partial(tvm_compile, name='fw_mlp_1')
bw_compiler = partial(tvm_compile, name='bw_mlp_1')
fw_compiler = partial(tvm_compile, name='fw_keops')
bw_compiler = partial(tvm_compile, name='bw_keops')
# fw_compiler = lambda x, _: x
# bw_compiler = lambda x, _: x
compiled_f = compiled_function(f, fw_compiler, bw_compiler).apply
for param in mod.parameters():
param.requires_grad_(False)
a = torch.randn(512, 128, requires_grad=True)
iters = 10
out = compiled_f(a)
out.sum().backward()
@ -50,7 +32,7 @@ def bench_jax():
jax_a = jnp.array(a.detach().numpy())
jax_b = jnp.array(b.detach().numpy())
def f(a):
return jnp.sin((a*jax_b).sum(axis=[0,1])).sum()
return jnp.sin((a*jax_b).sum(axis=[0])).sum()
jit_f = jax.jit(jax.grad(f))
jit_f(jax_a)
begin = time.time()