PT2/TorchScript interoperability fix (#94678)

Allows torch.compile() to inline into ScriptFunction

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94678
Approved by: https://github.com/ezyang
This commit is contained in:
Jason Ansel
2023-02-14 19:06:50 +00:00
committed by PyTorch MergeBot
parent b6443fca86
commit ae57bd6630
10 changed files with 61 additions and 3 deletions

View File

@ -0,0 +1,38 @@
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.onnx.operators
from torch._dynamo.testing import same
def fn(a, b):
return a + b * 0.67
class InteropTests(torch._dynamo.test_case.TestCase):
def _common(self, fn):
inputs = [torch.randn(10), torch.randn(10)]
ref = fn(*inputs)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(*inputs)
self.assertTrue(same(ref, res))
def test_fx_fn(self):
fx_fn = torch.fx.symbolic_trace(fn)
self._common(lambda a, b: fx_fn(a, b) + 1)
def test_script_fn(self):
script_fn = torch.jit.script(fn)
self._common(lambda a, b: script_fn(a, b) + 1)
def test_trace_fn(self):
trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)])
self._common(lambda a, b: trace_fn(a, b) + 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -2,9 +2,12 @@
import torch
from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.jit_utils import JitTestCase
from typing import List
@skipIfTorchDynamo()
class TestAutodiffJit(JitTestCase):
def test_undefined_tensor_lists(self):
def fn(tensor_list: List[torch.Tensor], add_tensor):

View File

@ -4,6 +4,7 @@ import os
import sys
import torch
from torch.testing._internal.common_utils import skipIfTorchDynamo
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -15,6 +16,7 @@ if __name__ == '__main__':
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
@skipIfTorchDynamo()
class TestProfiler(JitTestCase):
def setUp(self):
self.prev_exec = torch._C._jit_set_profiling_executor(True)

View File

@ -342,6 +342,8 @@ class FooToPickle(torch.nn.Module):
super().__init__()
self.bar = torch.jit.ScriptModule()
@skipIfTorchDynamo()
class TestJit(JitTestCase):
@unittest.skip("Requires a lot of RAM")
def test_big(self):
@ -2982,6 +2984,7 @@ graph(%Ra, %Rb):
self.assertRegex(graph.__repr__(), source_range_regex)
@skipIfTorchDynamo()
class TestFrontend(JitTestCase):
def test_instancing_error(self):
@ -3038,6 +3041,7 @@ class TestFrontend(JitTestCase):
res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])})
@skipIfTorchDynamo()
class TestScript(JitTestCase):
# Tests that calling torch.jit.script repeated on function is allowed.
@ -15989,10 +15993,12 @@ EXCLUDE_ALIAS = {
}
@skipIfTorchDynamo()
class TestJitGeneratedModule(JitTestCase):
pass
@skipIfTorchDynamo()
class TestJitGeneratedFunctional(JitTestCase):
pass

View File

@ -80,6 +80,8 @@ def inline_fusion_groups():
finally:
torch._C._debug_set_fusion_group_inlining(old_inlining)
@skipIfTorchDynamo()
class TestTEFuser(JitTestCase):
def setUp(self):
super().setUp()
@ -2622,6 +2624,7 @@ def get_name(op):
# super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works.
# super(TestNNCOpInfo, self) fails because TestNNCOpInfo gets deleted from global scope.
# super(JitCommonTestCase, self).fn() would skip JitCommonTestCase.fn() implementation
@skipIfTorchDynamo()
class TestNNCOpInfoParent(JitCommonTestCase):
pass
@ -2739,6 +2742,7 @@ only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
# Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent)
@skipIfTorchDynamo()
class TestLoopnestRandomizationParent(JitTestCase):
pass

View File

@ -7,7 +7,7 @@ from torch import nn
import unittest
import itertools
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo
from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions
@ -34,6 +34,7 @@ def warmup_and_run_forward(f, *args):
return results
@skipIfTorchDynamo()
class TestTensorExprFuser(BaseTestClass):
def test_easy(self):
def easy(x, y):

View File

@ -459,7 +459,7 @@ class VariableBuilder:
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif istype(value, types.FunctionType):
elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
return UserFunctionVariable(
value,
source=self.source,

View File

@ -112,7 +112,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
self.is_constant = False
assert isinstance(
fn, types.FunctionType
fn, (types.FunctionType, torch.jit.ScriptFunction)
), f"expected FunctionType found {typestr(fn)} {fn}"
# unpack @torch._dynamo.optimize()(fn) wrapped function
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)

View File

@ -1343,6 +1343,8 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None,
)
# Forward docstrings
fn.__doc__ = obj.__doc__
# Allow torch.compile() to inline
fn._torchdynamo_inline = obj # type: ignore[attr-defined]
_set_jit_function_cache(obj, fn)
return fn
else:

View File

@ -893,6 +893,8 @@ def trace(
example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
)
# Allow torch.compile() to inline
traced._torchdynamo_inline = func # type: ignore[attr-defined]
return traced