mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Use pyupgrade(https://github.com/asottile/pyupgrade) and flynt to modernize python syntax ```sh pyupgrade --py36-plus --keep-runtime-typing torch/onnx/**/*.py pyupgrade --py36-plus --keep-runtime-typing test/onnx/**/*.py flynt torch/onnx/ --line-length 120 ``` - Use f-strings for string formatting - Use the new `super()` syntax for class initialization - Use dictionary / set comprehension Pull Request resolved: https://github.com/pytorch/pytorch/pull/77935 Approved by: https://github.com/BowenBao
30 lines
1.0 KiB
Python
30 lines
1.0 KiB
Python
import torch.nn as nn
|
|
import torch.nn.init as init
|
|
|
|
|
|
class SuperResolutionNet(nn.Module):
|
|
def __init__(self, upscale_factor):
|
|
super().__init__()
|
|
|
|
self.relu = nn.ReLU()
|
|
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
|
|
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
|
|
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
|
|
self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
|
|
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
|
|
|
self._initialize_weights()
|
|
|
|
def forward(self, x):
|
|
x = self.relu(self.conv1(x))
|
|
x = self.relu(self.conv2(x))
|
|
x = self.relu(self.conv3(x))
|
|
x = self.pixel_shuffle(self.conv4(x))
|
|
return x
|
|
|
|
def _initialize_weights(self):
|
|
init.orthogonal_(self.conv1.weight, init.calculate_gain("relu"))
|
|
init.orthogonal_(self.conv2.weight, init.calculate_gain("relu"))
|
|
init.orthogonal_(self.conv3.weight, init.calculate_gain("relu"))
|
|
init.orthogonal_(self.conv4.weight)
|