mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Moved to Keops example
This commit is contained in:
@ -4,35 +4,17 @@ import torch
|
||||
from functools import partial
|
||||
import time
|
||||
|
||||
class Lambda(nn.Module):
|
||||
def __init__(self, lambd):
|
||||
super(Lambda, self).__init__()
|
||||
self.lambd = lambd
|
||||
def forward(self, x):
|
||||
return self.lambd(x)
|
||||
|
||||
mod = nn.Sequential(
|
||||
nn.Linear(128, 1024),
|
||||
Lambda(lambda x: 2*x),
|
||||
Lambda(lambda x: 2*x),
|
||||
Lambda(lambda x: 2*x),
|
||||
Lambda(lambda x: 2*x),
|
||||
Lambda(lambda x: 2*x),
|
||||
nn.Flatten(),
|
||||
)
|
||||
|
||||
a = torch.randn(10000, 1, 4, requires_grad=True)
|
||||
b = torch.randn(1, 10000, 4)
|
||||
def f(a):
|
||||
return mod(a)
|
||||
return (a * b).sin().sum(dim=0)
|
||||
|
||||
fw_compiler = partial(tvm_compile, name='fw_mlp_1')
|
||||
bw_compiler = partial(tvm_compile, name='bw_mlp_1')
|
||||
fw_compiler = partial(tvm_compile, name='fw_keops')
|
||||
bw_compiler = partial(tvm_compile, name='bw_keops')
|
||||
# fw_compiler = lambda x, _: x
|
||||
# bw_compiler = lambda x, _: x
|
||||
compiled_f = compiled_function(f, fw_compiler, bw_compiler).apply
|
||||
|
||||
for param in mod.parameters():
|
||||
param.requires_grad_(False)
|
||||
a = torch.randn(512, 128, requires_grad=True)
|
||||
iters = 10
|
||||
out = compiled_f(a)
|
||||
out.sum().backward()
|
||||
@ -50,7 +32,7 @@ def bench_jax():
|
||||
jax_a = jnp.array(a.detach().numpy())
|
||||
jax_b = jnp.array(b.detach().numpy())
|
||||
def f(a):
|
||||
return jnp.sin((a*jax_b).sum(axis=[0,1])).sum()
|
||||
return jnp.sin((a*jax_b).sum(axis=[0])).sum()
|
||||
jit_f = jax.jit(jax.grad(f))
|
||||
jit_f(jax_a)
|
||||
begin = time.time()
|
||||
|
Reference in New Issue
Block a user