mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 16:44:54 +08:00
PT2/TorchScript interoperability fix (#94678)
Allows torch.compile() to inline into ScriptFunction Pull Request resolved: https://github.com/pytorch/pytorch/pull/94678 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
b6443fca86
commit
ae57bd6630
38
test/dynamo/test_interop.py
Normal file
38
test/dynamo/test_interop.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch.onnx.operators
|
||||
from torch._dynamo.testing import same
|
||||
|
||||
|
||||
def fn(a, b):
|
||||
return a + b * 0.67
|
||||
|
||||
|
||||
class InteropTests(torch._dynamo.test_case.TestCase):
|
||||
def _common(self, fn):
|
||||
inputs = [torch.randn(10), torch.randn(10)]
|
||||
ref = fn(*inputs)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(*inputs)
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
def test_fx_fn(self):
|
||||
fx_fn = torch.fx.symbolic_trace(fn)
|
||||
self._common(lambda a, b: fx_fn(a, b) + 1)
|
||||
|
||||
def test_script_fn(self):
|
||||
script_fn = torch.jit.script(fn)
|
||||
self._common(lambda a, b: script_fn(a, b) + 1)
|
||||
|
||||
def test_trace_fn(self):
|
||||
trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)])
|
||||
self._common(lambda a, b: trace_fn(a, b) + 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -2,9 +2,12 @@
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from typing import List
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestAutodiffJit(JitTestCase):
|
||||
def test_undefined_tensor_lists(self):
|
||||
def fn(tensor_list: List[torch.Tensor], add_tensor):
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
@ -15,6 +16,7 @@ if __name__ == '__main__':
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestProfiler(JitTestCase):
|
||||
def setUp(self):
|
||||
self.prev_exec = torch._C._jit_set_profiling_executor(True)
|
||||
|
||||
@ -342,6 +342,8 @@ class FooToPickle(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.bar = torch.jit.ScriptModule()
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestJit(JitTestCase):
|
||||
@unittest.skip("Requires a lot of RAM")
|
||||
def test_big(self):
|
||||
@ -2982,6 +2984,7 @@ graph(%Ra, %Rb):
|
||||
self.assertRegex(graph.__repr__(), source_range_regex)
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestFrontend(JitTestCase):
|
||||
|
||||
def test_instancing_error(self):
|
||||
@ -3038,6 +3041,7 @@ class TestFrontend(JitTestCase):
|
||||
res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])})
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestScript(JitTestCase):
|
||||
|
||||
# Tests that calling torch.jit.script repeated on function is allowed.
|
||||
@ -15989,10 +15993,12 @@ EXCLUDE_ALIAS = {
|
||||
}
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestJitGeneratedModule(JitTestCase):
|
||||
pass
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestJitGeneratedFunctional(JitTestCase):
|
||||
pass
|
||||
|
||||
|
||||
@ -80,6 +80,8 @@ def inline_fusion_groups():
|
||||
finally:
|
||||
torch._C._debug_set_fusion_group_inlining(old_inlining)
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestTEFuser(JitTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -2622,6 +2624,7 @@ def get_name(op):
|
||||
# super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works.
|
||||
# super(TestNNCOpInfo, self) fails because TestNNCOpInfo gets deleted from global scope.
|
||||
# super(JitCommonTestCase, self).fn() would skip JitCommonTestCase.fn() implementation
|
||||
@skipIfTorchDynamo()
|
||||
class TestNNCOpInfoParent(JitCommonTestCase):
|
||||
pass
|
||||
|
||||
@ -2739,6 +2742,7 @@ only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
|
||||
|
||||
# Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent)
|
||||
@skipIfTorchDynamo()
|
||||
class TestLoopnestRandomizationParent(JitTestCase):
|
||||
pass
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from torch import nn
|
||||
import unittest
|
||||
import itertools
|
||||
|
||||
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
|
||||
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions
|
||||
|
||||
@ -34,6 +34,7 @@ def warmup_and_run_forward(f, *args):
|
||||
return results
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestTensorExprFuser(BaseTestClass):
|
||||
def test_easy(self):
|
||||
def easy(x, y):
|
||||
|
||||
@ -459,7 +459,7 @@ class VariableBuilder:
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
elif istype(value, types.FunctionType):
|
||||
elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
|
||||
return UserFunctionVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
|
||||
@ -112,7 +112,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
||||
self.is_constant = False
|
||||
|
||||
assert isinstance(
|
||||
fn, types.FunctionType
|
||||
fn, (types.FunctionType, torch.jit.ScriptFunction)
|
||||
), f"expected FunctionType found {typestr(fn)} {fn}"
|
||||
# unpack @torch._dynamo.optimize()(fn) wrapped function
|
||||
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
|
||||
|
||||
@ -1343,6 +1343,8 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None,
|
||||
)
|
||||
# Forward docstrings
|
||||
fn.__doc__ = obj.__doc__
|
||||
# Allow torch.compile() to inline
|
||||
fn._torchdynamo_inline = obj # type: ignore[attr-defined]
|
||||
_set_jit_function_cache(obj, fn)
|
||||
return fn
|
||||
else:
|
||||
|
||||
@ -893,6 +893,8 @@ def trace(
|
||||
example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
|
||||
)
|
||||
|
||||
# Allow torch.compile() to inline
|
||||
traced._torchdynamo_inline = func # type: ignore[attr-defined]
|
||||
return traced
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user