mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
cuDNN seems to be dispatching to TF32 kernels on B200 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161662 Approved by: https://github.com/Skylion007
758 lines
28 KiB
Python
758 lines
28 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.testing._internal.common_cuda import tf32_on_and_off
|
|
from torch.testing._internal.common_utils import (
|
|
enable_profiling_mode_for_profiling_tests,
|
|
GRAPH_EXECUTOR,
|
|
ProfilingMode,
|
|
raise_on_run_directly,
|
|
set_default_dtype,
|
|
slowTest,
|
|
suppress_warnings,
|
|
)
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
|
|
|
|
|
|
try:
|
|
import torchvision
|
|
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
except RuntimeError:
|
|
HAS_TORCHVISION = False
|
|
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
|
|
|
|
class MnistNet(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
|
self.conv2_drop = nn.Dropout2d()
|
|
self.fc1 = nn.Linear(320, 50)
|
|
self.fc2 = nn.Linear(50, 10)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
|
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
|
x = x.reshape(-1, 320)
|
|
x = F.relu(self.fc1(x))
|
|
x = F.dropout(x, training=self.training)
|
|
x = self.fc2(x)
|
|
return F.log_softmax(x, dim=1)
|
|
|
|
|
|
class TestModels(JitTestCase):
|
|
@staticmethod
|
|
def _test_dcgan_models(self, device, check_export_import=True):
|
|
class DCGANGenerator(nn.Module):
|
|
def __init__(self, nz, ngf, nc):
|
|
super().__init__()
|
|
self.main = nn.Sequential(
|
|
# input is Z, going into a convolution
|
|
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
|
|
nn.BatchNorm2d(ngf * 8),
|
|
nn.ReLU(True),
|
|
# state size. (ngf*8) x 4 x 4
|
|
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ngf * 4),
|
|
nn.ReLU(True),
|
|
# state size. (ngf*4) x 8 x 8
|
|
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ngf * 2),
|
|
nn.ReLU(True),
|
|
# state size. (ngf*2) x 16 x 16
|
|
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ngf),
|
|
nn.ReLU(True),
|
|
# state size. (ngf) x 32 x 32
|
|
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
|
nn.Tanh(),
|
|
# state size. (nc) x 64 x 64
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.main(input)
|
|
|
|
class DCGANDiscriminator(nn.Module):
|
|
def __init__(self, nc, ndf):
|
|
super().__init__()
|
|
self.main = nn.Sequential(
|
|
# input is (nc) x 64 x 64
|
|
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf) x 32 x 32
|
|
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ndf * 2),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf*2) x 16 x 16
|
|
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ndf * 4),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf*4) x 8 x 8
|
|
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ndf * 8),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf*8) x 4 x 4
|
|
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.main(input).view(-1, 1).squeeze(1)
|
|
|
|
bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
|
|
self.checkTrace(
|
|
DCGANGenerator(nz, ngf, nc).to(device),
|
|
(torch.rand(bs, nz, 1, 1, device=device),),
|
|
export_import=check_export_import,
|
|
)
|
|
example_input = DCGANGenerator(nz, ngf, nc).to(device)(
|
|
torch.rand(bs, nz, 1, 1, device=device)
|
|
)
|
|
self.checkTrace(
|
|
DCGANDiscriminator(nc, ndf).to(device),
|
|
(example_input,),
|
|
export_import=check_export_import,
|
|
)
|
|
|
|
def test_dcgan_models(self):
|
|
# Note: Can sometimes fail with low precision if run with float dtype
|
|
with set_default_dtype(torch.double):
|
|
self._test_dcgan_models(self, device="cpu")
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_dcgan_models_cuda(self):
|
|
# Note: Can sometimes fail with low precision if run with float dtype
|
|
with set_default_dtype(torch.double):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_dcgan_models(self, device="cuda", check_export_import=False)
|
|
|
|
@staticmethod
|
|
def _test_neural_style(self, device, check_export_import=True):
|
|
class TransformerNet(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
# Initial convolution layers
|
|
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
|
|
self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
|
|
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
|
|
self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
|
|
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
|
|
self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
|
|
# Residual layers
|
|
self.res1 = ResidualBlock(128)
|
|
self.res2 = ResidualBlock(128)
|
|
self.res3 = ResidualBlock(128)
|
|
self.res4 = ResidualBlock(128)
|
|
self.res5 = ResidualBlock(128)
|
|
# Upsampling Layers
|
|
self.deconv1 = UpsampleConvLayer(
|
|
128, 64, kernel_size=3, stride=1, upsample=2
|
|
)
|
|
self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
|
|
self.deconv2 = UpsampleConvLayer(
|
|
64, 32, kernel_size=3, stride=1, upsample=2
|
|
)
|
|
self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
|
|
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
|
|
# Non-linearities
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, X):
|
|
y = self.relu(self.in1(self.conv1(X)))
|
|
y = self.relu(self.in2(self.conv2(y)))
|
|
y = self.relu(self.in3(self.conv3(y)))
|
|
y = self.res1(y)
|
|
y = self.res2(y)
|
|
y = self.res3(y)
|
|
y = self.res4(y)
|
|
y = self.res5(y)
|
|
y = self.relu(self.in4(self.deconv1(y)))
|
|
y = self.relu(self.in5(self.deconv2(y)))
|
|
y = self.deconv3(y)
|
|
return y
|
|
|
|
class ConvLayer(torch.nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride):
|
|
super().__init__()
|
|
reflection_padding = kernel_size // 2
|
|
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
|
|
self.conv2d = torch.nn.Conv2d(
|
|
in_channels, out_channels, kernel_size, stride
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.reflection_pad(x)
|
|
out = self.conv2d(out)
|
|
return out
|
|
|
|
class ResidualBlock(torch.nn.Module):
|
|
"""ResidualBlock
|
|
introduced in: https://arxiv.org/abs/1512.03385
|
|
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
|
|
"""
|
|
|
|
def __init__(self, channels):
|
|
super().__init__()
|
|
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
|
|
self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
|
|
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
|
|
self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
out = self.relu(self.in1(self.conv1(x)))
|
|
out = self.in2(self.conv2(out))
|
|
out = out + residual
|
|
return out
|
|
|
|
class UpsampleConvLayer(torch.nn.Module):
|
|
"""UpsampleConvLayer
|
|
Upsamples the input and then does a convolution. This method gives better results
|
|
compared to ConvTranspose2d.
|
|
ref: http://distill.pub/2016/deconv-checkerboard/
|
|
"""
|
|
|
|
def __init__(
|
|
self, in_channels, out_channels, kernel_size, stride, upsample=None
|
|
):
|
|
super().__init__()
|
|
self.upsample = upsample
|
|
if upsample:
|
|
self.upsample_layer = torch.nn.Upsample(
|
|
mode="nearest", scale_factor=upsample
|
|
)
|
|
reflection_padding = kernel_size // 2
|
|
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
|
|
self.conv2d = torch.nn.Conv2d(
|
|
in_channels, out_channels, kernel_size, stride
|
|
)
|
|
|
|
def forward(self, x):
|
|
x_in = x
|
|
if self.upsample:
|
|
x_in = self.upsample_layer(x_in)
|
|
out = self.reflection_pad(x_in)
|
|
out = self.conv2d(out)
|
|
return out
|
|
|
|
self.checkTrace(
|
|
TransformerNet(),
|
|
(torch.rand(5, 3, 16, 16),),
|
|
export_import=check_export_import,
|
|
)
|
|
|
|
@slowTest
|
|
def test_neural_style(self):
|
|
self._test_neural_style(self, device="cpu")
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_neural_style_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_neural_style(self, device="cuda", check_export_import=False)
|
|
|
|
@unittest.skipIf(
|
|
GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor"
|
|
)
|
|
@staticmethod
|
|
def _test_mnist(self, device, check_export_import=True):
|
|
# eval() is present because dropout makes this nondeterministic
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
self.checkTrace(
|
|
MnistNet().to(device).eval(),
|
|
(torch.rand(5, 1, 28, 28, device=device),),
|
|
export_import=check_export_import,
|
|
)
|
|
|
|
def test_mnist(self):
|
|
self._test_mnist(self, device="cpu")
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_mnist_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_mnist(self, device="cuda", check_export_import=False)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_mnist_training_leaks_no_memory_cuda(self):
|
|
net = MnistNet().cuda()
|
|
# MnistNet uses dropout, don't check its trace
|
|
traced_net = torch.jit.trace(
|
|
net, [torch.randn(5, 1, 28, 28, device="cuda")], check_trace=False
|
|
)
|
|
|
|
def train(iters):
|
|
for _ in range(iters):
|
|
# Get some fake data
|
|
inp = torch.randn(5, 1, 28, 28, device="cuda")
|
|
out = traced_net(inp)
|
|
|
|
# Here's some fake loss
|
|
out.sum().backward()
|
|
|
|
# Zero out grads
|
|
traced_net.zero_grad()
|
|
|
|
# Set it up so the params have .grad fields so they are not reported as leaks
|
|
train(1)
|
|
|
|
with self.assertLeaksNoCudaTensors():
|
|
train(5)
|
|
|
|
@staticmethod
|
|
def _test_reinforcement_learning(self, device, test_export_import=True):
|
|
class Policy(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.affine1 = nn.Linear(4, 128)
|
|
self.affine2 = nn.Linear(128, 2)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.affine1(x))
|
|
action_scores = self.affine2(x)
|
|
return F.softmax(action_scores, dim=1)
|
|
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
self.checkTrace(
|
|
Policy().to(device),
|
|
(torch.rand(1, 4, device=device),),
|
|
export_import=test_export_import,
|
|
)
|
|
|
|
def test_reinforcement_learning(self):
|
|
self._test_reinforcement_learning(self, device="cpu")
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_reinforcement_learning_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_reinforcement_learning(self, device="cuda", test_export_import=False)
|
|
|
|
@staticmethod
|
|
def _test_snli(self, device, check_export_import=True):
|
|
class Bottle(nn.Module):
|
|
def forward(self, input):
|
|
if len(input.size()) <= 2:
|
|
return super().forward(input)
|
|
size = input.size()[:2]
|
|
out = super().forward(input.view(size[0] * size[1], -1))
|
|
return out.view(size[0], size[1], -1)
|
|
|
|
class Linear(Bottle, nn.Linear):
|
|
pass
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
input_size = config.d_proj if config.projection else config.d_embed
|
|
dropout = 0 if config.n_layers == 1 else config.dp_ratio
|
|
self.rnn = nn.LSTM(
|
|
input_size=input_size,
|
|
hidden_size=config.d_hidden,
|
|
num_layers=config.n_layers,
|
|
dropout=dropout,
|
|
bidirectional=config.birnn,
|
|
)
|
|
|
|
def forward(self, inputs):
|
|
batch_size = inputs.size()[1]
|
|
state_shape = self.config.n_cells, batch_size, self.config.d_hidden
|
|
h0 = c0 = inputs.new_zeros(state_shape)
|
|
_, (ht, _) = self.rnn(inputs, (h0, c0))
|
|
return (
|
|
ht[-1]
|
|
if not self.config.birnn
|
|
else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
|
|
)
|
|
|
|
class SNLIClassifier(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed = nn.Embedding(config.n_embed, config.d_embed)
|
|
self.projection = Linear(config.d_embed, config.d_proj)
|
|
self.encoder = Encoder(config)
|
|
self.dropout = nn.Dropout(p=config.dp_ratio)
|
|
self.relu = nn.ReLU()
|
|
seq_in_size = 2 * config.d_hidden
|
|
if self.config.birnn:
|
|
seq_in_size *= 2
|
|
lin_config = [seq_in_size] * 2
|
|
self.out = nn.Sequential(
|
|
Linear(*lin_config),
|
|
self.relu,
|
|
self.dropout,
|
|
Linear(*lin_config),
|
|
self.relu,
|
|
self.dropout,
|
|
Linear(*lin_config),
|
|
self.relu,
|
|
self.dropout,
|
|
Linear(seq_in_size, config.d_out),
|
|
)
|
|
|
|
def forward(self, premise, hypothesis):
|
|
prem_embed = self.embed(premise)
|
|
hypo_embed = self.embed(hypothesis)
|
|
if self.config.fix_emb:
|
|
prem_embed = prem_embed.detach()
|
|
hypo_embed = hypo_embed.detach()
|
|
if self.config.projection:
|
|
prem_embed = self.relu(self.projection(prem_embed))
|
|
hypo_embed = self.relu(self.projection(hypo_embed))
|
|
premise = self.encoder(prem_embed)
|
|
hypothesis = self.encoder(hypo_embed)
|
|
scores = self.out(torch.cat([premise, hypothesis], 1))
|
|
return scores
|
|
|
|
class Config:
|
|
n_embed = 100
|
|
d_embed = 100
|
|
d_proj = 300
|
|
dp_ratio = 0.0 # For deterministic testing TODO: change by fixing seed in checkTrace?
|
|
d_hidden = 30
|
|
birnn = True
|
|
d_out = 300
|
|
fix_emb = True
|
|
projection = True
|
|
n_layers = 2
|
|
n_cells = 4 # 2 * n_layers because birnn = True
|
|
|
|
premise = torch.LongTensor(48, 64).random_(0, 100).to(device)
|
|
hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device)
|
|
|
|
self.checkTrace(
|
|
SNLIClassifier(Config()).to(device),
|
|
(premise, hypothesis),
|
|
inputs_require_grads=False,
|
|
export_import=check_export_import,
|
|
)
|
|
|
|
@slowTest
|
|
def test_snli(self):
|
|
self._test_snli(self, device="cpu")
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_snli_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_snli(self, device="cuda", check_export_import=False)
|
|
|
|
@staticmethod
|
|
def _test_super_resolution(self, device, check_export_import=True):
|
|
class Net(nn.Module):
|
|
def __init__(self, upscale_factor):
|
|
super().__init__()
|
|
|
|
self.relu = nn.ReLU()
|
|
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
|
|
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
|
|
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
|
|
self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
|
|
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
|
|
|
def forward(self, x):
|
|
x = self.relu(self.conv1(x))
|
|
x = self.relu(self.conv2(x))
|
|
x = self.relu(self.conv3(x))
|
|
x = self.pixel_shuffle(self.conv4(x))
|
|
return x
|
|
|
|
net = Net(upscale_factor=4).to(device)
|
|
self.checkTrace(
|
|
net,
|
|
(torch.rand(5, 1, 32, 32, device=device),),
|
|
export_import=check_export_import,
|
|
)
|
|
|
|
@slowTest
|
|
def test_super_resolution(self):
|
|
self._test_super_resolution(self, device="cpu")
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
@tf32_on_and_off(0.02)
|
|
def test_super_resolution_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_super_resolution(self, device="cuda", check_export_import=False)
|
|
|
|
@suppress_warnings
|
|
def test_time_sequence_prediction(self):
|
|
class Sequence(torch.jit.ScriptModule):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.lstm1 = nn.LSTMCell(1, 51)
|
|
self.lstm2 = nn.LSTMCell(51, 51)
|
|
self.linear = nn.Linear(51, 1)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
# TODO: add future as input with default val
|
|
# see https://github.com/pytorch/pytorch/issues/8724
|
|
outputs = torch.empty((3, 0))
|
|
h_t = torch.zeros((3, 51))
|
|
c_t = torch.zeros((3, 51))
|
|
h_t2 = torch.zeros((3, 51))
|
|
c_t2 = torch.zeros((3, 51))
|
|
|
|
output = torch.zeros([3, 51])
|
|
future = 2
|
|
|
|
# TODO: chunk call should appear as the for loop iterable
|
|
# We hard-code it to 4 for now.
|
|
a, b, c, d = input.chunk(input.size(1), dim=1)
|
|
for input_t in (a, b, c, d):
|
|
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
|
|
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
|
|
output = self.linear(h_t2)
|
|
outputs = torch.cat((outputs, output), 1)
|
|
for _ in range(future): # if we should predict the future
|
|
h_t, c_t = self.lstm1(output, (h_t, c_t))
|
|
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
|
|
output = self.linear(h_t2)
|
|
outputs = torch.cat((outputs, output), 1)
|
|
return outputs
|
|
|
|
class Traced(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = Sequence()
|
|
|
|
def forward(self, input):
|
|
return self.seq.forward(input)
|
|
|
|
# disabled due to a jitter issues that will be fixed by using load/store in the compiler
|
|
with torch._jit_internal._disable_emit_hooks():
|
|
# TODO: toggle export_import once above issues are fixed
|
|
self.checkTrace(Traced(), (torch.rand(3, 4),), export_import=False)
|
|
|
|
@staticmethod
|
|
def _test_vae(self, device, check_export_import=True):
|
|
class VAE(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
self.fc1 = nn.Linear(784, 400)
|
|
self.fc21 = nn.Linear(400, 20)
|
|
self.fc22 = nn.Linear(400, 20)
|
|
self.fc3 = nn.Linear(20, 400)
|
|
self.fc4 = nn.Linear(400, 784)
|
|
|
|
def encode(self, x):
|
|
h1 = F.relu(self.fc1(x))
|
|
return self.fc21(h1), self.fc22(h1)
|
|
|
|
def reparameterize(self, mu, logvar):
|
|
if self.training:
|
|
std = torch.exp(0.5 * logvar)
|
|
eps = torch.randn_like(std)
|
|
return eps.mul(std).add_(mu)
|
|
else:
|
|
return mu
|
|
|
|
def decode(self, z):
|
|
h3 = F.relu(self.fc3(z))
|
|
return torch.sigmoid(self.fc4(h3))
|
|
|
|
def forward(self, x):
|
|
mu, logvar = self.encode(x.view(-1, 784))
|
|
z = self.reparameterize(mu, logvar)
|
|
return self.decode(z), mu, logvar
|
|
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
# eval() is present because randn_like makes this nondeterministic
|
|
self.checkTrace(
|
|
VAE().to(device).eval(),
|
|
(torch.rand(128, 1, 28, 28, device=device),),
|
|
export_import=check_export_import,
|
|
)
|
|
|
|
def test_vae(self):
|
|
self._test_vae(self, device="cpu")
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_vae_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_vae(self, device="cuda", check_export_import=False)
|
|
|
|
@slowTest
|
|
@skipIfNoTorchVision
|
|
def test_script_module_trace_resnet18(self):
|
|
m_orig = torch.jit.trace(
|
|
torchvision.models.resnet18(), torch.ones(1, 3, 224, 224)
|
|
)
|
|
m_import = self.getExportImportCopy(m_orig)
|
|
|
|
input = torch.randn(1, 3, 224, 224, requires_grad=True)
|
|
output_orig = m_orig(input)
|
|
output_orig.sum().backward()
|
|
grad_orig = input.grad.clone()
|
|
input.grad.zero_()
|
|
|
|
output_import = m_import(input)
|
|
output_import.sum().backward()
|
|
grad_import = input.grad.clone()
|
|
|
|
self.assertEqual(output_orig, output_import)
|
|
self.assertEqual(grad_orig, grad_import)
|
|
|
|
@slowTest
|
|
@skipIfNoTorchVision
|
|
def test_script_module_script_resnet(self):
|
|
def conv1x1(in_planes, out_planes, stride=1):
|
|
"""1x1 convolution"""
|
|
return nn.Conv2d(
|
|
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
|
|
)
|
|
|
|
def conv3x3(in_planes, out_planes, stride=1):
|
|
"""3x3 convolution with padding"""
|
|
return nn.Conv2d(
|
|
in_planes,
|
|
out_planes,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=1,
|
|
bias=False,
|
|
)
|
|
|
|
class BasicBlock(torch.jit.ScriptModule):
|
|
expansion = 1
|
|
__constants__ = ["downsample"]
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
|
super().__init__()
|
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
|
self.bn1 = nn.BatchNorm2d(planes)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.conv2 = conv3x3(planes, planes)
|
|
self.bn2 = nn.BatchNorm2d(planes)
|
|
self.downsample = downsample
|
|
self.stride = stride
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
if self.downsample is not None:
|
|
residual = self.downsample(x)
|
|
|
|
out += residual
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
class ResNet(torch.jit.ScriptModule):
|
|
__constants__ = ["layer1", "layer2", "layer3", "layer4"]
|
|
|
|
def __init__(self, block, layers, num_classes=1000):
|
|
super().__init__()
|
|
self.inplanes = 64
|
|
self.conv1 = nn.Conv2d(
|
|
3, 64, kernel_size=7, stride=2, padding=3, bias=False
|
|
)
|
|
self.bn1 = nn.BatchNorm2d(64)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(
|
|
m.weight, mode="fan_out", nonlinearity="relu"
|
|
)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def _make_layer(self, block, planes, blocks, stride=1):
|
|
downsample = None
|
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
downsample = nn.Sequential(
|
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
|
nn.BatchNorm2d(planes * block.expansion),
|
|
)
|
|
|
|
layers = []
|
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
|
self.inplanes = planes * block.expansion
|
|
for _ in range(1, blocks):
|
|
layers.append(block(self.inplanes, planes))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.maxpool(x)
|
|
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
x = self.layer4(x)
|
|
|
|
x = self.avgpool(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
|
|
return x
|
|
|
|
resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
|
|
|
|
resnet18_imported = self.getExportImportCopy(resnet18)
|
|
|
|
input = torch.randn(1, 3, 224, 224, requires_grad=True)
|
|
output_orig = resnet18(input)
|
|
output_orig.sum().backward()
|
|
grad_orig = input.grad.clone()
|
|
input.grad.zero_()
|
|
output_import = resnet18_imported(input)
|
|
output_import.sum().backward()
|
|
grad_import = input.grad.clone()
|
|
|
|
self.assertEqual(output_orig, output_import)
|
|
self.assertEqual(grad_orig, grad_import)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_alexnet(self):
|
|
x = torch.ones(1, 3, 224, 224)
|
|
model = torchvision.models.AlexNet()
|
|
with torch.random.fork_rng(devices=[]):
|
|
g, outputs, inputs = torch.jit._get_trace_graph(
|
|
model, x, return_inputs=True
|
|
)
|
|
self.run_pass("cse", g)
|
|
m = self.createFunctionFromGraph(g)
|
|
with torch.random.fork_rng(devices=[]):
|
|
self.assertEqual(outputs, m(*inputs))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_jit.py")
|