mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Use pyupgrade(https://github.com/asottile/pyupgrade) and flynt to modernize python syntax ```sh pyupgrade --py36-plus --keep-runtime-typing torch/onnx/**/*.py pyupgrade --py36-plus --keep-runtime-typing test/onnx/**/*.py flynt torch/onnx/ --line-length 120 ``` - Use f-strings for string formatting - Use the new `super()` syntax for class initialization - Use dictionary / set comprehension Pull Request resolved: https://github.com/pytorch/pytorch/pull/77935 Approved by: https://github.com/BowenBao
26 lines
658 B
Python
26 lines
658 B
Python
import torch.nn as nn
|
|
|
|
|
|
class EmbeddingNetwork1(nn.Module):
|
|
def __init__(self, dim=5):
|
|
super().__init__()
|
|
self.emb = nn.Embedding(10, dim)
|
|
self.lin1 = nn.Linear(dim, 1)
|
|
self.seq = nn.Sequential(
|
|
self.emb,
|
|
self.lin1,
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.seq(input)
|
|
|
|
|
|
class EmbeddingNetwork2(nn.Module):
|
|
def __init__(self, in_space=10, dim=3):
|
|
super().__init__()
|
|
self.embedding = nn.Embedding(in_space, dim)
|
|
self.seq = nn.Sequential(self.embedding, nn.Linear(dim, 1), nn.Sigmoid())
|
|
|
|
def forward(self, indices):
|
|
return self.seq(indices)
|