[jit] jit._drop fun modifier to allow in jit class non-jit decl funs (#93012)

`@torch.jit.unused` and `@torch.jit.ignore` do not allow to keep in torch scripted class member function, that has non scriptable declaration (e.g. return type)

Adding FunctionModifier _DROP to allow fully skip those functions from scripting and keep them in the code of the scripted class.

E.g. it can be used for:

```
@torch.jit._drop
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
    # torch.fx classes are not scriptable
    return tracer.create_node(
        "call_function",
        CFX,
        args=(tracer.create_arg(self.features),),
        kwargs={},
    )

def __iter__(self) -> Iterator[torch.Tensor]:
    return iter(self.a)
```

Testing:
Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after.

```
python test/test_jit.py
```

Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93012
Approved by: https://github.com/davidberard98
This commit is contained in:
Ivan Kobzarev
2023-01-31 11:56:53 -08:00
committed by PyTorch MergeBot
parent 994f85d639
commit 6a2838eec5
4 changed files with 51 additions and 4 deletions

View File

@ -23,7 +23,7 @@ from torch._sources import get_source_lines_and_file, parse_def, make_source_con
from torch._sources import ParsedDef as _ParsedDef
from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401
from torch._jit_internal import should_drop, _is_drop_fn, is_static_fn, FunctionModifiers # noqa: F401
from torch import _jit_internal
import torch.jit.annotations
@ -195,6 +195,7 @@ def get_jit_class_def(cls, self_name):
predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m))
and not is_static_fn(cls, m.__name__)
and m.__name__ in cls.__dict__
and not _is_drop_fn(m)
)
def is_classmethod(fn):
@ -281,6 +282,10 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
for arg in fn_def.args.args + fn_def.args.kwonlyargs:
# Replace potentially unsupported type annotations by "Any"
arg.annotation = unused_def.args.args[0].annotation
if _is_drop_fn(fn):
# Dropping potentially unsupported return type annotation for jit._drop
fn_def.returns = None
fn_def.type_comment = None
# If MonkeyType is installed, get all the consolidated type traces
# for the arguments from type_trace_db