mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[jit] jit._drop fun modifier to allow in jit class non-jit decl funs (#93012)
`@torch.jit.unused` and `@torch.jit.ignore` do not allow to keep in torch scripted class member function, that has non scriptable declaration (e.g. return type) Adding FunctionModifier _DROP to allow fully skip those functions from scripting and keep them in the code of the scripted class. E.g. it can be used for: ``` @torch.jit._drop def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument: # torch.fx classes are not scriptable return tracer.create_node( "call_function", CFX, args=(tracer.create_arg(self.features),), kwargs={}, ) def __iter__(self) -> Iterator[torch.Tensor]: return iter(self.a) ``` Testing: Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after. ``` python test/test_jit.py ``` Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830) Pull Request resolved: https://github.com/pytorch/pytorch/pull/93012 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
994f85d639
commit
6a2838eec5
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing import FileCheck
|
||||
@ -244,6 +244,32 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"):
|
||||
scripted_mod = torch.jit.script(mod)
|
||||
|
||||
|
||||
def test_ignoring_fn_with_nonscriptable_types(self):
|
||||
class CFX(object):
|
||||
def __init__(self, a: List[torch.Tensor]) -> None:
|
||||
self.a = a
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sin(x)
|
||||
|
||||
@torch.jit._drop
|
||||
def __iter__(self) -> Iterator[torch.Tensor]:
|
||||
return iter(self.a)
|
||||
|
||||
@torch.jit._drop
|
||||
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
|
||||
# torch.fx classes are not scriptable
|
||||
return tracer.create_node(
|
||||
"call_function",
|
||||
CFX,
|
||||
args=(tracer.create_arg(self.features),),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
torch.jit.script(CFX)
|
||||
|
||||
|
||||
def test_unimported_type_resolution(self):
|
||||
# verify fallback from the python resolver to the c++ resolver
|
||||
|
||||
|
@ -526,6 +526,7 @@ class FunctionModifiers(object):
|
||||
COPY_TO_SCRIPT_WRAPPER = (
|
||||
"if this method is not scripted, copy the python method onto the scripted model"
|
||||
)
|
||||
_DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
|
||||
|
||||
|
||||
def export(fn):
|
||||
@ -740,6 +741,11 @@ def ignore(drop=False, **kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def _drop(fn):
|
||||
fn._torchscript_modifier = FunctionModifiers._DROP
|
||||
return fn
|
||||
|
||||
|
||||
def _copy_to_script_wrapper(fn):
|
||||
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
|
||||
return fn
|
||||
@ -762,12 +768,21 @@ def should_drop(fn) -> bool:
|
||||
attr = get_torchscript_modifier(fn)
|
||||
if attr is None:
|
||||
return False
|
||||
return attr is FunctionModifiers.UNUSED
|
||||
return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
|
||||
|
||||
|
||||
def is_ignored_fn(fn) -> bool:
|
||||
mod = get_torchscript_modifier(fn)
|
||||
return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE
|
||||
return (
|
||||
mod is FunctionModifiers.UNUSED
|
||||
or mod is FunctionModifiers.IGNORE
|
||||
or mod is FunctionModifiers._DROP
|
||||
)
|
||||
|
||||
|
||||
def _is_drop_fn(fn) -> bool:
|
||||
mod = get_torchscript_modifier(fn)
|
||||
return mod is FunctionModifiers._DROP
|
||||
|
||||
|
||||
def is_static_fn(cls, fn) -> bool:
|
||||
|
@ -11,6 +11,7 @@ from torch._jit_internal import (
|
||||
Final,
|
||||
Future,
|
||||
_Await,
|
||||
_drop,
|
||||
_IgnoreContextManager,
|
||||
_overload,
|
||||
_overload_method,
|
||||
|
@ -23,7 +23,7 @@ from torch._sources import get_source_lines_and_file, parse_def, make_source_con
|
||||
from torch._sources import ParsedDef as _ParsedDef
|
||||
from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS
|
||||
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
|
||||
from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401
|
||||
from torch._jit_internal import should_drop, _is_drop_fn, is_static_fn, FunctionModifiers # noqa: F401
|
||||
from torch import _jit_internal
|
||||
import torch.jit.annotations
|
||||
|
||||
@ -195,6 +195,7 @@ def get_jit_class_def(cls, self_name):
|
||||
predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m))
|
||||
and not is_static_fn(cls, m.__name__)
|
||||
and m.__name__ in cls.__dict__
|
||||
and not _is_drop_fn(m)
|
||||
)
|
||||
|
||||
def is_classmethod(fn):
|
||||
@ -281,6 +282,10 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
||||
for arg in fn_def.args.args + fn_def.args.kwonlyargs:
|
||||
# Replace potentially unsupported type annotations by "Any"
|
||||
arg.annotation = unused_def.args.args[0].annotation
|
||||
if _is_drop_fn(fn):
|
||||
# Dropping potentially unsupported return type annotation for jit._drop
|
||||
fn_def.returns = None
|
||||
fn_def.type_comment = None
|
||||
|
||||
# If MonkeyType is installed, get all the consolidated type traces
|
||||
# for the arguments from type_trace_db
|
||||
|
Reference in New Issue
Block a user