mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	add torch.jit.is_scripting() api (#25263)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25263 This adds an api to return true in script and false in eager, which together with ignore allows guarding of not yet supported JIT features. Bikeshedding requested please. cc zou3519 ``` def foo(): if not torch.jit.is_scripting(): return torch.linear(...) else: return addmm(...) ``` Test Plan: Imported from OSS Differential Revision: D17272443 Pulled By: eellison fbshipit-source-id: de0f769c7eaae91de0007b98969183df93a91f42
This commit is contained in:
		
				
					committed by
					
						
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							36bdde255e
						
					
				
				
					commit
					7ab4ad7b6d
				
			@ -6,6 +6,7 @@ circular dependency problems
 | 
			
		||||
 | 
			
		||||
import inspect
 | 
			
		||||
import weakref
 | 
			
		||||
import warnings
 | 
			
		||||
import torch._C
 | 
			
		||||
from torch._six import builtins
 | 
			
		||||
 | 
			
		||||
@ -168,7 +169,7 @@ class FunctionModifiers(object):
 | 
			
		||||
    Used to denote the behavior of a function in TorchScript. See export() and
 | 
			
		||||
    ignore() for details.
 | 
			
		||||
    """
 | 
			
		||||
    IGNORE_AND_DROP = "ignore (leave as a call to Python, replace with a 'raise' on torch.jit.save)"
 | 
			
		||||
    UNUSED = "unused (ignored and replaced with raising of an exception)"
 | 
			
		||||
    IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
 | 
			
		||||
    EXPORT = "export (compile this function even if nothing calls it)"
 | 
			
		||||
    DEFAULT = "default (compile if called from a exported function / forward)"
 | 
			
		||||
@ -219,23 +220,52 @@ def export(fn):
 | 
			
		||||
    return fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ignore(drop_on_export=False):
 | 
			
		||||
def unused(fn):
 | 
			
		||||
    """
 | 
			
		||||
    This decorator indicates to the compiler that a function or method should
 | 
			
		||||
    be ignored and left as a Python function.
 | 
			
		||||
    be ignored and replaced with the raising of an exception. This allows you
 | 
			
		||||
    to leave code in your model that is not yet TorchScript compatible and still
 | 
			
		||||
    export your model.
 | 
			
		||||
 | 
			
		||||
    Arguments:
 | 
			
		||||
        Example (using ``@torch.jit.unused`` on a method)::
 | 
			
		||||
 | 
			
		||||
        drop_on_export (bool):  When ``False``, calls to this function will
 | 
			
		||||
                                that will be run with ``example_inputs``.
 | 
			
		||||
                                arguments and returns to ``func`` must be tensors
 | 
			
		||||
                                or (possibly nested) tuples that
 | 
			
		||||
                                contain tensors. When ``True``, any calls to
 | 
			
		||||
                                this function from other TorchScript code will be replaced
 | 
			
		||||
                                with a `raise` when the model is saved.
 | 
			
		||||
                                This allows you to leave code in your TorchScript model that is only ever
 | 
			
		||||
                                run when the Python interpreter is present, but not run after you save
 | 
			
		||||
                                and load your model.
 | 
			
		||||
            import torch
 | 
			
		||||
            import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
            class MyModule(nn.Module):
 | 
			
		||||
                def __init__(self, use_memory_efficent):
 | 
			
		||||
                    super(MyModule, self).__init__()
 | 
			
		||||
                    self.use_memory_efficent = use_memory_efficent
 | 
			
		||||
 | 
			
		||||
                @torch.jit.unused
 | 
			
		||||
                def memory_efficient(self, x):
 | 
			
		||||
                    import pdb
 | 
			
		||||
                    pdb.set_trace()
 | 
			
		||||
                    return x + 10
 | 
			
		||||
 | 
			
		||||
                def forward(self, x):
 | 
			
		||||
                    # Use not-yet-scriptable memory efficient mode
 | 
			
		||||
                    if self.use_memory_efficient:
 | 
			
		||||
                        return self.memory_efficient(x)
 | 
			
		||||
                    else:
 | 
			
		||||
                        return x + 10
 | 
			
		||||
 | 
			
		||||
            m = torch.jit.script(MyModule(use_memory_efficent=False))
 | 
			
		||||
            m.save("m.pt")
 | 
			
		||||
 | 
			
		||||
            m = torch.jit.script(MyModule(use_memory_efficient=True))
 | 
			
		||||
            # exception raised
 | 
			
		||||
            m(torch.rand(100))
 | 
			
		||||
    """
 | 
			
		||||
    fn._torchscript_modifier = FunctionModifiers.UNUSED
 | 
			
		||||
    return fn
 | 
			
		||||
 | 
			
		||||
def ignore(drop=False, **kwargs):
 | 
			
		||||
    """
 | 
			
		||||
    This decorator indicates to the compiler that a function or method should
 | 
			
		||||
    be ignored and left as a Python function. This allows you to leave code in
 | 
			
		||||
    your model that is not yet TorchScript compatible. Models with ignored
 | 
			
		||||
    functions cannot be exported; use torch.jit.unused instead.
 | 
			
		||||
 | 
			
		||||
    Example (using ``@torch.jit.ignore`` on a method)::
 | 
			
		||||
 | 
			
		||||
@ -261,7 +291,7 @@ def ignore(drop_on_export=False):
 | 
			
		||||
        # Error! The call `debugger` cannot be saved since it calls into Python
 | 
			
		||||
        m.save("m.pt")
 | 
			
		||||
 | 
			
		||||
    Example (using ``@torch.jit.ignore(drop_on_export=True)`` on a method):
 | 
			
		||||
    Example (using ``@torch.jit.ignore(drop=True)`` on a method):
 | 
			
		||||
 | 
			
		||||
    .. testcode::
 | 
			
		||||
 | 
			
		||||
@ -269,7 +299,7 @@ def ignore(drop_on_export=False):
 | 
			
		||||
        import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
        class MyModule(nn.Module):
 | 
			
		||||
            @torch.jit.ignore(drop_on_export=True)
 | 
			
		||||
            @torch.jit.ignore(drop=True)
 | 
			
		||||
            def training_method(self, x):
 | 
			
		||||
                import pdb
 | 
			
		||||
                pdb.set_trace()
 | 
			
		||||
@ -290,24 +320,37 @@ def ignore(drop_on_export=False):
 | 
			
		||||
        import os
 | 
			
		||||
        os.remove('m.pt')
 | 
			
		||||
    """
 | 
			
		||||
    if callable(drop_on_export):
 | 
			
		||||
        # used without any args, so drop_on_export is actually a function
 | 
			
		||||
 | 
			
		||||
    if callable(drop):
 | 
			
		||||
        # used without any args, so drop is actually a function
 | 
			
		||||
        #   @torch.jit.ignore
 | 
			
		||||
        #   def fn(...):
 | 
			
		||||
        fn = drop_on_export
 | 
			
		||||
        fn = drop
 | 
			
		||||
        fn._torchscript_modifier = FunctionModifiers.IGNORE
 | 
			
		||||
        return fn
 | 
			
		||||
 | 
			
		||||
    if isinstance(drop_on_export, bool):
 | 
			
		||||
        def decorator(fn):
 | 
			
		||||
            if drop_on_export:
 | 
			
		||||
                fn._torchscript_modifier = FunctionModifiers.IGNORE_AND_DROP
 | 
			
		||||
            else:
 | 
			
		||||
                fn._torchscript_modifier = FunctionModifiers.IGNORE
 | 
			
		||||
            return fn
 | 
			
		||||
        return decorator
 | 
			
		||||
    raise RuntimeError("Argument to @torch.jit.ignore must be a bool or "
 | 
			
		||||
                       "a function but got {}".format(drop_on_export))
 | 
			
		||||
    if not isinstance(drop, bool):
 | 
			
		||||
        raise RuntimeError("Argument to @torch.jit.ignore must be a bool or "
 | 
			
		||||
                           "a function but got {}".format(drop))
 | 
			
		||||
 | 
			
		||||
    # for backwards compat
 | 
			
		||||
    drop_on_export = kwargs.pop("drop_on_export", None)
 | 
			
		||||
    if drop_on_export:
 | 
			
		||||
        warnings.warn("ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
 | 
			
		||||
                      "call on compilation. Use torch.jit.unused now. {}", category=DeprecationWarning)
 | 
			
		||||
 | 
			
		||||
        drop = drop_on_export
 | 
			
		||||
    elif drop:
 | 
			
		||||
        warnings.warn("ignore(True) has been deprecated. TorchScript will now drop the function "
 | 
			
		||||
                      "call on compilation. Use torch.jit.unused now. {}", category=DeprecationWarning)
 | 
			
		||||
 | 
			
		||||
    def decorator(fn):
 | 
			
		||||
        if drop:
 | 
			
		||||
            fn._torchscript_modifier = FunctionModifiers.UNUSED
 | 
			
		||||
        else:
 | 
			
		||||
            fn._torchscript_modifier = FunctionModifiers.IGNORE
 | 
			
		||||
        return fn
 | 
			
		||||
    return decorator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def module_has_exports(mod):
 | 
			
		||||
@ -318,16 +361,16 @@ def module_has_exports(mod):
 | 
			
		||||
                return True
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
def should_drop_on_export(fn):
 | 
			
		||||
def should_drop(fn):
 | 
			
		||||
    attr = get_torchscript_modifier(fn)
 | 
			
		||||
    if attr is None:
 | 
			
		||||
        return False
 | 
			
		||||
    return attr is FunctionModifiers.IGNORE_AND_DROP
 | 
			
		||||
    return attr is FunctionModifiers.UNUSED
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_ignored_fn(fn):
 | 
			
		||||
    mod = get_torchscript_modifier(fn)
 | 
			
		||||
    return mod is FunctionModifiers.IGNORE_AND_DROP or mod is FunctionModifiers.IGNORE
 | 
			
		||||
    return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_torchscript_modifier(fn):
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user