add torch.jit.is_scripting() api (#25263)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25263

This adds an api to return true in script and false in eager, which together with ignore allows guarding of not yet supported JIT features. Bikeshedding requested please.

cc zou3519

```
def foo():
   if not torch.jit.is_scripting():
      return torch.linear(...)
   else:
      return addmm(...)
```

Test Plan: Imported from OSS

Differential Revision: D17272443

Pulled By: eellison

fbshipit-source-id: de0f769c7eaae91de0007b98969183df93a91f42
This commit is contained in:
Elias Ellison
2019-09-09 20:22:54 -07:00
committed by Facebook Github Bot
parent 36bdde255e
commit 7ab4ad7b6d
8 changed files with 152 additions and 75 deletions

View File

@ -6,6 +6,7 @@ circular dependency problems
import inspect
import weakref
import warnings
import torch._C
from torch._six import builtins
@ -168,7 +169,7 @@ class FunctionModifiers(object):
Used to denote the behavior of a function in TorchScript. See export() and
ignore() for details.
"""
IGNORE_AND_DROP = "ignore (leave as a call to Python, replace with a 'raise' on torch.jit.save)"
UNUSED = "unused (ignored and replaced with raising of an exception)"
IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
EXPORT = "export (compile this function even if nothing calls it)"
DEFAULT = "default (compile if called from a exported function / forward)"
@ -219,23 +220,52 @@ def export(fn):
return fn
def ignore(drop_on_export=False):
def unused(fn):
"""
This decorator indicates to the compiler that a function or method should
be ignored and left as a Python function.
be ignored and replaced with the raising of an exception. This allows you
to leave code in your model that is not yet TorchScript compatible and still
export your model.
Arguments:
Example (using ``@torch.jit.unused`` on a method)::
drop_on_export (bool): When ``False``, calls to this function will
that will be run with ``example_inputs``.
arguments and returns to ``func`` must be tensors
or (possibly nested) tuples that
contain tensors. When ``True``, any calls to
this function from other TorchScript code will be replaced
with a `raise` when the model is saved.
This allows you to leave code in your TorchScript model that is only ever
run when the Python interpreter is present, but not run after you save
and load your model.
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self, use_memory_efficent):
super(MyModule, self).__init__()
self.use_memory_efficent = use_memory_efficent
@torch.jit.unused
def memory_efficient(self, x):
import pdb
pdb.set_trace()
return x + 10
def forward(self, x):
# Use not-yet-scriptable memory efficient mode
if self.use_memory_efficient:
return self.memory_efficient(x)
else:
return x + 10
m = torch.jit.script(MyModule(use_memory_efficent=False))
m.save("m.pt")
m = torch.jit.script(MyModule(use_memory_efficient=True))
# exception raised
m(torch.rand(100))
"""
fn._torchscript_modifier = FunctionModifiers.UNUSED
return fn
def ignore(drop=False, **kwargs):
"""
This decorator indicates to the compiler that a function or method should
be ignored and left as a Python function. This allows you to leave code in
your model that is not yet TorchScript compatible. Models with ignored
functions cannot be exported; use torch.jit.unused instead.
Example (using ``@torch.jit.ignore`` on a method)::
@ -261,7 +291,7 @@ def ignore(drop_on_export=False):
# Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")
Example (using ``@torch.jit.ignore(drop_on_export=True)`` on a method):
Example (using ``@torch.jit.ignore(drop=True)`` on a method):
.. testcode::
@ -269,7 +299,7 @@ def ignore(drop_on_export=False):
import torch.nn as nn
class MyModule(nn.Module):
@torch.jit.ignore(drop_on_export=True)
@torch.jit.ignore(drop=True)
def training_method(self, x):
import pdb
pdb.set_trace()
@ -290,24 +320,37 @@ def ignore(drop_on_export=False):
import os
os.remove('m.pt')
"""
if callable(drop_on_export):
# used without any args, so drop_on_export is actually a function
if callable(drop):
# used without any args, so drop is actually a function
# @torch.jit.ignore
# def fn(...):
fn = drop_on_export
fn = drop
fn._torchscript_modifier = FunctionModifiers.IGNORE
return fn
if isinstance(drop_on_export, bool):
def decorator(fn):
if drop_on_export:
fn._torchscript_modifier = FunctionModifiers.IGNORE_AND_DROP
else:
fn._torchscript_modifier = FunctionModifiers.IGNORE
return fn
return decorator
raise RuntimeError("Argument to @torch.jit.ignore must be a bool or "
"a function but got {}".format(drop_on_export))
if not isinstance(drop, bool):
raise RuntimeError("Argument to @torch.jit.ignore must be a bool or "
"a function but got {}".format(drop))
# for backwards compat
drop_on_export = kwargs.pop("drop_on_export", None)
if drop_on_export:
warnings.warn("ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
"call on compilation. Use torch.jit.unused now. {}", category=DeprecationWarning)
drop = drop_on_export
elif drop:
warnings.warn("ignore(True) has been deprecated. TorchScript will now drop the function "
"call on compilation. Use torch.jit.unused now. {}", category=DeprecationWarning)
def decorator(fn):
if drop:
fn._torchscript_modifier = FunctionModifiers.UNUSED
else:
fn._torchscript_modifier = FunctionModifiers.IGNORE
return fn
return decorator
def module_has_exports(mod):
@ -318,16 +361,16 @@ def module_has_exports(mod):
return True
return False
def should_drop_on_export(fn):
def should_drop(fn):
attr = get_torchscript_modifier(fn)
if attr is None:
return False
return attr is FunctionModifiers.IGNORE_AND_DROP
return attr is FunctionModifiers.UNUSED
def is_ignored_fn(fn):
mod = get_torchscript_modifier(fn)
return mod is FunctionModifiers.IGNORE_AND_DROP or mod is FunctionModifiers.IGNORE
return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE
def get_torchscript_modifier(fn):