mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
Removes torchtest, expands generic device testing (#26374)
Summary: - Removes torchtest - <s>Moves test_torch tests skipped on ROCm to generic device test class</s> - Creates test_nn generic device test class Next: adding dtypes to generic device testing framework. Pull Request resolved: https://github.com/pytorch/pytorch/pull/26374 Test Plan: Change is to tests themselves. Differential Revision: D17442218 Pulled By: mruberry fbshipit-source-id: d7e4451d09fc9049478b35a7efb8bb580071e8c8
This commit is contained in:
committed by
Facebook Github Bot
parent
ed09704899
commit
388cfdf2ac
@ -216,52 +216,6 @@ def _test_function(fn, device):
|
||||
return run_test_function
|
||||
|
||||
|
||||
class torchtest():
|
||||
"""Allows to generate and run per-device unittests.
|
||||
|
||||
This decorator class allows to generate and run per-device unittest.
|
||||
|
||||
Example:
|
||||
|
||||
class _TestTorchMixin(torchtest):
|
||||
|
||||
@torchtest.for_all_device_types()
|
||||
def test_zeros_like(self, device):
|
||||
expected = torch.zeros((100, 100,), device=device)
|
||||
|
||||
Will execute:
|
||||
|
||||
test_zeros_like (__main__.TestTorch) ... skipped 'Look at test_zeros_like_cpu, test_zeros_like_cuda results.'
|
||||
test_zeros_like_cpu (__main__.TestTorch) ... ok
|
||||
test_zeros_like_cuda (__main__.TestTorch) ... ok
|
||||
|
||||
To work properly, test class should be inherited from `torchtest`.
|
||||
for_all_device_types decorator does not guarantee proper functionality in
|
||||
combination with other decorators.
|
||||
|
||||
Please do not extend this decorator to support other cases (such as dtype,
|
||||
layouts, etc) without consulting with bigger group. Devices is the special
|
||||
case as build flags control additions/removals (see
|
||||
https://github.com/pytorch/pytorch/pull/23824 for the reference).
|
||||
"""
|
||||
@classmethod
|
||||
def for_all_device_types(cls):
|
||||
def wrapper(fn):
|
||||
test_names = []
|
||||
|
||||
for device in torch.testing.get_all_device_types():
|
||||
test_name = fn.__name__ + '_' + device
|
||||
assert not hasattr(cls, test_name), "Duplicated test name: " + test_name
|
||||
setattr(cls, test_name, _test_function(fn, device))
|
||||
test_names.append(test_name)
|
||||
|
||||
@wraps(fn)
|
||||
def empty_test(*args, **kwargs):
|
||||
raise unittest.SkipTest("Look at {} results.".format(", ".join(test_names)))
|
||||
return empty_test
|
||||
return wrapper
|
||||
|
||||
|
||||
def skipIfNoLapack(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
799
test/test_nn.py
799
test/test_nn.py
@ -35,6 +35,7 @@ from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
|
||||
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
|
||||
module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
|
||||
ctcloss_reference, new_module_tests
|
||||
from common_device_type import instantiate_device_type_tests
|
||||
|
||||
from torch.nn import MultiheadAttention
|
||||
|
||||
@ -966,34 +967,6 @@ class TestNN(NNTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, 'negative stride is not supported'):
|
||||
module(input)
|
||||
|
||||
def _test_dropout(self, cls, cuda, input):
|
||||
p = 0.2
|
||||
device = torch.device("cuda") if cuda else torch.device("cpu")
|
||||
input = input.to(device).fill_(1 - p)
|
||||
|
||||
module = cls(p)
|
||||
input_var = input.clone().requires_grad_()
|
||||
output = module(input_var)
|
||||
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
|
||||
output.backward(input)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
|
||||
|
||||
module = cls(p, True)
|
||||
input_var = input.clone().requires_grad_()
|
||||
output = module(input_var + 0)
|
||||
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
|
||||
output.backward(input)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
|
||||
|
||||
# check eval mode doesn't change anything
|
||||
for inplace in [True, False]:
|
||||
module = cls(p, inplace).eval()
|
||||
self.assertEqual(input, module(input))
|
||||
|
||||
# Check that these don't raise errors
|
||||
module.__repr__()
|
||||
str(module)
|
||||
|
||||
def _test_alpha_dropout(self, cls, input):
|
||||
mean = input.mean()
|
||||
std = input.std()
|
||||
@ -3160,51 +3133,6 @@ class TestNN(NNTestCase):
|
||||
gradcheck(func, [x])
|
||||
gradgradcheck(func, [x])
|
||||
|
||||
def test_Dropout(self):
|
||||
input = torch.Tensor(1000)
|
||||
self._test_dropout(nn.Dropout, False, input)
|
||||
|
||||
def test_Dropout2d(self):
|
||||
b = random.randint(1, 5)
|
||||
w = random.randint(1, 5)
|
||||
h = random.randint(1, 5)
|
||||
num_features = 1000
|
||||
input = torch.Tensor(num_features, b, w, h)
|
||||
self._test_dropout(nn.Dropout2d, False, input)
|
||||
|
||||
def test_Dropout3d(self):
|
||||
b = random.randint(1, 5)
|
||||
w = random.randint(1, 5)
|
||||
h = random.randint(1, 5)
|
||||
d = random.randint(1, 2)
|
||||
num_features = 1000
|
||||
input = torch.Tensor(num_features, b, d, w, h)
|
||||
self._test_dropout(nn.Dropout3d, False, input)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_Dropout_cuda(self):
|
||||
input = torch.Tensor(1000)
|
||||
self._test_dropout(nn.Dropout, True, input)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_Dropout2d_cuda(self):
|
||||
b = random.randint(1, 5)
|
||||
w = random.randint(1, 5)
|
||||
h = random.randint(1, 5)
|
||||
num_features = 1000
|
||||
input = torch.Tensor(num_features, b, w, h)
|
||||
self._test_dropout(nn.Dropout2d, True, input)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_Dropout3d_cuda(self):
|
||||
b = random.randint(1, 5)
|
||||
w = random.randint(1, 5)
|
||||
h = random.randint(1, 5)
|
||||
d = random.randint(1, 2)
|
||||
num_features = 1000
|
||||
input = torch.Tensor(num_features, b, d, w, h)
|
||||
self._test_dropout(nn.Dropout3d, True, input)
|
||||
|
||||
def test_AlphaDropout(self):
|
||||
# generate random tensor with zero mean and unit std
|
||||
input = torch.randn(5000)
|
||||
@ -3219,259 +3147,6 @@ class TestNN(NNTestCase):
|
||||
input = torch.randn(num_features, b, d, w, h)
|
||||
self._test_alpha_dropout(nn.FeatureAlphaDropout, input)
|
||||
|
||||
def _test_InstanceNorm_general(self, cls, input, device="cpu", dtype=torch.float):
|
||||
# default case track_running_stats=False
|
||||
b, c = input.size(0), input.size(1)
|
||||
input_var = input.to(device=device, dtype=dtype).requires_grad_()
|
||||
|
||||
IN = cls(c, eps=0).to(device, dtype)
|
||||
|
||||
output = IN(input_var)
|
||||
out_reshaped = output.view(b * c, -1)
|
||||
|
||||
mean = out_reshaped.mean(1)
|
||||
var = out_reshaped.var(1, unbiased=False)
|
||||
|
||||
self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
|
||||
|
||||
# check that eval mode doesn't change behavior
|
||||
grad_out = torch.randn_like(output)
|
||||
res1 = output.data.clone()
|
||||
output.backward(grad_out)
|
||||
grad1 = input_var.grad.data.clone()
|
||||
|
||||
IN.eval()
|
||||
output = IN(input_var)
|
||||
input_var.grad = None
|
||||
output.backward(grad_out)
|
||||
res2 = output.data
|
||||
grad2 = input_var.grad.data
|
||||
self.assertEqual(res1, res2)
|
||||
self.assertEqual(grad1, grad2)
|
||||
|
||||
# If track_running_stats=True and momentum=1, running_mean/var should be
|
||||
# equal to mean/var of the input (with unbias correction)
|
||||
IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype)
|
||||
|
||||
output = IN(input_var)
|
||||
|
||||
input_reshaped = input_var.transpose(1, 0).reshape(c, -1)
|
||||
mean = input_reshaped.mean(1)
|
||||
|
||||
input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1)
|
||||
var = input_reshaped.var(2, unbiased=True)[:, :]
|
||||
|
||||
self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5)
|
||||
|
||||
# in eval mode, adding X * std to a channel in input should make the
|
||||
# corresponding channel in output have mean X
|
||||
IN.eval()
|
||||
delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype)
|
||||
delta = delta.view(-1, *[1 for _ in range(2, input.dim())])
|
||||
output = IN(input_var + delta)
|
||||
self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c))
|
||||
|
||||
def _test_InstanceNorm_cuda_half(self, cls, input):
|
||||
# THNN
|
||||
input = input.to(device='cuda', dtype=torch.half).random_(1, 10).requires_grad_(True)
|
||||
m = cls(input.size(1), affine=True, track_running_stats=True).to("cuda", torch.half)
|
||||
thnn_output = m(input)
|
||||
thnn_output.sum().backward()
|
||||
thnn_input_grad = input.grad.data.clone()
|
||||
self.assertEqual(thnn_output.type(), input.type())
|
||||
# cuDNN
|
||||
if TEST_CUDNN:
|
||||
input.grad = None
|
||||
m = m.float()
|
||||
cudnn_output = m(input)
|
||||
cudnn_output.sum().backward()
|
||||
cudnn_input_grad = input.grad.data.clone()
|
||||
self.assertEqual(cudnn_output.type(), input.type())
|
||||
self.assertAlmostEqual(cudnn_output, thnn_output, delta=1e-4)
|
||||
self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3)
|
||||
|
||||
def test_InstanceNorm1d_general(self):
|
||||
b = random.randint(3, 5)
|
||||
c = random.randint(3, 5)
|
||||
d = random.randint(8, 10)
|
||||
|
||||
input = torch.rand(b, c, d)
|
||||
self._test_InstanceNorm_general(nn.InstanceNorm1d, input, dtype=torch.float)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_InstanceNorm1d_general_cuda(self):
|
||||
b = random.randint(3, 5)
|
||||
c = random.randint(3, 5)
|
||||
d = random.randint(8, 10)
|
||||
|
||||
input = torch.rand(b, c, d)
|
||||
self._test_InstanceNorm_general(nn.InstanceNorm1d, input, "cuda", torch.float)
|
||||
self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input)
|
||||
|
||||
def test_InstanceNorm2d_general(self):
|
||||
b = random.randint(3, 5)
|
||||
c = random.randint(3, 5)
|
||||
w = random.randint(3, 6)
|
||||
h = random.randint(6, 8)
|
||||
|
||||
input = torch.rand(b, c, h, w)
|
||||
self._test_InstanceNorm_general(nn.InstanceNorm2d, input, dtype=torch.float)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_InstanceNorm2d_general_cuda(self):
|
||||
b = random.randint(3, 5)
|
||||
c = random.randint(3, 5)
|
||||
w = random.randint(3, 6)
|
||||
h = random.randint(6, 8)
|
||||
|
||||
input = torch.rand(b, c, h, w)
|
||||
self._test_InstanceNorm_general(nn.InstanceNorm2d, input, "cuda", torch.float)
|
||||
self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input)
|
||||
|
||||
def test_InstanceNorm3d_general(self):
|
||||
b = random.randint(3, 5)
|
||||
c = random.randint(3, 5)
|
||||
w = random.randint(2, 5)
|
||||
h = random.randint(2, 5)
|
||||
d = random.randint(2, 5)
|
||||
|
||||
input = torch.rand(b, c, h, w, d)
|
||||
self._test_InstanceNorm_general(nn.InstanceNorm3d, input, dtype=torch.float)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_InstanceNorm3d_general_cuda(self):
|
||||
b = random.randint(3, 5)
|
||||
c = random.randint(2, 5)
|
||||
w = random.randint(2, 5)
|
||||
h = random.randint(2, 5)
|
||||
d = random.randint(2, 5)
|
||||
|
||||
input = torch.rand(b, c, h, w, d)
|
||||
self._test_InstanceNorm_general(nn.InstanceNorm3d, input, "cuda", torch.float)
|
||||
self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input)
|
||||
|
||||
def _test_LayerNorm_general(self, device="cpu", dtype=torch.float):
|
||||
for i in range(2, 6):
|
||||
shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist()
|
||||
x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
|
||||
normalized_ndim = random.randint(1, i - 1) # inclusive
|
||||
normalized_shape = shape[-normalized_ndim:]
|
||||
unnormalized_shape = shape[:-normalized_ndim]
|
||||
|
||||
# test that LN normalizes to mean 0 and stddev 1
|
||||
ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype)
|
||||
ln.weight.data.fill_(1)
|
||||
ln.bias.data.fill_(0)
|
||||
output = ln(x)
|
||||
out_reshaped = output.view(*(unnormalized_shape + [-1]))
|
||||
mean = out_reshaped.mean(-1)
|
||||
var = out_reshaped.var(-1, unbiased=False)
|
||||
self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
|
||||
|
||||
# test that LN applies weight and bias correctly
|
||||
scale, bias = torch.empty(2).uniform_(0.2, 2).tolist()
|
||||
ln.weight.data.fill_(scale)
|
||||
ln.bias.data.fill_(bias)
|
||||
output = ln(x)
|
||||
out_reshaped = output.view(*(unnormalized_shape + [-1]))
|
||||
mean = out_reshaped.mean(-1)
|
||||
var = out_reshaped.var(-1, unbiased=False)
|
||||
self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5)
|
||||
|
||||
bad_norm_shape_input_shape = {
|
||||
(): (),
|
||||
(2, 3): (3,),
|
||||
(2,): (1, 2, 3),
|
||||
(10,): (2, 3),
|
||||
10: (2, 3),
|
||||
}
|
||||
for norm_shape, input_shape in bad_norm_shape_input_shape.items():
|
||||
ln = nn.LayerNorm(norm_shape)
|
||||
input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10)
|
||||
self.assertRaises(RuntimeError, lambda: ln(input))
|
||||
|
||||
def _test_LayerNorm_cuda_half(self):
|
||||
input = torch.empty(2, 3, 3, 2, device="cuda", dtype=torch.half).random_(1, 10).requires_grad_(True)
|
||||
m = nn.LayerNorm([3, 2]).to("cuda", torch.half)
|
||||
output = m(input)
|
||||
output.sum().backward()
|
||||
self.assertEqual(output.type(), input.type())
|
||||
|
||||
def test_LayerNorm_general(self):
|
||||
self._test_LayerNorm_general()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_LayerNorm_general_cuda(self):
|
||||
self._test_LayerNorm_general("cuda")
|
||||
self._test_LayerNorm_cuda_half()
|
||||
|
||||
def _test_GroupNorm_general(self, device="cpu", dtype=torch.float):
|
||||
good_shape_g = {
|
||||
(1, 2, 3, 4): 2,
|
||||
(2, 3, 10): 3,
|
||||
(3, 1, 1, 1, 2): 1,
|
||||
(2, 6, 4, 2, 2): 3,
|
||||
}
|
||||
for shape, g in good_shape_g.items():
|
||||
x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
|
||||
b = shape[0]
|
||||
c = shape[1]
|
||||
|
||||
# test that GN normalizes to mean 0 and stddev 1
|
||||
gn = nn.GroupNorm(g, c, eps=0).to(device, dtype)
|
||||
gn.weight.data.fill_(1)
|
||||
gn.bias.data.fill_(0)
|
||||
output = gn(x)
|
||||
out_reshaped = output.view(b, g, -1)
|
||||
mean = out_reshaped.mean(-1)
|
||||
var = out_reshaped.var(-1, unbiased=False)
|
||||
self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
|
||||
|
||||
# test that GN applies weight and bias correctly
|
||||
scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
|
||||
bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
|
||||
gn.weight.data.copy_(scale)
|
||||
gn.bias.data.copy_(bias)
|
||||
output = gn(x)
|
||||
out_reshaped = output.view(b, c, -1)
|
||||
out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1)
|
||||
out_normed_reshaped = out_normed.view(b, g, -1)
|
||||
mean = out_normed_reshaped.mean(-1)
|
||||
var = out_normed_reshaped.var(-1, unbiased=False)
|
||||
self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
|
||||
|
||||
bad_shape_g = {
|
||||
(1, 2, 3, 4): 3,
|
||||
(2, 3, 10): 2,
|
||||
(3, 1, 1, 1, 2): 10,
|
||||
(2, 6, 4, 2, 2): 4,
|
||||
}
|
||||
for shape, g in bad_shape_g.items():
|
||||
gn = nn.GroupNorm(g, shape[1])
|
||||
input = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
|
||||
self.assertRaises(RuntimeError, lambda: gn(input))
|
||||
|
||||
def _test_GroupNorm_cuda_half(self):
|
||||
input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10)
|
||||
m = nn.GroupNorm(2, 4).to("cuda", torch.half)
|
||||
output = m(input)
|
||||
output.sum().backward()
|
||||
self.assertEqual(output.type(), input.type())
|
||||
|
||||
def test_GroupNorm_general(self):
|
||||
self._test_GroupNorm_general(dtype=torch.float)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_GroupNorm_general_cuda(self):
|
||||
self._test_GroupNorm_general("cuda", torch.float)
|
||||
self._test_GroupNorm_cuda_half()
|
||||
|
||||
def test_pad(self):
|
||||
inputs = torch.randn(1, 3, 4, 4, requires_grad=True)
|
||||
_assertGradAndGradgradChecks(self, lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,))
|
||||
@ -3490,115 +3165,11 @@ class TestNN(NNTestCase):
|
||||
self.assertRaisesRegex(RuntimeError, expected_err_msg,
|
||||
lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect'))
|
||||
|
||||
@staticmethod
|
||||
def _test_one_hot(self, use_cuda=False):
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
|
||||
expected = torch.tensor([[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
|
||||
expected = torch.tensor([[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
|
||||
expected = torch.tensor([[0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0],
|
||||
[0, 1, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
|
||||
expected = torch.tensor([[[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1]],
|
||||
[[0, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0]]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
|
||||
expected = torch.tensor([0, 0, 0, 0, 1], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
|
||||
expected = torch.empty([4, 0, 100])
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
|
||||
|
||||
def test_one_hot(self):
|
||||
self._test_one_hot(self)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_one_hot_cuda(self):
|
||||
self._test_one_hot(self, use_cuda=True)
|
||||
|
||||
def test_pad_scalar_error(self):
|
||||
inputs = torch.tensor(0., requires_grad=True)
|
||||
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))
|
||||
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1,)))
|
||||
|
||||
def test_nn_scalars(self):
|
||||
# One off tests to ensure scalars from nn.yaml are properly applied
|
||||
def verify_scalars(input, output):
|
||||
if input.dim() == 0:
|
||||
self.assertEqual((), output.shape)
|
||||
else:
|
||||
self.assertNotEqual((), output.shape)
|
||||
output.sum().backward()
|
||||
self.assertEqual(input.shape, input.grad.shape)
|
||||
|
||||
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|
||||
for device in devices:
|
||||
for input_shape in [(5, 6), ()]:
|
||||
for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
|
||||
torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
|
||||
torch.nn.Tanh]:
|
||||
input = torch.randn(input_shape, device=device, requires_grad=True)
|
||||
m = module()
|
||||
output = m(input)
|
||||
verify_scalars(input, output)
|
||||
|
||||
def test_nn_scalars_reductions(self):
|
||||
# One off tests to ensure scalars from nn.yaml are properly applied
|
||||
def verify_reduction_scalars(input, reduction, output):
|
||||
if reduction != 'none' or input.dim() == 0:
|
||||
self.assertEqual((), output.shape)
|
||||
else:
|
||||
self.assertNotEqual((), output.shape)
|
||||
output.sum().backward()
|
||||
self.assertEqual(input.shape, input.grad.shape)
|
||||
|
||||
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|
||||
for device in devices:
|
||||
for input_shape in [(5, 6), ()]:
|
||||
for reduction in ['none', 'mean', 'sum']:
|
||||
for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss,
|
||||
torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]:
|
||||
input = torch.randn(input_shape, device=device, requires_grad=True)
|
||||
target = torch.empty(input_shape, device=device).random_(2)
|
||||
sigmoid = nn.Sigmoid()
|
||||
|
||||
input = torch.randn(input_shape, device=device, requires_grad=True)
|
||||
m = module(reduction=reduction)
|
||||
output = m(sigmoid(input), target)
|
||||
verify_reduction_scalars(input, reduction, output)
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
def test_multihead_attention(self):
|
||||
@ -10090,6 +9661,374 @@ def _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_ra
|
||||
return transform_tensor, transform_ary, grid_ary
|
||||
# end TestNN.test_affine_* helpers
|
||||
|
||||
class GenericDeviceTypeHelpers(object):
|
||||
def _test_dropout(self, cls, device, input):
|
||||
p = 0.2
|
||||
input = input.to(device).fill_(1 - p)
|
||||
|
||||
module = cls(p)
|
||||
input_var = input.clone().requires_grad_()
|
||||
output = module(input_var)
|
||||
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
|
||||
output.backward(input)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
|
||||
|
||||
module = cls(p, True)
|
||||
input_var = input.clone().requires_grad_()
|
||||
output = module(input_var + 0)
|
||||
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
|
||||
output.backward(input)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
|
||||
|
||||
# check eval mode doesn't change anything
|
||||
for inplace in [True, False]:
|
||||
module = cls(p, inplace).eval()
|
||||
self.assertEqual(input, module(input))
|
||||
|
||||
# Check that these don't raise errors
|
||||
module.__repr__()
|
||||
str(module)
|
||||
|
||||
def _test_InstanceNorm_general(self, cls, input, device, dtype=torch.float):
|
||||
# default case track_running_stats=False
|
||||
b, c = input.size(0), input.size(1)
|
||||
input_var = input.to(device=device, dtype=dtype).requires_grad_()
|
||||
|
||||
IN = cls(c, eps=0).to(device, dtype)
|
||||
|
||||
output = IN(input_var)
|
||||
out_reshaped = output.view(b * c, -1)
|
||||
|
||||
mean = out_reshaped.mean(1)
|
||||
var = out_reshaped.var(1, unbiased=False)
|
||||
|
||||
self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
|
||||
|
||||
# check that eval mode doesn't change behavior
|
||||
grad_out = torch.randn_like(output)
|
||||
res1 = output.data.clone()
|
||||
output.backward(grad_out)
|
||||
grad1 = input_var.grad.data.clone()
|
||||
|
||||
IN.eval()
|
||||
output = IN(input_var)
|
||||
input_var.grad = None
|
||||
output.backward(grad_out)
|
||||
res2 = output.data
|
||||
grad2 = input_var.grad.data
|
||||
self.assertEqual(res1, res2)
|
||||
self.assertEqual(grad1, grad2)
|
||||
|
||||
# If track_running_stats=True and momentum=1, running_mean/var should be
|
||||
# equal to mean/var of the input (with unbias correction)
|
||||
IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype)
|
||||
|
||||
output = IN(input_var)
|
||||
|
||||
input_reshaped = input_var.transpose(1, 0).reshape(c, -1)
|
||||
mean = input_reshaped.mean(1)
|
||||
|
||||
input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1)
|
||||
var = input_reshaped.var(2, unbiased=True)[:, :]
|
||||
|
||||
self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5)
|
||||
|
||||
# in eval mode, adding X * std to a channel in input should make the
|
||||
# corresponding channel in output have mean X
|
||||
IN.eval()
|
||||
delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype)
|
||||
delta = delta.view(-1, *[1 for _ in range(2, input.dim())])
|
||||
output = IN(input_var + delta)
|
||||
self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c))
|
||||
|
||||
def _test_InstanceNorm_cuda_half(self, cls, input):
|
||||
# THNN
|
||||
input = input.to(device='cuda', dtype=torch.half).random_(1, 10).requires_grad_(True)
|
||||
m = cls(input.size(1), affine=True, track_running_stats=True).to("cuda", torch.half)
|
||||
thnn_output = m(input)
|
||||
thnn_output.sum().backward()
|
||||
thnn_input_grad = input.grad.data.clone()
|
||||
self.assertEqual(thnn_output.type(), input.type())
|
||||
# cuDNN
|
||||
if TEST_CUDNN:
|
||||
input.grad = None
|
||||
m = m.float()
|
||||
cudnn_output = m(input)
|
||||
cudnn_output.sum().backward()
|
||||
cudnn_input_grad = input.grad.data.clone()
|
||||
self.assertEqual(cudnn_output.type(), input.type())
|
||||
self.assertAlmostEqual(cudnn_output, thnn_output, delta=1e-4)
|
||||
self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3)
|
||||
|
||||
def _test_LayerNorm_general(self, device, dtype=torch.float):
|
||||
for i in range(2, 6):
|
||||
shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist()
|
||||
x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
|
||||
normalized_ndim = random.randint(1, i - 1) # inclusive
|
||||
normalized_shape = shape[-normalized_ndim:]
|
||||
unnormalized_shape = shape[:-normalized_ndim]
|
||||
|
||||
# test that LN normalizes to mean 0 and stddev 1
|
||||
ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype)
|
||||
ln.weight.data.fill_(1)
|
||||
ln.bias.data.fill_(0)
|
||||
output = ln(x)
|
||||
out_reshaped = output.view(*(unnormalized_shape + [-1]))
|
||||
mean = out_reshaped.mean(-1)
|
||||
var = out_reshaped.var(-1, unbiased=False)
|
||||
self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
|
||||
|
||||
# test that LN applies weight and bias correctly
|
||||
scale, bias = torch.empty(2).uniform_(0.2, 2).tolist()
|
||||
ln.weight.data.fill_(scale)
|
||||
ln.bias.data.fill_(bias)
|
||||
output = ln(x)
|
||||
out_reshaped = output.view(*(unnormalized_shape + [-1]))
|
||||
mean = out_reshaped.mean(-1)
|
||||
var = out_reshaped.var(-1, unbiased=False)
|
||||
self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5)
|
||||
|
||||
bad_norm_shape_input_shape = {
|
||||
(): (),
|
||||
(2, 3): (3,),
|
||||
(2,): (1, 2, 3),
|
||||
(10,): (2, 3),
|
||||
10: (2, 3),
|
||||
}
|
||||
for norm_shape, input_shape in bad_norm_shape_input_shape.items():
|
||||
ln = nn.LayerNorm(norm_shape)
|
||||
input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10)
|
||||
self.assertRaises(RuntimeError, lambda: ln(input))
|
||||
|
||||
def _test_LayerNorm_cuda_half(self):
|
||||
input = torch.empty(2, 3, 3, 2, device="cuda", dtype=torch.half).random_(1, 10).requires_grad_(True)
|
||||
m = nn.LayerNorm([3, 2]).to("cuda", torch.half)
|
||||
output = m(input)
|
||||
output.sum().backward()
|
||||
self.assertEqual(output.type(), input.type())
|
||||
|
||||
def _test_GroupNorm_general(self, device, dtype=torch.float):
|
||||
good_shape_g = {
|
||||
(1, 2, 3, 4): 2,
|
||||
(2, 3, 10): 3,
|
||||
(3, 1, 1, 1, 2): 1,
|
||||
(2, 6, 4, 2, 2): 3,
|
||||
}
|
||||
for shape, g in good_shape_g.items():
|
||||
x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
|
||||
b = shape[0]
|
||||
c = shape[1]
|
||||
|
||||
# test that GN normalizes to mean 0 and stddev 1
|
||||
gn = nn.GroupNorm(g, c, eps=0).to(device, dtype)
|
||||
gn.weight.data.fill_(1)
|
||||
gn.bias.data.fill_(0)
|
||||
output = gn(x)
|
||||
out_reshaped = output.view(b, g, -1)
|
||||
mean = out_reshaped.mean(-1)
|
||||
var = out_reshaped.var(-1, unbiased=False)
|
||||
self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
|
||||
|
||||
# test that GN applies weight and bias correctly
|
||||
scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
|
||||
bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
|
||||
gn.weight.data.copy_(scale)
|
||||
gn.bias.data.copy_(bias)
|
||||
output = gn(x)
|
||||
out_reshaped = output.view(b, c, -1)
|
||||
out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1)
|
||||
out_normed_reshaped = out_normed.view(b, g, -1)
|
||||
mean = out_normed_reshaped.mean(-1)
|
||||
var = out_normed_reshaped.var(-1, unbiased=False)
|
||||
self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5)
|
||||
self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5)
|
||||
|
||||
bad_shape_g = {
|
||||
(1, 2, 3, 4): 3,
|
||||
(2, 3, 10): 2,
|
||||
(3, 1, 1, 1, 2): 10,
|
||||
(2, 6, 4, 2, 2): 4,
|
||||
}
|
||||
for shape, g in bad_shape_g.items():
|
||||
gn = nn.GroupNorm(g, shape[1])
|
||||
input = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
|
||||
self.assertRaises(RuntimeError, lambda: gn(input))
|
||||
|
||||
def _test_GroupNorm_cuda_half(self):
|
||||
input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10)
|
||||
m = nn.GroupNorm(2, 4).to("cuda", torch.half)
|
||||
output = m(input)
|
||||
output.sum().backward()
|
||||
self.assertEqual(output.type(), input.type())
|
||||
|
||||
class TestNNDeviceType(NNTestCase, GenericDeviceTypeHelpers):
|
||||
def test_Dropout(self, device):
|
||||
input = torch.Tensor(1000)
|
||||
self._test_dropout(nn.Dropout, device, input)
|
||||
|
||||
def test_Dropout2d(self, device):
|
||||
b = random.randint(1, 5)
|
||||
w = random.randint(1, 5)
|
||||
h = random.randint(1, 5)
|
||||
num_features = 1000
|
||||
input = torch.Tensor(num_features, b, w, h)
|
||||
self._test_dropout(nn.Dropout2d, device, input)
|
||||
|
||||
def test_Dropout3d(self, device):
|
||||
b = random.randint(1, 5)
|
||||
w = random.randint(1, 5)
|
||||
h = random.randint(1, 5)
|
||||
d = random.randint(1, 2)
|
||||
num_features = 1000
|
||||
input = torch.Tensor(num_features, b, d, w, h)
|
||||
self._test_dropout(nn.Dropout3d, device, input)
|
||||
|
||||
def test_InstanceNorm1d_general(self, device):
|
||||
b = random.randint(3, 5)
|
||||
c = random.randint(3, 5)
|
||||
d = random.randint(8, 10)
|
||||
|
||||
input = torch.rand(b, c, d)
|
||||
self._test_InstanceNorm_general(nn.InstanceNorm1d, input, device)
|
||||
|
||||
if device == 'cuda':
|
||||
self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input)
|
||||
|
||||
def test_InstanceNorm2d_general(self, device):
|
||||
b = random.randint(3, 5)
|
||||
c = random.randint(3, 5)
|
||||
w = random.randint(3, 6)
|
||||
h = random.randint(6, 8)
|
||||
|
||||
input = torch.rand(b, c, h, w)
|
||||
self._test_InstanceNorm_general(nn.InstanceNorm2d, input, device)
|
||||
|
||||
if device == 'cuda':
|
||||
self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input)
|
||||
|
||||
def test_InstanceNorm3d_general(self, device):
|
||||
b = random.randint(3, 5)
|
||||
c = random.randint(3, 5)
|
||||
w = random.randint(2, 5)
|
||||
h = random.randint(2, 5)
|
||||
d = random.randint(2, 5)
|
||||
|
||||
input = torch.rand(b, c, h, w, d)
|
||||
self._test_InstanceNorm_general(nn.InstanceNorm3d, input, device)
|
||||
|
||||
if device == 'cuda':
|
||||
self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input)
|
||||
|
||||
def test_LayerNorm_general(self, device):
|
||||
self._test_LayerNorm_general(device)
|
||||
|
||||
if device == 'cuda':
|
||||
self._test_LayerNorm_cuda_half()
|
||||
|
||||
def test_GroupNorm_general(self, device):
|
||||
self._test_GroupNorm_general(device)
|
||||
|
||||
if device == 'cuda':
|
||||
self._test_GroupNorm_cuda_half()
|
||||
|
||||
def test_one_hot(self, device):
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
|
||||
expected = torch.tensor([[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
|
||||
expected = torch.tensor([[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
|
||||
expected = torch.tensor([[0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0],
|
||||
[0, 1, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
|
||||
expected = torch.tensor([[[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1]],
|
||||
[[0, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0]]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
|
||||
expected = torch.tensor([0, 0, 0, 0, 1], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
|
||||
expected = torch.empty([4, 0, 100])
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
|
||||
|
||||
def test_nn_scalars(self, device):
|
||||
# One off tests to ensure scalars from nn.yaml are properly applied
|
||||
def verify_scalars(input, output):
|
||||
if input.dim() == 0:
|
||||
self.assertEqual((), output.shape)
|
||||
else:
|
||||
self.assertNotEqual((), output.shape)
|
||||
output.sum().backward()
|
||||
self.assertEqual(input.shape, input.grad.shape)
|
||||
|
||||
for input_shape in [(5, 6), ()]:
|
||||
for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
|
||||
torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
|
||||
torch.nn.Tanh]:
|
||||
input = torch.randn(input_shape, device=device, requires_grad=True)
|
||||
m = module()
|
||||
output = m(input)
|
||||
verify_scalars(input, output)
|
||||
|
||||
def test_nn_scalars_reductions(self, device):
|
||||
# One off tests to ensure scalars from nn.yaml are properly applied
|
||||
def verify_reduction_scalars(input, reduction, output):
|
||||
if reduction != 'none' or input.dim() == 0:
|
||||
self.assertEqual((), output.shape)
|
||||
else:
|
||||
self.assertNotEqual((), output.shape)
|
||||
output.sum().backward()
|
||||
self.assertEqual(input.shape, input.grad.shape)
|
||||
|
||||
for input_shape in [(5, 6), ()]:
|
||||
for reduction in ['none', 'mean', 'sum']:
|
||||
for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss,
|
||||
torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]:
|
||||
input = torch.randn(input_shape, device=device, requires_grad=True)
|
||||
target = torch.empty(input_shape, device=device).random_(2)
|
||||
sigmoid = nn.Sigmoid()
|
||||
|
||||
input = torch.randn(input_shape, device=device, requires_grad=True)
|
||||
m = module(reduction=reduction)
|
||||
output = m(sigmoid(input), target)
|
||||
verify_reduction_scalars(input, reduction, output)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestNNDeviceType, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
@ -30,7 +30,7 @@ from common_methods_invocations import tri_tests_args, run_additional_tri_tests,
|
||||
from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
|
||||
TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
|
||||
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
|
||||
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, torchtest, \
|
||||
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
|
||||
skipCUDANonDefaultStreamIf
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from common_device_type import instantiate_device_type_tests, \
|
||||
@ -104,7 +104,7 @@ class BytesIOContext(io.BytesIO):
|
||||
|
||||
# This is intentionally prefixed by an underscore. Otherwise pytest will try to
|
||||
# run its methods as test cases.
|
||||
class _TestTorchMixin(torchtest):
|
||||
class _TestTorchMixin(object):
|
||||
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True):
|
||||
float_types = [torch.double,
|
||||
torch.float]
|
||||
|
||||
Reference in New Issue
Block a user