Removes torchtest, expands generic device testing (#26374)

Summary:
- Removes torchtest
- <s>Moves test_torch tests skipped on ROCm to generic device test class</s>
- Creates test_nn generic device test class

Next: adding dtypes to generic device testing framework.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26374

Test Plan: Change is to tests themselves.

Differential Revision: D17442218

Pulled By: mruberry

fbshipit-source-id: d7e4451d09fc9049478b35a7efb8bb580071e8c8
This commit is contained in:
Mike Ruberry
2019-09-18 10:22:43 -07:00
committed by Facebook Github Bot
parent ed09704899
commit 388cfdf2ac
3 changed files with 371 additions and 478 deletions

View File

@ -216,52 +216,6 @@ def _test_function(fn, device):
return run_test_function
class torchtest():
"""Allows to generate and run per-device unittests.
This decorator class allows to generate and run per-device unittest.
Example:
class _TestTorchMixin(torchtest):
@torchtest.for_all_device_types()
def test_zeros_like(self, device):
expected = torch.zeros((100, 100,), device=device)
Will execute:
test_zeros_like (__main__.TestTorch) ... skipped 'Look at test_zeros_like_cpu, test_zeros_like_cuda results.'
test_zeros_like_cpu (__main__.TestTorch) ... ok
test_zeros_like_cuda (__main__.TestTorch) ... ok
To work properly, test class should be inherited from `torchtest`.
for_all_device_types decorator does not guarantee proper functionality in
combination with other decorators.
Please do not extend this decorator to support other cases (such as dtype,
layouts, etc) without consulting with bigger group. Devices is the special
case as build flags control additions/removals (see
https://github.com/pytorch/pytorch/pull/23824 for the reference).
"""
@classmethod
def for_all_device_types(cls):
def wrapper(fn):
test_names = []
for device in torch.testing.get_all_device_types():
test_name = fn.__name__ + '_' + device
assert not hasattr(cls, test_name), "Duplicated test name: " + test_name
setattr(cls, test_name, _test_function(fn, device))
test_names.append(test_name)
@wraps(fn)
def empty_test(*args, **kwargs):
raise unittest.SkipTest("Look at {} results.".format(", ".join(test_names)))
return empty_test
return wrapper
def skipIfNoLapack(fn):
@wraps(fn)
def wrapper(*args, **kwargs):

View File

@ -35,6 +35,7 @@ from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
ctcloss_reference, new_module_tests
from common_device_type import instantiate_device_type_tests
from torch.nn import MultiheadAttention
@ -966,34 +967,6 @@ class TestNN(NNTestCase):
with self.assertRaisesRegex(RuntimeError, 'negative stride is not supported'):
module(input)
def _test_dropout(self, cls, cuda, input):
p = 0.2
device = torch.device("cuda") if cuda else torch.device("cpu")
input = input.to(device).fill_(1 - p)
module = cls(p)
input_var = input.clone().requires_grad_()
output = module(input_var)
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
output.backward(input)
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
module = cls(p, True)
input_var = input.clone().requires_grad_()
output = module(input_var + 0)
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
output.backward(input)
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
# check eval mode doesn't change anything
for inplace in [True, False]:
module = cls(p, inplace).eval()
self.assertEqual(input, module(input))
# Check that these don't raise errors
module.__repr__()
str(module)
def _test_alpha_dropout(self, cls, input):
mean = input.mean()
std = input.std()
@ -3160,51 +3133,6 @@ class TestNN(NNTestCase):
gradcheck(func, [x])
gradgradcheck(func, [x])
def test_Dropout(self):
input = torch.Tensor(1000)
self._test_dropout(nn.Dropout, False, input)
def test_Dropout2d(self):
b = random.randint(1, 5)
w = random.randint(1, 5)
h = random.randint(1, 5)
num_features = 1000
input = torch.Tensor(num_features, b, w, h)
self._test_dropout(nn.Dropout2d, False, input)
def test_Dropout3d(self):
b = random.randint(1, 5)
w = random.randint(1, 5)
h = random.randint(1, 5)
d = random.randint(1, 2)
num_features = 1000
input = torch.Tensor(num_features, b, d, w, h)
self._test_dropout(nn.Dropout3d, False, input)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_Dropout_cuda(self):
input = torch.Tensor(1000)
self._test_dropout(nn.Dropout, True, input)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_Dropout2d_cuda(self):
b = random.randint(1, 5)
w = random.randint(1, 5)
h = random.randint(1, 5)
num_features = 1000
input = torch.Tensor(num_features, b, w, h)
self._test_dropout(nn.Dropout2d, True, input)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_Dropout3d_cuda(self):
b = random.randint(1, 5)
w = random.randint(1, 5)
h = random.randint(1, 5)
d = random.randint(1, 2)
num_features = 1000
input = torch.Tensor(num_features, b, d, w, h)
self._test_dropout(nn.Dropout3d, True, input)
def test_AlphaDropout(self):
# generate random tensor with zero mean and unit std
input = torch.randn(5000)
@ -3219,259 +3147,6 @@ class TestNN(NNTestCase):
input = torch.randn(num_features, b, d, w, h)
self._test_alpha_dropout(nn.FeatureAlphaDropout, input)
def _test_InstanceNorm_general(self, cls, input, device="cpu", dtype=torch.float):
# default case track_running_stats=False
b, c = input.size(0), input.size(1)
input_var = input.to(device=device, dtype=dtype).requires_grad_()
IN = cls(c, eps=0).to(device, dtype)
output = IN(input_var)
out_reshaped = output.view(b * c, -1)
mean = out_reshaped.mean(1)
var = out_reshaped.var(1, unbiased=False)
self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
# check that eval mode doesn't change behavior
grad_out = torch.randn_like(output)
res1 = output.data.clone()
output.backward(grad_out)
grad1 = input_var.grad.data.clone()
IN.eval()
output = IN(input_var)
input_var.grad = None
output.backward(grad_out)
res2 = output.data
grad2 = input_var.grad.data
self.assertEqual(res1, res2)
self.assertEqual(grad1, grad2)
# If track_running_stats=True and momentum=1, running_mean/var should be
# equal to mean/var of the input (with unbias correction)
IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype)
output = IN(input_var)
input_reshaped = input_var.transpose(1, 0).reshape(c, -1)
mean = input_reshaped.mean(1)
input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1)
var = input_reshaped.var(2, unbiased=True)[:, :]
self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5)
# in eval mode, adding X * std to a channel in input should make the
# corresponding channel in output have mean X
IN.eval()
delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype)
delta = delta.view(-1, *[1 for _ in range(2, input.dim())])
output = IN(input_var + delta)
self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c))
def _test_InstanceNorm_cuda_half(self, cls, input):
# THNN
input = input.to(device='cuda', dtype=torch.half).random_(1, 10).requires_grad_(True)
m = cls(input.size(1), affine=True, track_running_stats=True).to("cuda", torch.half)
thnn_output = m(input)
thnn_output.sum().backward()
thnn_input_grad = input.grad.data.clone()
self.assertEqual(thnn_output.type(), input.type())
# cuDNN
if TEST_CUDNN:
input.grad = None
m = m.float()
cudnn_output = m(input)
cudnn_output.sum().backward()
cudnn_input_grad = input.grad.data.clone()
self.assertEqual(cudnn_output.type(), input.type())
self.assertAlmostEqual(cudnn_output, thnn_output, delta=1e-4)
self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3)
def test_InstanceNorm1d_general(self):
b = random.randint(3, 5)
c = random.randint(3, 5)
d = random.randint(8, 10)
input = torch.rand(b, c, d)
self._test_InstanceNorm_general(nn.InstanceNorm1d, input, dtype=torch.float)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_InstanceNorm1d_general_cuda(self):
b = random.randint(3, 5)
c = random.randint(3, 5)
d = random.randint(8, 10)
input = torch.rand(b, c, d)
self._test_InstanceNorm_general(nn.InstanceNorm1d, input, "cuda", torch.float)
self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input)
def test_InstanceNorm2d_general(self):
b = random.randint(3, 5)
c = random.randint(3, 5)
w = random.randint(3, 6)
h = random.randint(6, 8)
input = torch.rand(b, c, h, w)
self._test_InstanceNorm_general(nn.InstanceNorm2d, input, dtype=torch.float)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_InstanceNorm2d_general_cuda(self):
b = random.randint(3, 5)
c = random.randint(3, 5)
w = random.randint(3, 6)
h = random.randint(6, 8)
input = torch.rand(b, c, h, w)
self._test_InstanceNorm_general(nn.InstanceNorm2d, input, "cuda", torch.float)
self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input)
def test_InstanceNorm3d_general(self):
b = random.randint(3, 5)
c = random.randint(3, 5)
w = random.randint(2, 5)
h = random.randint(2, 5)
d = random.randint(2, 5)
input = torch.rand(b, c, h, w, d)
self._test_InstanceNorm_general(nn.InstanceNorm3d, input, dtype=torch.float)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_InstanceNorm3d_general_cuda(self):
b = random.randint(3, 5)
c = random.randint(2, 5)
w = random.randint(2, 5)
h = random.randint(2, 5)
d = random.randint(2, 5)
input = torch.rand(b, c, h, w, d)
self._test_InstanceNorm_general(nn.InstanceNorm3d, input, "cuda", torch.float)
self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input)
def _test_LayerNorm_general(self, device="cpu", dtype=torch.float):
for i in range(2, 6):
shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist()
x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
normalized_ndim = random.randint(1, i - 1) # inclusive
normalized_shape = shape[-normalized_ndim:]
unnormalized_shape = shape[:-normalized_ndim]
# test that LN normalizes to mean 0 and stddev 1
ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype)
ln.weight.data.fill_(1)
ln.bias.data.fill_(0)
output = ln(x)
out_reshaped = output.view(*(unnormalized_shape + [-1]))
mean = out_reshaped.mean(-1)
var = out_reshaped.var(-1, unbiased=False)
self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
# test that LN applies weight and bias correctly
scale, bias = torch.empty(2).uniform_(0.2, 2).tolist()
ln.weight.data.fill_(scale)
ln.bias.data.fill_(bias)
output = ln(x)
out_reshaped = output.view(*(unnormalized_shape + [-1]))
mean = out_reshaped.mean(-1)
var = out_reshaped.var(-1, unbiased=False)
self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5)
self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5)
bad_norm_shape_input_shape = {
(): (),
(2, 3): (3,),
(2,): (1, 2, 3),
(10,): (2, 3),
10: (2, 3),
}
for norm_shape, input_shape in bad_norm_shape_input_shape.items():
ln = nn.LayerNorm(norm_shape)
input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10)
self.assertRaises(RuntimeError, lambda: ln(input))
def _test_LayerNorm_cuda_half(self):
input = torch.empty(2, 3, 3, 2, device="cuda", dtype=torch.half).random_(1, 10).requires_grad_(True)
m = nn.LayerNorm([3, 2]).to("cuda", torch.half)
output = m(input)
output.sum().backward()
self.assertEqual(output.type(), input.type())
def test_LayerNorm_general(self):
self._test_LayerNorm_general()
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_LayerNorm_general_cuda(self):
self._test_LayerNorm_general("cuda")
self._test_LayerNorm_cuda_half()
def _test_GroupNorm_general(self, device="cpu", dtype=torch.float):
good_shape_g = {
(1, 2, 3, 4): 2,
(2, 3, 10): 3,
(3, 1, 1, 1, 2): 1,
(2, 6, 4, 2, 2): 3,
}
for shape, g in good_shape_g.items():
x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
b = shape[0]
c = shape[1]
# test that GN normalizes to mean 0 and stddev 1
gn = nn.GroupNorm(g, c, eps=0).to(device, dtype)
gn.weight.data.fill_(1)
gn.bias.data.fill_(0)
output = gn(x)
out_reshaped = output.view(b, g, -1)
mean = out_reshaped.mean(-1)
var = out_reshaped.var(-1, unbiased=False)
self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
# test that GN applies weight and bias correctly
scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
gn.weight.data.copy_(scale)
gn.bias.data.copy_(bias)
output = gn(x)
out_reshaped = output.view(b, c, -1)
out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1)
out_normed_reshaped = out_normed.view(b, g, -1)
mean = out_normed_reshaped.mean(-1)
var = out_normed_reshaped.var(-1, unbiased=False)
self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
bad_shape_g = {
(1, 2, 3, 4): 3,
(2, 3, 10): 2,
(3, 1, 1, 1, 2): 10,
(2, 6, 4, 2, 2): 4,
}
for shape, g in bad_shape_g.items():
gn = nn.GroupNorm(g, shape[1])
input = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
self.assertRaises(RuntimeError, lambda: gn(input))
def _test_GroupNorm_cuda_half(self):
input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10)
m = nn.GroupNorm(2, 4).to("cuda", torch.half)
output = m(input)
output.sum().backward()
self.assertEqual(output.type(), input.type())
def test_GroupNorm_general(self):
self._test_GroupNorm_general(dtype=torch.float)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_GroupNorm_general_cuda(self):
self._test_GroupNorm_general("cuda", torch.float)
self._test_GroupNorm_cuda_half()
def test_pad(self):
inputs = torch.randn(1, 3, 4, 4, requires_grad=True)
_assertGradAndGradgradChecks(self, lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,))
@ -3490,115 +3165,11 @@ class TestNN(NNTestCase):
self.assertRaisesRegex(RuntimeError, expected_err_msg,
lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect'))
@staticmethod
def _test_one_hot(self, use_cuda=False):
device = torch.device('cuda' if use_cuda else 'cpu')
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
expected = torch.tensor([[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
expected = torch.tensor([[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
expected = torch.tensor([[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
expected = torch.tensor([[[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]],
[[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0]]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
expected = torch.tensor([0, 0, 0, 0, 1], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
expected = torch.empty([4, 0, 100])
self.assertEqual(t, expected)
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
def test_one_hot(self):
self._test_one_hot(self)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_one_hot_cuda(self):
self._test_one_hot(self, use_cuda=True)
def test_pad_scalar_error(self):
inputs = torch.tensor(0., requires_grad=True)
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1,)))
def test_nn_scalars(self):
# One off tests to ensure scalars from nn.yaml are properly applied
def verify_scalars(input, output):
if input.dim() == 0:
self.assertEqual((), output.shape)
else:
self.assertNotEqual((), output.shape)
output.sum().backward()
self.assertEqual(input.shape, input.grad.shape)
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
for device in devices:
for input_shape in [(5, 6), ()]:
for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
torch.nn.Tanh]:
input = torch.randn(input_shape, device=device, requires_grad=True)
m = module()
output = m(input)
verify_scalars(input, output)
def test_nn_scalars_reductions(self):
# One off tests to ensure scalars from nn.yaml are properly applied
def verify_reduction_scalars(input, reduction, output):
if reduction != 'none' or input.dim() == 0:
self.assertEqual((), output.shape)
else:
self.assertNotEqual((), output.shape)
output.sum().backward()
self.assertEqual(input.shape, input.grad.shape)
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
for device in devices:
for input_shape in [(5, 6), ()]:
for reduction in ['none', 'mean', 'sum']:
for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss,
torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]:
input = torch.randn(input_shape, device=device, requires_grad=True)
target = torch.empty(input_shape, device=device).random_(2)
sigmoid = nn.Sigmoid()
input = torch.randn(input_shape, device=device, requires_grad=True)
m = module(reduction=reduction)
output = m(sigmoid(input), target)
verify_reduction_scalars(input, reduction, output)
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
"Scipy v1.0 and/or numpy not found")
def test_multihead_attention(self):
@ -10090,6 +9661,374 @@ def _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_ra
return transform_tensor, transform_ary, grid_ary
# end TestNN.test_affine_* helpers
class GenericDeviceTypeHelpers(object):
def _test_dropout(self, cls, device, input):
p = 0.2
input = input.to(device).fill_(1 - p)
module = cls(p)
input_var = input.clone().requires_grad_()
output = module(input_var)
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
output.backward(input)
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
module = cls(p, True)
input_var = input.clone().requires_grad_()
output = module(input_var + 0)
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
output.backward(input)
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
# check eval mode doesn't change anything
for inplace in [True, False]:
module = cls(p, inplace).eval()
self.assertEqual(input, module(input))
# Check that these don't raise errors
module.__repr__()
str(module)
def _test_InstanceNorm_general(self, cls, input, device, dtype=torch.float):
# default case track_running_stats=False
b, c = input.size(0), input.size(1)
input_var = input.to(device=device, dtype=dtype).requires_grad_()
IN = cls(c, eps=0).to(device, dtype)
output = IN(input_var)
out_reshaped = output.view(b * c, -1)
mean = out_reshaped.mean(1)
var = out_reshaped.var(1, unbiased=False)
self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
# check that eval mode doesn't change behavior
grad_out = torch.randn_like(output)
res1 = output.data.clone()
output.backward(grad_out)
grad1 = input_var.grad.data.clone()
IN.eval()
output = IN(input_var)
input_var.grad = None
output.backward(grad_out)
res2 = output.data
grad2 = input_var.grad.data
self.assertEqual(res1, res2)
self.assertEqual(grad1, grad2)
# If track_running_stats=True and momentum=1, running_mean/var should be
# equal to mean/var of the input (with unbias correction)
IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype)
output = IN(input_var)
input_reshaped = input_var.transpose(1, 0).reshape(c, -1)
mean = input_reshaped.mean(1)
input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1)
var = input_reshaped.var(2, unbiased=True)[:, :]
self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5)
# in eval mode, adding X * std to a channel in input should make the
# corresponding channel in output have mean X
IN.eval()
delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype)
delta = delta.view(-1, *[1 for _ in range(2, input.dim())])
output = IN(input_var + delta)
self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c))
def _test_InstanceNorm_cuda_half(self, cls, input):
# THNN
input = input.to(device='cuda', dtype=torch.half).random_(1, 10).requires_grad_(True)
m = cls(input.size(1), affine=True, track_running_stats=True).to("cuda", torch.half)
thnn_output = m(input)
thnn_output.sum().backward()
thnn_input_grad = input.grad.data.clone()
self.assertEqual(thnn_output.type(), input.type())
# cuDNN
if TEST_CUDNN:
input.grad = None
m = m.float()
cudnn_output = m(input)
cudnn_output.sum().backward()
cudnn_input_grad = input.grad.data.clone()
self.assertEqual(cudnn_output.type(), input.type())
self.assertAlmostEqual(cudnn_output, thnn_output, delta=1e-4)
self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3)
def _test_LayerNorm_general(self, device, dtype=torch.float):
for i in range(2, 6):
shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist()
x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
normalized_ndim = random.randint(1, i - 1) # inclusive
normalized_shape = shape[-normalized_ndim:]
unnormalized_shape = shape[:-normalized_ndim]
# test that LN normalizes to mean 0 and stddev 1
ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype)
ln.weight.data.fill_(1)
ln.bias.data.fill_(0)
output = ln(x)
out_reshaped = output.view(*(unnormalized_shape + [-1]))
mean = out_reshaped.mean(-1)
var = out_reshaped.var(-1, unbiased=False)
self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
# test that LN applies weight and bias correctly
scale, bias = torch.empty(2).uniform_(0.2, 2).tolist()
ln.weight.data.fill_(scale)
ln.bias.data.fill_(bias)
output = ln(x)
out_reshaped = output.view(*(unnormalized_shape + [-1]))
mean = out_reshaped.mean(-1)
var = out_reshaped.var(-1, unbiased=False)
self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5)
self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5)
bad_norm_shape_input_shape = {
(): (),
(2, 3): (3,),
(2,): (1, 2, 3),
(10,): (2, 3),
10: (2, 3),
}
for norm_shape, input_shape in bad_norm_shape_input_shape.items():
ln = nn.LayerNorm(norm_shape)
input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10)
self.assertRaises(RuntimeError, lambda: ln(input))
def _test_LayerNorm_cuda_half(self):
input = torch.empty(2, 3, 3, 2, device="cuda", dtype=torch.half).random_(1, 10).requires_grad_(True)
m = nn.LayerNorm([3, 2]).to("cuda", torch.half)
output = m(input)
output.sum().backward()
self.assertEqual(output.type(), input.type())
def _test_GroupNorm_general(self, device, dtype=torch.float):
good_shape_g = {
(1, 2, 3, 4): 2,
(2, 3, 10): 3,
(3, 1, 1, 1, 2): 1,
(2, 6, 4, 2, 2): 3,
}
for shape, g in good_shape_g.items():
x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
b = shape[0]
c = shape[1]
# test that GN normalizes to mean 0 and stddev 1
gn = nn.GroupNorm(g, c, eps=0).to(device, dtype)
gn.weight.data.fill_(1)
gn.bias.data.fill_(0)
output = gn(x)
out_reshaped = output.view(b, g, -1)
mean = out_reshaped.mean(-1)
var = out_reshaped.var(-1, unbiased=False)
self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
# test that GN applies weight and bias correctly
scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
gn.weight.data.copy_(scale)
gn.bias.data.copy_(bias)
output = gn(x)
out_reshaped = output.view(b, c, -1)
out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1)
out_normed_reshaped = out_normed.view(b, g, -1)
mean = out_normed_reshaped.mean(-1)
var = out_normed_reshaped.var(-1, unbiased=False)
self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
bad_shape_g = {
(1, 2, 3, 4): 3,
(2, 3, 10): 2,
(3, 1, 1, 1, 2): 10,
(2, 6, 4, 2, 2): 4,
}
for shape, g in bad_shape_g.items():
gn = nn.GroupNorm(g, shape[1])
input = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
self.assertRaises(RuntimeError, lambda: gn(input))
def _test_GroupNorm_cuda_half(self):
input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10)
m = nn.GroupNorm(2, 4).to("cuda", torch.half)
output = m(input)
output.sum().backward()
self.assertEqual(output.type(), input.type())
class TestNNDeviceType(NNTestCase, GenericDeviceTypeHelpers):
def test_Dropout(self, device):
input = torch.Tensor(1000)
self._test_dropout(nn.Dropout, device, input)
def test_Dropout2d(self, device):
b = random.randint(1, 5)
w = random.randint(1, 5)
h = random.randint(1, 5)
num_features = 1000
input = torch.Tensor(num_features, b, w, h)
self._test_dropout(nn.Dropout2d, device, input)
def test_Dropout3d(self, device):
b = random.randint(1, 5)
w = random.randint(1, 5)
h = random.randint(1, 5)
d = random.randint(1, 2)
num_features = 1000
input = torch.Tensor(num_features, b, d, w, h)
self._test_dropout(nn.Dropout3d, device, input)
def test_InstanceNorm1d_general(self, device):
b = random.randint(3, 5)
c = random.randint(3, 5)
d = random.randint(8, 10)
input = torch.rand(b, c, d)
self._test_InstanceNorm_general(nn.InstanceNorm1d, input, device)
if device == 'cuda':
self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input)
def test_InstanceNorm2d_general(self, device):
b = random.randint(3, 5)
c = random.randint(3, 5)
w = random.randint(3, 6)
h = random.randint(6, 8)
input = torch.rand(b, c, h, w)
self._test_InstanceNorm_general(nn.InstanceNorm2d, input, device)
if device == 'cuda':
self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input)
def test_InstanceNorm3d_general(self, device):
b = random.randint(3, 5)
c = random.randint(3, 5)
w = random.randint(2, 5)
h = random.randint(2, 5)
d = random.randint(2, 5)
input = torch.rand(b, c, h, w, d)
self._test_InstanceNorm_general(nn.InstanceNorm3d, input, device)
if device == 'cuda':
self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input)
def test_LayerNorm_general(self, device):
self._test_LayerNorm_general(device)
if device == 'cuda':
self._test_LayerNorm_cuda_half()
def test_GroupNorm_general(self, device):
self._test_GroupNorm_general(device)
if device == 'cuda':
self._test_GroupNorm_cuda_half()
def test_one_hot(self, device):
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
expected = torch.tensor([[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
expected = torch.tensor([[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
expected = torch.tensor([[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
expected = torch.tensor([[[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]],
[[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0]]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
expected = torch.tensor([0, 0, 0, 0, 1], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
expected = torch.empty([4, 0, 100])
self.assertEqual(t, expected)
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
def test_nn_scalars(self, device):
# One off tests to ensure scalars from nn.yaml are properly applied
def verify_scalars(input, output):
if input.dim() == 0:
self.assertEqual((), output.shape)
else:
self.assertNotEqual((), output.shape)
output.sum().backward()
self.assertEqual(input.shape, input.grad.shape)
for input_shape in [(5, 6), ()]:
for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
torch.nn.Tanh]:
input = torch.randn(input_shape, device=device, requires_grad=True)
m = module()
output = m(input)
verify_scalars(input, output)
def test_nn_scalars_reductions(self, device):
# One off tests to ensure scalars from nn.yaml are properly applied
def verify_reduction_scalars(input, reduction, output):
if reduction != 'none' or input.dim() == 0:
self.assertEqual((), output.shape)
else:
self.assertNotEqual((), output.shape)
output.sum().backward()
self.assertEqual(input.shape, input.grad.shape)
for input_shape in [(5, 6), ()]:
for reduction in ['none', 'mean', 'sum']:
for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss,
torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]:
input = torch.randn(input_shape, device=device, requires_grad=True)
target = torch.empty(input_shape, device=device).random_(2)
sigmoid = nn.Sigmoid()
input = torch.randn(input_shape, device=device, requires_grad=True)
m = module(reduction=reduction)
output = m(sigmoid(input), target)
verify_reduction_scalars(input, reduction, output)
instantiate_device_type_tests(TestNNDeviceType, globals())
if __name__ == '__main__':
run_tests()

View File

@ -30,7 +30,7 @@ from common_methods_invocations import tri_tests_args, run_additional_tri_tests,
from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, torchtest, \
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
skipCUDANonDefaultStreamIf
from multiprocessing.reduction import ForkingPickler
from common_device_type import instantiate_device_type_tests, \
@ -104,7 +104,7 @@ class BytesIOContext(io.BytesIO):
# This is intentionally prefixed by an underscore. Otherwise pytest will try to
# run its methods as test cases.
class _TestTorchMixin(torchtest):
class _TestTorchMixin(object):
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True):
float_types = [torch.double,
torch.float]