mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	[JIT] Document torch.jit.interface (#109356)
Good option for replacing "Callable" types; we should document it so it's searchable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109356 Approved by: https://github.com/eellison, https://github.com/gmagogsfm
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							ec8b58f5ba
						
					
				
				
					commit
					b4ea3260d7
				
			@ -1468,6 +1468,53 @@ def _check_directly_compile_overloaded(obj):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def interface(obj):
 | 
			
		||||
    r"""
 | 
			
		||||
    This decorator can be used to define an interface that can be used to annotate
 | 
			
		||||
    classes or modules of different types. This can be used for to annotate a submodule
 | 
			
		||||
    or attribute class that could have different types that implement the same
 | 
			
		||||
    interface, or which could be swapped at runtime; or to store a list of modules or
 | 
			
		||||
    classes of varying types.
 | 
			
		||||
 | 
			
		||||
    It is sometimes used to implement "Callables" - functions or modules that implement
 | 
			
		||||
    an interface but whose implementations differ and which can be swapped out.
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
 | 
			
		||||
    .. testcode::
 | 
			
		||||
 | 
			
		||||
        import torch
 | 
			
		||||
        from typing import List
 | 
			
		||||
 | 
			
		||||
        @torch.jit.interface
 | 
			
		||||
        class InterfaceType:
 | 
			
		||||
            def run(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        # implements InterfaceType
 | 
			
		||||
        @torch.jit.script
 | 
			
		||||
        class Impl1:
 | 
			
		||||
            def run(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
                return x.relu()
 | 
			
		||||
 | 
			
		||||
        class Impl2(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super().__init__()
 | 
			
		||||
                self.val = torch.rand(())
 | 
			
		||||
 | 
			
		||||
            @torch.jit.export
 | 
			
		||||
            def run(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
                return x + self.val
 | 
			
		||||
 | 
			
		||||
        def user_fn(impls: List[InterfaceType], idx: int, val: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
            return impls[idx].run(val)
 | 
			
		||||
 | 
			
		||||
        user_fn_jit = torch.jit.script(user_fn)
 | 
			
		||||
 | 
			
		||||
        impls = [Impl1(), torch.jit.script(Impl2())]
 | 
			
		||||
        val = torch.rand(4, 4)
 | 
			
		||||
        user_fn_jit(impls, 0, val)
 | 
			
		||||
        user_fn_jit(impls, 1, val)
 | 
			
		||||
    """
 | 
			
		||||
    if not inspect.isclass(obj):
 | 
			
		||||
        raise RuntimeError("interface must be applied to a class")
 | 
			
		||||
    if not _is_new_style_class(obj):
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user