[jit] jit._drop fun modifier to allow in jit class non-jit decl funs (#93012)

`@torch.jit.unused` and `@torch.jit.ignore` do not allow to keep in torch scripted class member function, that has non scriptable declaration (e.g. return type)

Adding FunctionModifier _DROP to allow fully skip those functions from scripting and keep them in the code of the scripted class.

E.g. it can be used for:

```
@torch.jit._drop
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
    # torch.fx classes are not scriptable
    return tracer.create_node(
        "call_function",
        CFX,
        args=(tracer.create_arg(self.features),),
        kwargs={},
    )

def __iter__(self) -> Iterator[torch.Tensor]:
    return iter(self.a)
```

Testing:
Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after.

```
python test/test_jit.py
```

Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93012
Approved by: https://github.com/davidberard98
This commit is contained in:
Ivan Kobzarev
2023-01-31 11:56:53 -08:00
committed by PyTorch MergeBot
parent 994f85d639
commit 6a2838eec5
4 changed files with 51 additions and 4 deletions

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: jit"]
from collections import namedtuple
from typing import Dict, List, Optional, Tuple
from typing import Dict, Iterator, List, Optional, Tuple
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck
@ -244,6 +244,32 @@ class TestTypesAndAnnotation(JitTestCase):
with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"):
scripted_mod = torch.jit.script(mod)
def test_ignoring_fn_with_nonscriptable_types(self):
class CFX(object):
def __init__(self, a: List[torch.Tensor]) -> None:
self.a = a
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.sin(x)
@torch.jit._drop
def __iter__(self) -> Iterator[torch.Tensor]:
return iter(self.a)
@torch.jit._drop
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
# torch.fx classes are not scriptable
return tracer.create_node(
"call_function",
CFX,
args=(tracer.create_arg(self.features),),
kwargs={},
)
torch.jit.script(CFX)
def test_unimported_type_resolution(self):
# verify fallback from the python resolver to the c++ resolver

View File

@ -526,6 +526,7 @@ class FunctionModifiers(object):
COPY_TO_SCRIPT_WRAPPER = (
"if this method is not scripted, copy the python method onto the scripted model"
)
_DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
def export(fn):
@ -740,6 +741,11 @@ def ignore(drop=False, **kwargs):
return decorator
def _drop(fn):
fn._torchscript_modifier = FunctionModifiers._DROP
return fn
def _copy_to_script_wrapper(fn):
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
return fn
@ -762,12 +768,21 @@ def should_drop(fn) -> bool:
attr = get_torchscript_modifier(fn)
if attr is None:
return False
return attr is FunctionModifiers.UNUSED
return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
def is_ignored_fn(fn) -> bool:
mod = get_torchscript_modifier(fn)
return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE
return (
mod is FunctionModifiers.UNUSED
or mod is FunctionModifiers.IGNORE
or mod is FunctionModifiers._DROP
)
def _is_drop_fn(fn) -> bool:
mod = get_torchscript_modifier(fn)
return mod is FunctionModifiers._DROP
def is_static_fn(cls, fn) -> bool:

View File

@ -11,6 +11,7 @@ from torch._jit_internal import (
Final,
Future,
_Await,
_drop,
_IgnoreContextManager,
_overload,
_overload_method,

View File

@ -23,7 +23,7 @@ from torch._sources import get_source_lines_and_file, parse_def, make_source_con
from torch._sources import ParsedDef as _ParsedDef
from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401
from torch._jit_internal import should_drop, _is_drop_fn, is_static_fn, FunctionModifiers # noqa: F401
from torch import _jit_internal
import torch.jit.annotations
@ -195,6 +195,7 @@ def get_jit_class_def(cls, self_name):
predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m))
and not is_static_fn(cls, m.__name__)
and m.__name__ in cls.__dict__
and not _is_drop_fn(m)
)
def is_classmethod(fn):
@ -281,6 +282,10 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
for arg in fn_def.args.args + fn_def.args.kwonlyargs:
# Replace potentially unsupported type annotations by "Any"
arg.annotation = unused_def.args.args[0].annotation
if _is_drop_fn(fn):
# Dropping potentially unsupported return type annotation for jit._drop
fn_def.returns = None
fn_def.type_comment = None
# If MonkeyType is installed, get all the consolidated type traces
# for the arguments from type_trace_db