mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129762 Approved by: https://github.com/anijain2305
65 lines
1.7 KiB
Python
65 lines
1.7 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch.onnx.operators
|
|
|
|
|
|
def fn(a, b):
|
|
return a + b * 0.67
|
|
|
|
|
|
class InteropTests(torch._dynamo.test_case.TestCase):
|
|
def _common(self, fn):
|
|
inputs = [torch.randn(10), torch.randn(10)]
|
|
ref = fn(*inputs)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(*inputs)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_fx_fn(self):
|
|
fx_fn = torch.fx.symbolic_trace(fn)
|
|
self._common(lambda a, b: fx_fn(a, b) + 1)
|
|
|
|
def test_script_fn(self):
|
|
script_fn = torch.jit.script(fn)
|
|
self._common(lambda a, b: script_fn(a, b) + 1)
|
|
|
|
def test_trace_fn(self):
|
|
trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)])
|
|
self._common(lambda a, b: trace_fn(a, b) + 1)
|
|
|
|
def test_vmap_in_graph(self):
|
|
from functools import wraps
|
|
|
|
from torch._dynamo import allow_in_graph
|
|
|
|
def traceable(f):
|
|
f = allow_in_graph(f)
|
|
|
|
@wraps(f)
|
|
def wrapper(*args, **kwargs):
|
|
return f(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
x = torch.randn(3, 5, 3)
|
|
|
|
def fn(x):
|
|
return torch.vmap(torch.Tensor.t)(x)
|
|
|
|
fn_opt = torch.compile(fn, backend=cnts, fullgraph=True)
|
|
fn_opt_traceable = torch.compile(traceable(fn), backend=cnts, fullgraph=True)
|
|
|
|
self.assertEqual(fn(x), fn_opt(x))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(fn_opt(x), fn_opt_traceable(x))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|