Files
pytorch/test/create_dummy_torchscript_model.py
Han Qi d65414d145 Add test for FC/BC for torchscript file.
Summary:
title

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75136
Approved by: https://github.com/gmagogsfm
2022-04-13 23:23:13 +00:00

29 lines
693 B
Python

# Usage: python create_dummy_model.py <name_of_the_file>
import sys
import torch
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
if __name__ == '__main__':
jit_module = torch.jit.script(NeuralNetwork())
torch.jit.save(jit_module, sys.argv[1])