mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-27 00:54:52 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
		
			
				
	
	
		
			37 lines
		
	
	
		
			900 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			37 lines
		
	
	
		
			900 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Usage: python create_dummy_model.py <name_of_the_file>
 | |
| import sys
 | |
| 
 | |
| import torch
 | |
| from torch import nn
 | |
| 
 | |
| 
 | |
| class NeuralNetwork(nn.Module):
 | |
|     def __init__(self) -> None:
 | |
|         super().__init__()
 | |
|         self.flatten = nn.Flatten()
 | |
|         self.linear_relu_stack = nn.Sequential(
 | |
|             nn.Linear(28 * 28, 512),
 | |
|             nn.ReLU(),
 | |
|             nn.Linear(512, 512),
 | |
|             nn.ReLU(),
 | |
|             nn.Linear(512, 10),
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x = self.flatten(x)
 | |
|         logits = self.linear_relu_stack(x)
 | |
|         return logits
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     jit_module = torch.jit.script(NeuralNetwork())
 | |
|     torch.jit.save(jit_module, sys.argv[1])
 | |
|     orig_module = nn.Sequential(
 | |
|         nn.Linear(28 * 28, 512),
 | |
|         nn.ReLU(),
 | |
|         nn.Linear(512, 512),
 | |
|         nn.ReLU(),
 | |
|         nn.Linear(512, 10),
 | |
|     )
 | |
|     torch.save(orig_module, sys.argv[1] + ".orig")
 |