mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Add example for torch.serialization.add_safe_globals (#129396)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129396 Approved by: https://github.com/albanD ghstack dependencies: #129244, #129251, #129239
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							381ce0821c
						
					
				
				
					commit
					f18becaaf1
				
			@ -209,6 +209,22 @@ def add_safe_globals(safe_globals: List[Any]) -> None:
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        safe_globals (List[Any]): list of globals to mark as safe
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
        >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
 | 
			
		||||
        >>> import tempfile
 | 
			
		||||
        >>> class MyTensor(torch.Tensor):
 | 
			
		||||
        ...     pass
 | 
			
		||||
        >>> t = MyTensor(torch.randn(2, 3))
 | 
			
		||||
        >>> with tempfile.NamedTemporaryFile() as f:
 | 
			
		||||
        ...     torch.save(t, f.name)
 | 
			
		||||
        # Running `torch.load(f.name, weights_only=True)` will fail with
 | 
			
		||||
        # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
 | 
			
		||||
        # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
 | 
			
		||||
        ...     torch.serialization.add_safe_globals([MyTensor])
 | 
			
		||||
        ...     torch.load(f.name, weights_only=True)
 | 
			
		||||
        # MyTensor([[-0.5024, -1.8152, -0.5455],
 | 
			
		||||
        #          [-0.8234,  2.0500, -0.3657]])
 | 
			
		||||
    """
 | 
			
		||||
    _weights_only_unpickler._add_safe_globals(safe_globals)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user