From 6a2838eec551e893e97bab05ed91c4892717a8a5 Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Tue, 31 Jan 2023 11:56:53 -0800 Subject: [PATCH] [jit] jit._drop fun modifier to allow in jit class non-jit decl funs (#93012) `@torch.jit.unused` and `@torch.jit.ignore` do not allow to keep in torch scripted class member function, that has non scriptable declaration (e.g. return type) Adding FunctionModifier _DROP to allow fully skip those functions from scripting and keep them in the code of the scripted class. E.g. it can be used for: ``` @torch.jit._drop def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument: # torch.fx classes are not scriptable return tracer.create_node( "call_function", CFX, args=(tracer.create_arg(self.features),), kwargs={}, ) def __iter__(self) -> Iterator[torch.Tensor]: return iter(self.a) ``` Testing: Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after. ``` python test/test_jit.py ``` Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830) Pull Request resolved: https://github.com/pytorch/pytorch/pull/93012 Approved by: https://github.com/davidberard98 --- test/jit/test_types.py | 28 +++++++++++++++++++++++++++- torch/_jit_internal.py | 19 +++++++++++++++++-- torch/jit/__init__.py | 1 + torch/jit/frontend.py | 7 ++++++- 4 files changed, 51 insertions(+), 4 deletions(-) diff --git a/test/jit/test_types.py b/test/jit/test_types.py index fd28448387d9..9ad04ce7148b 100644 --- a/test/jit/test_types.py +++ b/test/jit/test_types.py @@ -1,7 +1,7 @@ # Owner(s): ["oncall: jit"] from collections import namedtuple -from typing import Dict, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple from torch.testing._internal.jit_utils import JitTestCase from torch.testing import FileCheck @@ -244,6 +244,32 @@ class TestTypesAndAnnotation(JitTestCase): with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"): scripted_mod = torch.jit.script(mod) + + def test_ignoring_fn_with_nonscriptable_types(self): + class CFX(object): + def __init__(self, a: List[torch.Tensor]) -> None: + self.a = a + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.sin(x) + + @torch.jit._drop + def __iter__(self) -> Iterator[torch.Tensor]: + return iter(self.a) + + @torch.jit._drop + def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument: + # torch.fx classes are not scriptable + return tracer.create_node( + "call_function", + CFX, + args=(tracer.create_arg(self.features),), + kwargs={}, + ) + + torch.jit.script(CFX) + + def test_unimported_type_resolution(self): # verify fallback from the python resolver to the c++ resolver diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 6177ed0f6798..28bb78858e46 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -526,6 +526,7 @@ class FunctionModifiers(object): COPY_TO_SCRIPT_WRAPPER = ( "if this method is not scripted, copy the python method onto the scripted model" ) + _DROP = "_drop (function is fully ignored, declaration can be unscriptable)" def export(fn): @@ -740,6 +741,11 @@ def ignore(drop=False, **kwargs): return decorator +def _drop(fn): + fn._torchscript_modifier = FunctionModifiers._DROP + return fn + + def _copy_to_script_wrapper(fn): fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER return fn @@ -762,12 +768,21 @@ def should_drop(fn) -> bool: attr = get_torchscript_modifier(fn) if attr is None: return False - return attr is FunctionModifiers.UNUSED + return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP def is_ignored_fn(fn) -> bool: mod = get_torchscript_modifier(fn) - return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE + return ( + mod is FunctionModifiers.UNUSED + or mod is FunctionModifiers.IGNORE + or mod is FunctionModifiers._DROP + ) + + +def _is_drop_fn(fn) -> bool: + mod = get_torchscript_modifier(fn) + return mod is FunctionModifiers._DROP def is_static_fn(cls, fn) -> bool: diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index ed2652786c11..a473ecb94139 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -11,6 +11,7 @@ from torch._jit_internal import ( Final, Future, _Await, + _drop, _IgnoreContextManager, _overload, _overload_method, diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index c3d3ba350848..a53046bd2156 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -23,7 +23,7 @@ from torch._sources import get_source_lines_and_file, parse_def, make_source_con from torch._sources import ParsedDef as _ParsedDef from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name -from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401 +from torch._jit_internal import should_drop, _is_drop_fn, is_static_fn, FunctionModifiers # noqa: F401 from torch import _jit_internal import torch.jit.annotations @@ -195,6 +195,7 @@ def get_jit_class_def(cls, self_name): predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn(cls, m.__name__) and m.__name__ in cls.__dict__ + and not _is_drop_fn(m) ) def is_classmethod(fn): @@ -281,6 +282,10 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): for arg in fn_def.args.args + fn_def.args.kwonlyargs: # Replace potentially unsupported type annotations by "Any" arg.annotation = unused_def.args.args[0].annotation + if _is_drop_fn(fn): + # Dropping potentially unsupported return type annotation for jit._drop + fn_def.returns = None + fn_def.type_comment = None # If MonkeyType is installed, get all the consolidated type traces # for the arguments from type_trace_db