[functorch] update compile example imports (pytorch/functorch#834)

This commit is contained in:
Andre
2022-05-26 12:09:18 -07:00
committed by Jon Janzen
parent 36fa9d8295
commit 1feff6bb69
3 changed files with 8 additions and 8 deletions

View File

@ -4,7 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functorch import grad, nnc_jit, make_fx, make_nnc
from functorch import grad, make_fx
from functorch.compile import nnc_jit
import torch
import time
@ -16,9 +17,7 @@ def f(x):
inp = torch.randn(100)
grad_pt = grad(f)
grad_fx = make_fx(grad_pt)(inp)
grad_nnc = nnc_jit(grad_pt, skip_specialization=True)
loopnest = make_nnc(grad_pt)(inp)
print(loopnest)
grad_nnc = nnc_jit(grad_pt)
def bench(name, f, iters=10000, warmup=3):