mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/137602 Approved by: https://github.com/malfet, https://github.com/albanD ghstack dependencies: #138936, #139221, #139433, #139541
		
			
				
	
	
		
			14 lines
		
	
	
		
			311 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			14 lines
		
	
	
		
			311 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| import sys
 | |
| 
 | |
| import torch
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     script_mod = torch.jit.load(sys.argv[1])
 | |
|     # weights_only=False as this is loading a sharded model
 | |
|     mod = torch.load(sys.argv[1] + ".orig", weights_only=False)
 | |
|     print(script_mod)
 | |
|     inp = torch.rand(2, 28 * 28)
 | |
|     _ = mod(inp)
 | |
|     sys.exit(0)
 |