mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
As part of the follow up for https://github.com/pytorch/pytorch/issues/133520, adapting existing unused tests for use in MPS CI runs. Focusing on nhwc & other memory formatting tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/134356 Approved by: https://github.com/malfet, https://github.com/eqy, https://github.com/huydhn
332 lines
13 KiB
Python
332 lines
13 KiB
Python
# Owner(s): ["module: nn"]
|
|
import itertools
|
|
import random
|
|
import unittest
|
|
from itertools import product
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
dtypesIfMPS,
|
|
expectedFailureMPS,
|
|
expectedFailureMPSPre15,
|
|
expectedFailureXLA,
|
|
instantiate_device_type_tests,
|
|
)
|
|
from torch.testing._internal.common_nn import freeze_rng_state, NNTestCase
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
run_tests,
|
|
set_default_dtype,
|
|
TEST_PRIVATEUSE1,
|
|
)
|
|
|
|
|
|
class TestDropoutNN(NNTestCase):
|
|
_do_cuda_memory_leak_check = True
|
|
_do_cuda_non_default_stream = True
|
|
|
|
def _test_alpha_dropout(self, cls, input):
|
|
mean = input.mean()
|
|
std = input.std()
|
|
|
|
for p in [0.2, 0.5, 0.8]:
|
|
module = cls(p)
|
|
input_var = input.detach().clone().requires_grad_()
|
|
output = module(input_var)
|
|
# output mean should be close to input mean
|
|
self.assertLess(abs(output.data.mean() - mean), 0.1)
|
|
# output std should be close to input std
|
|
self.assertLess(abs(output.data.std() - std), 0.1)
|
|
output.backward(input)
|
|
|
|
def test_AlphaDropout(self):
|
|
# generate random tensor with zero mean and unit std
|
|
input = torch.randn(5000)
|
|
self._test_alpha_dropout(nn.AlphaDropout, input)
|
|
|
|
def test_FeatureAlphaDropout(self):
|
|
b = random.randint(1, 5)
|
|
w = random.randint(1, 5)
|
|
h = random.randint(1, 5)
|
|
d = random.randint(1, 2)
|
|
num_features = 1000
|
|
input = torch.randn(num_features, b, d, w, h)
|
|
self._test_alpha_dropout(nn.FeatureAlphaDropout, input)
|
|
|
|
# no batch dims
|
|
input = torch.randn(50, 20, 64, 64)
|
|
self._test_alpha_dropout(nn.FeatureAlphaDropout, input)
|
|
|
|
@unittest.skipIf(
|
|
not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 unavailable"
|
|
)
|
|
def test_native_dropout_corner_case(self):
|
|
if TEST_CUDA:
|
|
device = "cuda"
|
|
elif TEST_PRIVATEUSE1:
|
|
device = torch._C._get_privateuse1_backend_name()
|
|
for train in [True, False]:
|
|
for p in [0.0, 1.0]:
|
|
for current_device in [device, "cpu"]:
|
|
x = torch.randn(5).to(device=current_device).requires_grad_()
|
|
x_ref = x.detach().requires_grad_()
|
|
o = torch.native_dropout(x, p, train)[0]
|
|
o_ref = torch.dropout(x_ref, p, train)
|
|
o.sum().backward()
|
|
o_ref.sum().backward()
|
|
assert o.equal(o_ref)
|
|
assert x.grad.equal(x_ref.grad)
|
|
|
|
def test_invalid_dropout_p(self):
|
|
v = torch.ones(1)
|
|
self.assertRaises(ValueError, lambda: nn.Dropout(-0.1))
|
|
self.assertRaises(ValueError, lambda: nn.Dropout(1.1))
|
|
self.assertRaises(ValueError, lambda: nn.Dropout1d(-0.1))
|
|
self.assertRaises(ValueError, lambda: nn.Dropout1d(1.1))
|
|
self.assertRaises(ValueError, lambda: nn.Dropout2d(-0.1))
|
|
self.assertRaises(ValueError, lambda: nn.Dropout2d(1.1))
|
|
self.assertRaises(ValueError, lambda: nn.Dropout3d(-0.1))
|
|
self.assertRaises(ValueError, lambda: nn.Dropout3d(1.1))
|
|
self.assertRaises(ValueError, lambda: F.dropout(v, -0.1))
|
|
self.assertRaises(ValueError, lambda: F.dropout(v, 1.1))
|
|
|
|
|
|
class TestDropoutNNDeviceType(NNTestCase):
|
|
def _test_dropout(self, cls, device, input, memory_format=torch.contiguous_format):
|
|
p = 0.2
|
|
input = input.to(device).fill_(1 - p)
|
|
|
|
module = cls(p)
|
|
input_var = input.clone(memory_format=memory_format).requires_grad_()
|
|
output = module(input_var)
|
|
self.assertTrue(output.is_contiguous(memory_format=memory_format))
|
|
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
|
|
output.backward(input)
|
|
self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format))
|
|
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
|
|
|
|
module = cls(p, True)
|
|
input_var = input.clone(memory_format=memory_format).requires_grad_()
|
|
output = module(input_var + 0)
|
|
self.assertTrue(output.is_contiguous(memory_format=memory_format))
|
|
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
|
|
output.backward(input)
|
|
self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format))
|
|
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
|
|
|
|
# check eval mode doesn't change anything
|
|
for inplace in [True, False]:
|
|
module = cls(p, inplace).eval()
|
|
self.assertEqual(input, module(input))
|
|
|
|
# Check that these don't raise errors
|
|
module.__repr__()
|
|
str(module)
|
|
|
|
def _test_dropout_discontiguous(
|
|
self, cls, device, memory_format=torch.contiguous_format
|
|
):
|
|
# In this test, we verify that dropout preserves the layout and data for different memory formats.
|
|
# We check whether, we get same values for the output of dropout, when the probability
|
|
# of dropout is 0 or very close to 0.
|
|
# Reference: https://github.com/pytorch/pytorch/issues/47176
|
|
close_to_zero_p = 1e-10 # Should be almost zero but not zero, as for p=0 different path is taken
|
|
for p in [0, close_to_zero_p]:
|
|
inp = torch.ones(2, 3, 3, 3, device=device)
|
|
inp_discontiguous = torch.empty(
|
|
2, 3, 3, 6, device=device, memory_format=memory_format
|
|
)[..., ::2]
|
|
inp_discontiguous.copy_(inp)
|
|
mod = cls(p=p)
|
|
out = mod(inp_discontiguous)
|
|
if p != 0: # Zero will keep strides as is based on input.
|
|
# When prob == 0, input stride (54, 18, 6, 2) -> output stride (54, 18, 6, 2)
|
|
# When prob != 0, input stride (54, 18, 6, 2) -> output stride (27, 9, 3, 1)
|
|
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
|
self.assertEqual(inp_discontiguous, out)
|
|
|
|
def _test_dropout_stride_mean_preserve(self, cls, device):
|
|
def invert_perm(p):
|
|
d = {x: i for i, x in enumerate(p)}
|
|
return (d[0], d[1], d[2], d[3])
|
|
|
|
inp = torch.ones(2, 3, 4, 5, device=device)
|
|
shifts = [(0, 0), (1, 0), (0, 1), (1, 1)]
|
|
for perm in itertools.permutations((0, 1, 2, 3), r=4):
|
|
for shift in shifts:
|
|
for p in [1e-10, 0.3, 0.5, 0.7]:
|
|
mod = cls(p=p)
|
|
permuted_inp = (
|
|
inp.permute(perm).contiguous().permute(invert_perm(perm))
|
|
)
|
|
permuted_inp = permuted_inp[shift[0] :, shift[1] :, :, :]
|
|
out = mod(permuted_inp)
|
|
|
|
self.assertTrue(out.permute(perm).is_contiguous())
|
|
self.assertEqual(inp.mean(), out.mean(), rtol=0.5, atol=0.5)
|
|
if p == 1e-10:
|
|
self.assertEqual(permuted_inp, out)
|
|
else:
|
|
self.assertNotEqual(permuted_inp, out)
|
|
|
|
@expectedFailureMPSPre15
|
|
def test_Dropout(self, device):
|
|
input = torch.empty(1000)
|
|
self._test_dropout(nn.Dropout, device, input)
|
|
|
|
self._test_dropout_discontiguous(nn.Dropout, device)
|
|
self._test_dropout_discontiguous(
|
|
nn.Dropout, device, memory_format=torch.channels_last
|
|
)
|
|
|
|
self._test_dropout_stride_mean_preserve(nn.Dropout, device)
|
|
|
|
if self.device_type == "cuda" or self.device_type == "cpu":
|
|
input = input.bfloat16()
|
|
self._test_dropout(nn.Dropout, device, input)
|
|
|
|
def _test_dropoutNd_no_batch(self, dropout, input):
|
|
input_clone = input.clone()
|
|
with freeze_rng_state():
|
|
res_no_batch = dropout(input)
|
|
|
|
with freeze_rng_state():
|
|
res_batched = dropout(input_clone.unsqueeze(0)).squeeze(0)
|
|
|
|
self.assertEqual(res_no_batch, res_batched)
|
|
|
|
def _test_dropoutNd_channel_zero(self, dropout, input):
|
|
# Verify the number of zeros in a channel is 0 or the number of elements in the channel
|
|
# for a fully positive input tensor
|
|
shape = input.shape
|
|
B = shape[0]
|
|
C = shape[1]
|
|
channel_numel = torch.tensor(shape[2:]).prod()
|
|
result = dropout(input)
|
|
|
|
for b, c in product(range(B), range(C)):
|
|
self.assertTrue(result[b, c].count_nonzero() in (0, channel_numel))
|
|
|
|
@expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA
|
|
@dtypes(torch.double)
|
|
@dtypesIfMPS(torch.float32)
|
|
@expectedFailureMPS
|
|
def test_Dropout1d(self, device, dtype):
|
|
with set_default_dtype(dtype):
|
|
N, C, L = (
|
|
random.randint(10, 15),
|
|
random.randint(10, 15),
|
|
random.randint(10, 15),
|
|
)
|
|
input = torch.empty(N, C, L)
|
|
self._test_dropout(nn.Dropout1d, device, input)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Expected 2D or 3D input, but received a 4D input"
|
|
):
|
|
nn.Dropout1d(p=0.5)(torch.rand(1, 2, 2, 2, device=device))
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Expected 2D or 3D input, but received a 1D input"
|
|
):
|
|
nn.Dropout1d(p=0.5)(torch.rand(2, device=device))
|
|
|
|
# no batch dims
|
|
input = torch.rand(50, 2, device=device)
|
|
self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5), input)
|
|
self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5, inplace=True), input)
|
|
|
|
# check that complete channels are dropped
|
|
input = torch.ones(10, 4, 2, device=device)
|
|
self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5), input)
|
|
self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5, inplace=True), input)
|
|
|
|
@expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA
|
|
def test_Dropout2d(self, device):
|
|
b = random.randint(1, 5)
|
|
w = random.randint(1, 5)
|
|
h = random.randint(1, 5)
|
|
num_features = 1000
|
|
input = torch.empty(num_features, b, w, h)
|
|
self._test_dropout(nn.Dropout2d, device, input)
|
|
self._test_dropout(
|
|
nn.Dropout2d, device, input, memory_format=torch.channels_last
|
|
)
|
|
|
|
self._test_dropout_discontiguous(nn.Dropout2d, device)
|
|
self._test_dropout_discontiguous(
|
|
nn.Dropout2d, device, memory_format=torch.channels_last
|
|
)
|
|
|
|
with self.assertWarnsRegex(UserWarning, "Received a 5-D input to dropout2d"):
|
|
nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, 2, 2, device=device))
|
|
|
|
with self.assertWarnsRegex(UserWarning, "Received a 2-D input to dropout2d"):
|
|
nn.Dropout2d(p=0.5)(torch.rand(1, 2, device=device))
|
|
|
|
# TODO: Uncomment these lines once no-batch-dim inputs are supported.
|
|
# For now, the historical dropout1d behavior is performed for 3D inputs.
|
|
# See https://github.com/pytorch/pytorch/issues/77081
|
|
|
|
# input = torch.rand(50, 2, 2, device=device)
|
|
# self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input)
|
|
# self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input)
|
|
|
|
with self.assertWarnsRegex(
|
|
UserWarning, "assuming that channel-wise 1D dropout behavior is desired"
|
|
):
|
|
nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, device=device))
|
|
|
|
# check that complete channels are dropped
|
|
input = torch.ones(10, 4, 2, 2, device=device)
|
|
self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5), input)
|
|
self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5, inplace=True), input)
|
|
|
|
@expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA
|
|
@expectedFailureMPS # Failing on current pytorch MPS
|
|
def test_Dropout3d(self, device):
|
|
b = random.randint(1, 5)
|
|
w = random.randint(1, 5)
|
|
h = random.randint(1, 5)
|
|
d = random.randint(1, 2)
|
|
num_features = 1000
|
|
input = torch.empty(num_features, b, d, w, h)
|
|
self._test_dropout(nn.Dropout3d, device, input)
|
|
|
|
self._test_dropout_discontiguous(nn.Dropout3d, device)
|
|
self._test_dropout_discontiguous(
|
|
nn.Dropout3d, device, memory_format=torch.channels_last
|
|
)
|
|
|
|
with self.assertWarnsRegex(UserWarning, "Received a 6-D input to dropout3d"):
|
|
nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, 2, 2, 2, device=device))
|
|
|
|
with self.assertWarnsRegex(UserWarning, "Received a 3-D input to dropout3d"):
|
|
nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, device=device))
|
|
|
|
# no batch dims
|
|
input = torch.rand(50, 2, 2, 2, device=device)
|
|
self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5), input)
|
|
self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5, inplace=True), input)
|
|
|
|
# check that complete channels are dropped
|
|
input = torch.ones(10, 4, 2, 2, 2, device=device)
|
|
self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5), input)
|
|
self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5, inplace=True), input)
|
|
|
|
def test_empty_dropout(self, device):
|
|
x = torch.tensor([]).to(device)
|
|
out = torch.nn.functional.dropout(x)
|
|
self.assertEqual(out.size(), x.size())
|
|
|
|
|
|
instantiate_device_type_tests(TestDropoutNNDeviceType, globals(), allow_mps=True)
|
|
instantiate_parametrized_tests(TestDropoutNN)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|