mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
52 lines
1.2 KiB
Python
52 lines
1.2 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class DummyNet(nn.Module):
|
|
def __init__(self, num_classes=1000):
|
|
super().__init__()
|
|
self.features = nn.Sequential(
|
|
nn.LeakyReLU(0.02),
|
|
nn.BatchNorm2d(3),
|
|
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False),
|
|
)
|
|
|
|
def forward(self, x):
|
|
output = self.features(x)
|
|
return output.view(-1, 1).squeeze(1)
|
|
|
|
|
|
class ConcatNet(nn.Module):
|
|
def forward(self, inputs):
|
|
return torch.cat(inputs, 1)
|
|
|
|
|
|
class PermuteNet(nn.Module):
|
|
def forward(self, input):
|
|
return input.permute(2, 3, 0, 1)
|
|
|
|
|
|
class PReluNet(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.features = nn.Sequential(
|
|
nn.PReLU(3),
|
|
)
|
|
|
|
def forward(self, x):
|
|
output = self.features(x)
|
|
return output
|
|
|
|
|
|
class FakeQuantNet(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fake_quant = torch.ao.quantization.FakeQuantize()
|
|
self.fake_quant.disable_observer()
|
|
|
|
def forward(self, x):
|
|
output = self.fake_quant(x)
|
|
return output
|