mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] jit._drop fun modifier to allow in jit class non-jit decl funs (#93012)
`@torch.jit.unused` and `@torch.jit.ignore` do not allow to keep in torch scripted class member function, that has non scriptable declaration (e.g. return type) Adding FunctionModifier _DROP to allow fully skip those functions from scripting and keep them in the code of the scripted class. E.g. it can be used for: ``` @torch.jit._drop def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument: # torch.fx classes are not scriptable return tracer.create_node( "call_function", CFX, args=(tracer.create_arg(self.features),), kwargs={}, ) def __iter__(self) -> Iterator[torch.Tensor]: return iter(self.a) ``` Testing: Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after. ``` python test/test_jit.py ``` Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830) Pull Request resolved: https://github.com/pytorch/pytorch/pull/93012 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
994f85d639
commit
6a2838eec5
@ -23,7 +23,7 @@ from torch._sources import get_source_lines_and_file, parse_def, make_source_con
|
||||
from torch._sources import ParsedDef as _ParsedDef
|
||||
from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS
|
||||
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
|
||||
from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401
|
||||
from torch._jit_internal import should_drop, _is_drop_fn, is_static_fn, FunctionModifiers # noqa: F401
|
||||
from torch import _jit_internal
|
||||
import torch.jit.annotations
|
||||
|
||||
@ -195,6 +195,7 @@ def get_jit_class_def(cls, self_name):
|
||||
predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m))
|
||||
and not is_static_fn(cls, m.__name__)
|
||||
and m.__name__ in cls.__dict__
|
||||
and not _is_drop_fn(m)
|
||||
)
|
||||
|
||||
def is_classmethod(fn):
|
||||
@ -281,6 +282,10 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
||||
for arg in fn_def.args.args + fn_def.args.kwonlyargs:
|
||||
# Replace potentially unsupported type annotations by "Any"
|
||||
arg.annotation = unused_def.args.args[0].annotation
|
||||
if _is_drop_fn(fn):
|
||||
# Dropping potentially unsupported return type annotation for jit._drop
|
||||
fn_def.returns = None
|
||||
fn_def.type_comment = None
|
||||
|
||||
# If MonkeyType is installed, get all the consolidated type traces
|
||||
# for the arguments from type_trace_db
|
||||
|
Reference in New Issue
Block a user