mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #162816. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162856 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
4071 lines
153 KiB
Python
4071 lines
153 KiB
Python
# Owner(s): ["module: nn"]
|
|
import itertools
|
|
import math
|
|
import os
|
|
import unittest
|
|
import warnings
|
|
from itertools import product
|
|
|
|
import torch
|
|
import torch.autograd.forward_ad as fwAD
|
|
import torch.backends.cudnn as cudnn
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN, tf32_on_and_off
|
|
from torch.testing._internal.common_device_type import (
|
|
disablecuDNN,
|
|
disableMkldnn,
|
|
dtypes,
|
|
dtypesIfCUDA,
|
|
dtypesIfMPS,
|
|
expectedFailureMPS,
|
|
instantiate_device_type_tests,
|
|
largeTensorTest,
|
|
onlyCPU,
|
|
onlyCUDA,
|
|
onlyNativeDeviceTypes,
|
|
precisionOverride,
|
|
skipCPUIfNoMkldnn,
|
|
skipCUDAIfMiopen,
|
|
skipCUDAIfNoCudnn,
|
|
skipCUDAIfNoMiopen,
|
|
skipCUDAIfRocm,
|
|
skipMeta,
|
|
skipMPS,
|
|
)
|
|
from torch.testing._internal.common_dtype import (
|
|
floating_and_complex_types_and,
|
|
floating_types_and,
|
|
)
|
|
from torch.testing._internal.common_nn import _test_module_empty_input, NNTestCase
|
|
from torch.testing._internal.common_utils import (
|
|
download_file,
|
|
dtype2prec_DONTUSE,
|
|
gradcheck,
|
|
GRADCHECK_NONDET_TOL,
|
|
gradgradcheck,
|
|
instantiate_parametrized_tests,
|
|
MACOS_VERSION,
|
|
parametrize as parametrize_test,
|
|
run_tests,
|
|
set_default_dtype,
|
|
subtest,
|
|
TEST_SCIPY,
|
|
TEST_WITH_ROCM,
|
|
)
|
|
|
|
|
|
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
|
|
|
|
|
|
if TEST_WITH_ROCM:
|
|
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
|
|
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
|
|
|
|
|
|
if TEST_SCIPY:
|
|
import scipy.ndimage
|
|
import scipy.signal
|
|
|
|
|
|
class TestConvolutionNN(NNTestCase):
|
|
_do_cuda_memory_leak_check = True
|
|
_do_cuda_non_default_stream = True
|
|
|
|
def test_conv_backcompat(self):
|
|
from torch.serialization import SourceChangeWarning
|
|
|
|
# This file was generated by running on PyTorch 1.0.1 on Python 2:
|
|
#
|
|
# import torch
|
|
# from torch import nn
|
|
# m = nn.Conv2d(1, 1, 1)
|
|
# torch.save(m, 'legacy_conv2d.pt')
|
|
#
|
|
# NB: This Pickle also contains some Unicode data!
|
|
path = download_file("https://download.pytorch.org/test_data/legacy_conv2d.pt")
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore", SourceChangeWarning)
|
|
# weights_only=False as this is legacy code that saves the model
|
|
m = torch.load(path, encoding="utf-8", weights_only=False)
|
|
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
|
|
self.assertEqual(m(input).size(), (1, 1, 1, 1))
|
|
|
|
def test_invalid_conv1d(self):
|
|
for dtype in [
|
|
torch.half,
|
|
torch.bfloat16,
|
|
torch.float,
|
|
torch.double,
|
|
torch.cfloat,
|
|
torch.cdouble,
|
|
]:
|
|
module = nn.Conv1d(
|
|
in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True
|
|
).to(dtype)
|
|
input = torch.randn(1, 3, 4).to(dtype)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Calculated padded input size per channel: \(4\). "
|
|
+ r"Kernel size: \(10\). Kernel size can\'t be greater than actual input size",
|
|
):
|
|
module(input)
|
|
|
|
# Negative stride check
|
|
module = nn.Conv1d(
|
|
in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True
|
|
).to(dtype)
|
|
input = torch.randn(1, 3, 4).to(dtype)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "non-positive stride is not supported"
|
|
):
|
|
module(input)
|
|
|
|
def test_mismatch_shape_conv2d(self):
|
|
for dtype in (torch.float, torch.cfloat):
|
|
x = torch.randn(1, 10, 1, 28, 28, dtype=dtype)
|
|
w = torch.randn(6, 1, 5, 5, dtype=dtype)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got "
|
|
+ r"input of size: \[1, 10, 1, 28, 28\]",
|
|
):
|
|
F.conv2d(x, w)
|
|
|
|
def test_conv2d_discontiguous_weight(self):
|
|
for dtype in (torch.float, torch.cfloat):
|
|
# Test for https://github.com/pytorch/pytorch/issues/55781
|
|
x = torch.ones(64, 16, 16, 16, dtype=dtype)
|
|
weight = (
|
|
torch.arange(0, 1.0, 1 / 2.0**10)
|
|
.reshape(32, 16, 1, 2)
|
|
.to(dtype)[:, :, :, ::2]
|
|
)
|
|
self.assertFalse(weight.is_contiguous())
|
|
y = torch.nn.functional.conv2d(x, weight, None)
|
|
if torch.backends.mkldnn.is_available():
|
|
# Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
|
|
with torch.backends.mkldnn.flags(enabled=False):
|
|
y_ = torch.nn.functional.conv2d(x, weight, None)
|
|
self.assertEqual(y, y_)
|
|
self.assertEqual(y.sum(), 4186112.0)
|
|
|
|
def test_invalid_conv2d(self):
|
|
for dtype in [
|
|
torch.half,
|
|
torch.bfloat16,
|
|
torch.float,
|
|
torch.double,
|
|
torch.cfloat,
|
|
torch.cdouble,
|
|
]:
|
|
module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(
|
|
dtype
|
|
)
|
|
input = torch.empty(1, 1, 4, 4).to(dtype)
|
|
self.assertRaises(RuntimeError, lambda: module(input))
|
|
|
|
module = nn.Conv2d(
|
|
in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True
|
|
)
|
|
input = torch.randn(1, 3, 1, 1)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Calculated padded input size per channel: \(1 x 1\). "
|
|
+ r"Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size",
|
|
):
|
|
module(input)
|
|
|
|
# Negative stride check
|
|
module = nn.Conv2d(
|
|
in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True
|
|
).to(dtype)
|
|
input = torch.randn(1, 3, 4, 4).to(dtype)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "non-positive stride is not supported"
|
|
):
|
|
module(input)
|
|
|
|
# Zero stride check
|
|
module = nn.Conv2d(
|
|
in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True
|
|
).to(dtype)
|
|
input = torch.randn(1, 3, 4, 4).to(dtype)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "non-positive stride is not supported"
|
|
):
|
|
module(input)
|
|
|
|
def test_invalid_conv3d(self):
|
|
for dtype in [
|
|
torch.half,
|
|
torch.bfloat16,
|
|
torch.float,
|
|
torch.double,
|
|
torch.cfloat,
|
|
torch.cdouble,
|
|
]:
|
|
module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2).to(
|
|
dtype
|
|
)
|
|
input = torch.empty(1, 1, 4, 4, 4).to(dtype)
|
|
self.assertRaises(RuntimeError, lambda: module(input))
|
|
|
|
# Negative stride check
|
|
module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2)
|
|
input = torch.empty(1, 1, 4, 4, 4)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "non-positive stride is not supported"
|
|
):
|
|
module(input)
|
|
|
|
def test_conv_invalid_groups(self):
|
|
with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
|
|
torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0)
|
|
with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
|
|
torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1)
|
|
with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
|
|
torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2)
|
|
|
|
def test_Conv1d_module_same_padding(self):
|
|
# Compare module against functional: without strides/dilation, asymmetric padding
|
|
x = torch.rand(1, 1, 20)
|
|
module = nn.Conv1d(
|
|
in_channels=1, out_channels=1, kernel_size=10, padding="same"
|
|
)
|
|
expect = F.conv1d(x, module.weight, module.bias, padding="same")
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# Test dilation, symmetric padding
|
|
module = nn.Conv1d(
|
|
in_channels=1, out_channels=1, kernel_size=10, padding="same", dilation=2
|
|
)
|
|
expect = F.conv1d(x, module.weight, module.bias, padding="same", dilation=2)
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# Test non-zero padding_mode, requiring explicit padding
|
|
module = nn.Conv1d(
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=10,
|
|
padding="same",
|
|
padding_mode="replicate",
|
|
)
|
|
x_padded = F.pad(x, [4, 5], mode="replicate")
|
|
expect = F.conv1d(x_padded, module.weight, module.bias, padding="valid")
|
|
self.assertEqual(expect, module(x))
|
|
self.assertEqual(x.size(), expect.size())
|
|
|
|
# Test connstruction with invalid padding string raises
|
|
with self.assertRaisesRegex(ValueError, "Invalid padding string"):
|
|
module = nn.Conv1d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="foo"
|
|
)
|
|
|
|
# Test connstruction with same padding and strides raises
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv1d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
|
|
)
|
|
|
|
def test_Conv2d_module_same_padding(self):
|
|
# Compare module against functional:
|
|
# without strides/dilation, both symmetric and asymmetric padding
|
|
x = torch.rand(1, 1, 9, 20)
|
|
module = nn.Conv2d(
|
|
in_channels=1, out_channels=1, kernel_size=(5, 10), padding="same"
|
|
)
|
|
expect = F.conv2d(x, module.weight, module.bias, padding="same")
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# with dilation, symmetric padding
|
|
module = nn.Conv2d(
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=(3, 4),
|
|
padding="same",
|
|
dilation=(1, 2),
|
|
)
|
|
expect = F.conv2d(
|
|
x, module.weight, module.bias, padding="same", dilation=(1, 2)
|
|
)
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# Test non-zero padding_mode, requiring explicit padding
|
|
module = nn.Conv2d(
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=(3, 4),
|
|
padding="same",
|
|
padding_mode="reflect",
|
|
)
|
|
x_padded = F.pad(x, [1, 2, 1, 1], mode="reflect")
|
|
expect = F.conv2d(x_padded, module.weight, module.bias, padding="valid")
|
|
self.assertEqual(expect, module(x))
|
|
self.assertEqual(x.size(), expect.size())
|
|
|
|
# Test connstruction with invalid padding string raises
|
|
with self.assertRaisesRegex(ValueError, "Invalid padding string"):
|
|
module = nn.Conv2d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="foo"
|
|
)
|
|
|
|
# Test connstruction with same padding and strides raises
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv2d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv2d(
|
|
in_channels=3,
|
|
out_channels=33,
|
|
kernel_size=10,
|
|
padding="same",
|
|
stride=(1, 3),
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv2d(
|
|
in_channels=3,
|
|
out_channels=33,
|
|
kernel_size=10,
|
|
padding="same",
|
|
stride=(4, 1),
|
|
)
|
|
|
|
def test_Conv3d_module_same_padding(self):
|
|
# Compare module against functional:
|
|
x = torch.rand(1, 1, 4, 4, 4)
|
|
# without dilation, both symmetric and asymmetric padding
|
|
module = nn.Conv3d(
|
|
in_channels=1, out_channels=1, kernel_size=(2, 3, 4), padding="same"
|
|
)
|
|
expect = F.conv3d(x, module.weight, module.bias, padding="same")
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# with dilation, both symmetric and asymmetric padding
|
|
module = nn.Conv3d(
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=(2, 3, 4),
|
|
padding="same",
|
|
dilation=(3, 2, 1),
|
|
)
|
|
expect = F.conv3d(
|
|
x, module.weight, module.bias, padding="same", dilation=(3, 2, 1)
|
|
)
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# Test non-zero padding_mode, requiring explicit padding
|
|
module = nn.Conv3d(
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=(2, 3, 4),
|
|
padding="same",
|
|
padding_mode="circular",
|
|
)
|
|
x_padded = F.pad(x, [1, 2, 1, 1, 0, 1], mode="circular")
|
|
expect = F.conv3d(x_padded, module.weight, module.bias, padding="valid")
|
|
self.assertEqual(expect, module(x))
|
|
self.assertEqual(x.size(), expect.size())
|
|
|
|
# Test connstruction with invalid padding string raises
|
|
with self.assertRaisesRegex(ValueError, "Invalid padding string"):
|
|
module = nn.Conv3d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="foo"
|
|
)
|
|
|
|
# Test connstruction with same padding and strides raises
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv3d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv3d(
|
|
in_channels=3,
|
|
out_channels=33,
|
|
kernel_size=10,
|
|
padding="same",
|
|
stride=(1, 1, 3),
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv3d(
|
|
in_channels=3,
|
|
out_channels=33,
|
|
kernel_size=10,
|
|
padding="same",
|
|
stride=(1, 4, 1),
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv3d(
|
|
in_channels=3,
|
|
out_channels=33,
|
|
kernel_size=10,
|
|
padding="same",
|
|
stride=(5, 1, 1),
|
|
)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
def test_thnn_conv_strided_padded_dilated(self):
|
|
for convfn, dims, transposed in (
|
|
(torch.nn.functional.conv2d, 2, False),
|
|
(torch.nn.functional.conv_transpose2d, 2, True),
|
|
(torch.nn.functional.conv3d, 3, False),
|
|
(torch.nn.functional.conv_transpose3d, 3, True),
|
|
):
|
|
for stride, padding, dilation in (
|
|
(2, 0, 1),
|
|
(1, 1, 1),
|
|
(2, 1, 1),
|
|
(1, 0, 2),
|
|
):
|
|
kwargs = {"stride": stride, "padding": padding, "dilation": dilation}
|
|
inp_shape = (1, 2) + dims * (4,)
|
|
weight_shape = (2, 2) + dims * (1,)
|
|
inputs = torch.randn(
|
|
inp_shape, dtype=torch.double, device="cuda", requires_grad=True
|
|
)
|
|
weight = torch.randn(
|
|
weight_shape, dtype=torch.double, device="cuda", requires_grad=True
|
|
)
|
|
bias = torch.randn(
|
|
2, dtype=torch.double, device="cuda", requires_grad=True
|
|
)
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
res = convfn(inputs, weight, bias, **kwargs)
|
|
res_cpu = convfn(inputs.cpu(), weight.cpu(), bias.cpu(), **kwargs)
|
|
self.assertEqual(res, res_cpu)
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
torch.autograd.gradcheck(
|
|
lambda x, w, b: convfn(x, w, b, **kwargs),
|
|
(inputs, weight, bias),
|
|
)
|
|
torch.autograd.gradcheck(
|
|
lambda x, w, b: convfn(x, w, b, **kwargs),
|
|
(inputs.cpu(), weight.cpu(), bias.cpu()),
|
|
)
|
|
|
|
def test_Conv2d_inconsistent_types(self):
|
|
inputs = torch.randn(4, 1, 7, 7, dtype=torch.float)
|
|
weights = torch.randn(1, 1, 3, 3, dtype=torch.double)
|
|
# inconsistent types should raise an exception
|
|
self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
|
|
# but it should work with the same type
|
|
nn.functional.conv2d(inputs.float(), weights.float())
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self):
|
|
inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
|
|
weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
|
|
bias = torch.randn(1, dtype=torch.double, device="cuda")
|
|
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
# inconsistent types should raise an exception
|
|
self.assertRaises(
|
|
RuntimeError, lambda: nn.functional.conv2d(inputs, weights)
|
|
)
|
|
self.assertRaises(
|
|
RuntimeError,
|
|
lambda: nn.functional.conv2d(inputs, weights.float(), bias),
|
|
)
|
|
|
|
# but it should work with the same type
|
|
nn.functional.conv2d(inputs.float(), weights.float(), bias.float())
|
|
|
|
def test_Conv2d_1x1(self):
|
|
in_channels = 2
|
|
mod = torch.nn.Conv2d(2, 2, 1, bias=False).to(dtype=torch.double)
|
|
input = torch.randn(
|
|
1, in_channels, 5, 5, requires_grad=True, dtype=torch.double
|
|
)
|
|
for enabled in (False, True):
|
|
with torch.backends.mkldnn.flags(enabled=enabled):
|
|
gradcheck(F.conv2d, (input, mod.weight))
|
|
|
|
def test_Conv2d_OneDNN(self):
|
|
def run_once(group_val=24, dilation=1):
|
|
ifm = torch.ones([1, group_val, 6, 6], dtype=torch.float32)
|
|
weights = torch.ones([group_val, 1, 3, 3], dtype=torch.float32)
|
|
op = torch.nn.Conv2d(
|
|
in_channels=group_val,
|
|
out_channels=group_val,
|
|
kernel_size=[3, 3],
|
|
stride=[2, 2],
|
|
padding=[1, 1],
|
|
dilation=[dilation, dilation],
|
|
groups=group_val,
|
|
bias=False,
|
|
padding_mode="zeros",
|
|
)
|
|
|
|
op.weight.data = weights
|
|
res = op(ifm)
|
|
grad_in = torch.ones(res.shape, dtype=torch.float32)
|
|
res.backward(grad_in)
|
|
return op.weight.grad
|
|
|
|
for gorup_val in (24, 48, 23, 25):
|
|
for dilation in (1, 2):
|
|
with torch.backends.mkldnn.flags(enabled=False):
|
|
without_onednn = run_once(gorup_val, dilation)
|
|
|
|
with torch.backends.mkldnn.flags(enabled=True):
|
|
with_onednn = run_once(gorup_val, dilation)
|
|
|
|
self.assertEqual(without_onednn, with_onednn)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
|
|
def test_cudnn_non_contiguous(self):
|
|
x = torch.randn(192, 16, 50).cuda()
|
|
x = x.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
|
m = torch.nn.Conv1d(
|
|
in_channels=16, out_channels=32, kernel_size=2, bias=True
|
|
).cuda()
|
|
m(x)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
|
|
def test_cudnn_not_mutate_stride(self):
|
|
weight = torch.randn(64, 64, 1, 1)
|
|
x = torch.randn(2, 64, 10, 10).to(memory_format=torch.channels_last)
|
|
weight_stride = weight.stride()
|
|
|
|
def conv(x, weight):
|
|
return torch.convolution(
|
|
x,
|
|
weight,
|
|
stride=(1, 1),
|
|
padding=(0, 0),
|
|
dilation=(1, 1),
|
|
transposed=False,
|
|
output_padding=(0, 0),
|
|
groups=1,
|
|
bias=None,
|
|
)
|
|
|
|
# should have run in nhwc without mutating input strides
|
|
out_nhwc = conv(x, weight)
|
|
self.assertEqual(weight.stride(), weight_stride)
|
|
self.assertTrue(out_nhwc.is_contiguous(memory_format=torch.channels_last))
|
|
|
|
x = x.contiguous(memory_format=torch.contiguous_format)
|
|
out_c = conv(x, weight)
|
|
self.assertTrue(out_c.is_contiguous(memory_format=torch.contiguous_format))
|
|
self.assertEqual(out_c, out_nhwc)
|
|
self.assertEqual(weight.stride(), weight_stride)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
|
|
def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self):
|
|
inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
|
|
weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
|
|
bias = torch.randn(1, dtype=torch.double, device="cuda")
|
|
|
|
with torch.backends.cudnn.flags(enabled=True):
|
|
# inconsistent types should raise an exception
|
|
self.assertRaises(
|
|
RuntimeError, lambda: nn.functional.conv2d(inputs, weights)
|
|
)
|
|
self.assertRaises(
|
|
RuntimeError,
|
|
lambda: nn.functional.conv2d(inputs, weights.float(), bias),
|
|
)
|
|
|
|
# but it should work with the same type
|
|
nn.functional.conv2d(inputs.float(), weights.float(), bias.float())
|
|
|
|
def test_Conv2d_missing_argument(self):
|
|
c = nn.Conv2d(3, 3, 3)
|
|
self.assertRaises(TypeError, lambda: c(None))
|
|
|
|
def test_Conv2d_backward_twice(self):
|
|
input = torch.randn(2, 3, 5, 5)
|
|
c = nn.Conv2d(3, 3, 3)
|
|
o1 = c(input)
|
|
o1.sum().backward()
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "Specify retain_graph=True", lambda: o1.sum().backward()
|
|
)
|
|
|
|
def test_conv_modules_raise_error_on_incorrect_input_size(self):
|
|
for dtype in [torch.half, torch.bfloat16, torch.double, torch.float]:
|
|
modules = [
|
|
nn.Conv1d(3, 8, 3).to(dtype),
|
|
nn.ConvTranspose1d(3, 8, 3).to(dtype),
|
|
nn.Conv2d(3, 8, 3).to(dtype),
|
|
nn.ConvTranspose2d(3, 8, 3).to(dtype),
|
|
nn.Conv3d(3, 8, 3).to(dtype),
|
|
nn.ConvTranspose3d(3, 8, 3).to(dtype),
|
|
]
|
|
|
|
invalid_input_dims = [(1, 4), (1, 4), (2, 5), (2, 5), (3, 6), (3, 6)]
|
|
|
|
for invalid_dims, module in zip(invalid_input_dims, modules):
|
|
for dims in invalid_dims:
|
|
input = torch.empty(torch.Size((3,) * dims))
|
|
self.assertRaises(RuntimeError, lambda: module(input))
|
|
|
|
def test_conv_shapecheck(self):
|
|
def test(should_raise, module, input_size, dtype):
|
|
input = torch.empty(3, *input_size).to(dtype)
|
|
if should_raise:
|
|
self.assertRaises(RuntimeError, lambda: module(input))
|
|
else:
|
|
# just run it to ensure no exception raised.
|
|
module(input)
|
|
|
|
for dtype in [
|
|
torch.half,
|
|
torch.bfloat16,
|
|
torch.float,
|
|
torch.double,
|
|
torch.cfloat,
|
|
torch.cdouble,
|
|
]:
|
|
# Conv1d
|
|
test(True, nn.Conv1d(1, 1, 3).to(dtype), (1, 2), dtype)
|
|
test(True, nn.Conv1d(1, 1, 3, stride=2).to(dtype), (1, 2), dtype)
|
|
test(False, nn.Conv1d(1, 1, 2).to(dtype), (1, 2), dtype)
|
|
test(False, nn.Conv1d(1, 1, 2, stride=2).to(dtype), (1, 2), dtype)
|
|
test(
|
|
False, nn.Conv1d(1, 1, 3, stride=2, padding=1).to(dtype), (1, 2), dtype
|
|
)
|
|
|
|
# Conv2d
|
|
test(True, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 2, 2), dtype)
|
|
test(False, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 3, 3), dtype)
|
|
test(False, nn.Conv2d(1, 1, (3, 3), padding=1).to(dtype), (1, 2, 2), dtype)
|
|
|
|
# Conv3D
|
|
test(True, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 2, 2, 2), dtype)
|
|
test(False, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 3, 3, 3), dtype)
|
|
test(
|
|
False,
|
|
nn.Conv3d(1, 1, (3, 3, 3), padding=1).to(dtype),
|
|
(1, 2, 2, 2),
|
|
dtype,
|
|
)
|
|
|
|
def test_ConvTranspose2d_output_size(self):
|
|
m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
|
|
i = torch.randn(2, 3, 6, 6)
|
|
for h in range(15, 22):
|
|
for w in range(15, 22):
|
|
if 18 <= h <= 20 and 18 <= w <= 20:
|
|
output = m(i, output_size=(h, w))
|
|
self.assertEqual(output.size()[2:], (h, w))
|
|
else:
|
|
self.assertRaises(ValueError, lambda: m(i, (h, w)))
|
|
|
|
def test_ConvTranspose2d_output_size_downsample_upsample(self):
|
|
b, c, hid_c = 2, 3, 2
|
|
for h in range(13, 24):
|
|
for w in range(13, 17):
|
|
for k in range(2, 5):
|
|
for d in range(1, 5):
|
|
for s in range(1, 4):
|
|
for p in range(3):
|
|
conv = nn.Conv2d(
|
|
in_channels=c,
|
|
out_channels=hid_c,
|
|
kernel_size=k,
|
|
stride=s,
|
|
padding=p,
|
|
dilation=d,
|
|
)
|
|
|
|
t_conv = nn.ConvTranspose2d(
|
|
in_channels=hid_c,
|
|
out_channels=c,
|
|
kernel_size=k,
|
|
stride=s,
|
|
padding=p,
|
|
dilation=d,
|
|
)
|
|
|
|
i = torch.randn(b, c, h, w)
|
|
|
|
out = t_conv(conv(i), output_size=i.shape)
|
|
|
|
self.assertEqual(out.size()[2:], i.size()[2:])
|
|
|
|
def test_ConvTranspose3d_correct_output_size(self):
|
|
# Check that ConvTranspose3d can take a 5d output_size.
|
|
m = nn.ConvTranspose3d(2, 2, 2)
|
|
i = torch.rand(1, 2, 1, 1, 1)
|
|
m(i, output_size=(1, 2, 2, 2, 2))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
def test_ConvTranspose2d_half_cublas_gemm(self):
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
inputs = torch.randn(1, 1, 16, 16, device="cuda", dtype=torch.half)
|
|
deconv = (
|
|
nn.ConvTranspose2d(1, 1, 3, stride=2, padding=1, output_padding=1)
|
|
.cuda()
|
|
.half()
|
|
)
|
|
output = deconv(inputs)
|
|
output.mean().backward()
|
|
|
|
# For https://github.com/pytorch/pytorch/pull/1273
|
|
# Almost identical to the above `test_Conv2d_naive_groups`
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@tf32_on_and_off(0.001)
|
|
def test_Conv2d_groups_nobias(self):
|
|
dev_dtypes = [("cpu", torch.float)]
|
|
if TEST_CUDA:
|
|
dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
|
|
if AMPERE_OR_ROCM:
|
|
dev_dtypes += [("cuda", torch.bfloat16)]
|
|
for device, dtype in dev_dtypes:
|
|
m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype)
|
|
i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
|
|
output = m(i)
|
|
grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
|
|
output.backward(grad_output)
|
|
|
|
m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
|
|
m1.weight.data.copy_(m.weight.data[:2])
|
|
i1 = i.data[:, :2].contiguous().requires_grad_(True)
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :2].contiguous())
|
|
|
|
m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
|
|
m2.weight.data.copy_(m.weight.data[2:])
|
|
i2 = i.data[:, 2:].contiguous().requires_grad_(True)
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, 2:].contiguous())
|
|
|
|
self.assertEqual(output, torch.cat([output1, output2], 1))
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
|
|
# Almost identical to the above `test_Conv2d_naive_groups`
|
|
# Covering special case when group > 1, input-channel / group < 16 and output-channel is multiple of 16
|
|
# See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686
|
|
# and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@tf32_on_and_off(0.001)
|
|
def test_Conv2d_groups_nobias_v2(self):
|
|
torch.manual_seed(123)
|
|
dev_dtypes = [("cpu", torch.float)]
|
|
if TEST_CUDA:
|
|
dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
|
|
if AMPERE_OR_ROCM:
|
|
dev_dtypes += [("cuda", torch.bfloat16)]
|
|
for device, dtype in dev_dtypes:
|
|
m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype)
|
|
i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
|
|
output = m(i)
|
|
grad_output = torch.randn(2, 16, 4, 4, device=device, dtype=dtype)
|
|
output.backward(grad_output)
|
|
|
|
m1 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
|
|
m1.weight.data.copy_(m.weight.data[:8])
|
|
i1 = i.data[:, :2].contiguous().requires_grad_(True)
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :8].contiguous())
|
|
|
|
m2 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
|
|
m2.weight.data.copy_(m.weight.data[8:])
|
|
i2 = i.data[:, 2:].contiguous().requires_grad_(True)
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, 8:].contiguous())
|
|
|
|
self.assertEqual(output, torch.cat([output1, output2], 1))
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
|
|
# CPU-only test for group conv3d fast implementation using bmm
|
|
# See: https://github.com/pytorch/pytorch/pull/36355
|
|
def test_Conv3d_groups_nobias(self):
|
|
torch.manual_seed(123)
|
|
m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=False).to("cpu", torch.float)
|
|
i = torch.randn(
|
|
2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True
|
|
)
|
|
output = m(i)
|
|
grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
|
|
output.backward(grad_output)
|
|
|
|
m1 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
|
|
m1.weight.data.copy_(m.weight.data[:8])
|
|
i1 = i.data[:, :2].contiguous().requires_grad_(True)
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :8].contiguous())
|
|
|
|
m2 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
|
|
m2.weight.data.copy_(m.weight.data[8:])
|
|
i2 = i.data[:, 2:].contiguous().requires_grad_(True)
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, 8:].contiguous())
|
|
|
|
self.assertEqual(output, torch.cat([output1, output2], 1))
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[torch.float],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[torch.float],
|
|
rtol=dtype2prec_DONTUSE[torch.float],
|
|
)
|
|
|
|
def test_Conv3d_groups_wbias(self):
|
|
torch.manual_seed(123)
|
|
m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=True).to("cpu", torch.float)
|
|
i = torch.randn(
|
|
2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True
|
|
)
|
|
output = m(i)
|
|
grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
|
|
output.backward(grad_output)
|
|
|
|
m1 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
|
|
m1.weight.data.copy_(m.weight.data[:8])
|
|
m1.bias.data.copy_(m.bias.data[:8])
|
|
i1 = i.data[:, :2].contiguous().requires_grad_(True)
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :8].contiguous())
|
|
|
|
m2 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
|
|
m2.weight.data.copy_(m.weight.data[8:])
|
|
m2.bias.data.copy_(m.bias.data[8:])
|
|
i2 = i.data[:, 2:].contiguous().requires_grad_(True)
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, 8:].contiguous())
|
|
|
|
self.assertEqual(output, torch.cat([output1, output2], 1))
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[torch.float],
|
|
rtol=dtype2prec_DONTUSE[torch.float],
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[torch.float],
|
|
rtol=dtype2prec_DONTUSE[torch.float],
|
|
)
|
|
self.assertEqual(
|
|
m.bias.grad.data,
|
|
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[torch.float],
|
|
rtol=dtype2prec_DONTUSE[torch.float],
|
|
)
|
|
|
|
def test_conv_tbc(self):
|
|
with set_default_dtype(torch.double):
|
|
inp = torch.randn(9, 4, 5, requires_grad=True)
|
|
weight = torch.randn(3, 5, 6, requires_grad=True)
|
|
bias = torch.randn(6, requires_grad=True)
|
|
|
|
gradcheck(
|
|
lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3)
|
|
)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
|
def test_grouped_conv_cudnn_nhwc_support(self):
|
|
# in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
|
|
input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
weight = torch.randn((8, 4, 3, 3), dtype=torch.float16, device="cuda").to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), False, (0, 0), 4)
|
|
input = torch.randn((16, 8, 8, 8), dtype=torch.float16, device="cuda").to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), True, (0, 0), 4)
|
|
|
|
@unittest.expectedFailure
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
|
def test_conv_cudnn_memory_layout_dominance(self):
|
|
# desired behavior here is to have the memory_layout of conv.weight to
|
|
# dominante the layout of output.
|
|
# which is not the same as current behavior, we'll fix this in
|
|
# following up PRs and remove the `expectedFailure` tag
|
|
input = torch.randint(
|
|
1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True
|
|
)
|
|
conv = nn.Conv2d(8, 4, 3).cuda().float()
|
|
|
|
out = conv(input)
|
|
self.assertTrue(out.is_contiguous())
|
|
|
|
input = input.contiguous(memory_format=torch.channels_last)
|
|
out = conv(input)
|
|
self.assertTrue(out.is_contiguous())
|
|
|
|
conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last)
|
|
out = conv(input)
|
|
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
|
|
|
input = input.contiguous()
|
|
out = conv(input)
|
|
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_cudnn_noncontiguous_weight(self):
|
|
# Noncontiguous weights must be contiguous() before being
|
|
# passed to cuDNN
|
|
input = torch.tensor([1, 1, 1], dtype=torch.double, device="cuda").view(1, 1, 3)
|
|
weights1 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2)
|
|
weights2 = (
|
|
torch.tensor([1], dtype=torch.double, device="cuda")
|
|
.expand(1, 1, 2)
|
|
.contiguous()
|
|
)
|
|
self.assertEqual(
|
|
F.conv1d(input, weights1, bias=None, stride=2, dilation=2),
|
|
F.conv1d(input, weights2, bias=None, stride=2, dilation=2),
|
|
)
|
|
|
|
def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient="input"):
|
|
for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
|
|
for batch, stride, padding, chan_in, chan_out, dilation in product(
|
|
[1, 2], [1, 2], [0, 1, 2], [2], [3], [1]
|
|
):
|
|
for has_bias in [True, False]:
|
|
input_shape = [batch, chan_in]
|
|
weight_shape = [chan_out, chan_in]
|
|
for _ in range(dim):
|
|
input_shape.append(inp_size)
|
|
weight_shape.append(kern)
|
|
|
|
input = torch.randn(input_shape, requires_grad=True)
|
|
weight = torch.randn(weight_shape, requires_grad=True)
|
|
if has_bias:
|
|
bias = torch.randn([chan_out], requires_grad=True)
|
|
output = func_forward(
|
|
input,
|
|
weight,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
bias=bias,
|
|
)
|
|
|
|
gradient_o = torch.randn(output.shape)
|
|
gradient_w = torch.autograd.grad(
|
|
output, input if (gradient == "input") else weight, gradient_o
|
|
)
|
|
|
|
self.assertEqual(
|
|
gradient_w[0],
|
|
func_backward(
|
|
input_shape if (gradient == "input") else input,
|
|
weight_shape if (gradient == "weight") else weight,
|
|
gradient_o,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
),
|
|
)
|
|
|
|
def test_grad_conv1d_input(self):
|
|
self.run_grad_conv_test(F.conv1d, F.grad.conv1d_input, 1, "input")
|
|
|
|
def test_grad_conv1d_weight(self):
|
|
self.run_grad_conv_test(F.conv1d, F.grad.conv1d_weight, 1, "weight")
|
|
|
|
def test_grad_conv2d_input(self):
|
|
self.run_grad_conv_test(F.conv2d, F.grad.conv2d_input, 2, "input")
|
|
|
|
def test_grad_conv2d_weight(self):
|
|
self.run_grad_conv_test(F.conv2d, F.grad.conv2d_weight, 2, "weight")
|
|
|
|
def test_grad_conv3d_input(self):
|
|
self.run_grad_conv_test(F.conv3d, F.grad.conv3d_input, 3, "input")
|
|
|
|
def test_grad_conv3d_weight(self):
|
|
self.run_grad_conv_test(F.conv3d, F.grad.conv3d_weight, 3, "weight")
|
|
|
|
@unittest.skipIf(not torch._nnpack_available(), "NNPACK unavailable")
|
|
def test_nnpack_conv(self):
|
|
for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
|
|
for batch, stride, padding, chan_in, chan_out in product(
|
|
[1, 2, 3, 4], [1, 2], [0, 1, 2], [2], [3]
|
|
):
|
|
for has_bias in [True, False]:
|
|
input_shape = [batch, chan_in]
|
|
weight_shape = [chan_out, chan_in]
|
|
for _ in range(2):
|
|
input_shape.append(inp_size)
|
|
weight_shape.append(kern)
|
|
|
|
input = torch.randn(
|
|
input_shape, requires_grad=True, dtype=torch.float
|
|
)
|
|
weight = torch.randn(
|
|
weight_shape, requires_grad=True, dtype=torch.float
|
|
)
|
|
if has_bias:
|
|
bias = torch.randn(
|
|
[chan_out], requires_grad=True, dtype=torch.float
|
|
)
|
|
output = torch._nnpack_spatial_convolution(
|
|
input, weight, stride=stride, padding=padding, bias=bias
|
|
)
|
|
output_expected = torch.nn.functional.conv2d(
|
|
input, weight, stride=stride, padding=padding, bias=bias
|
|
)
|
|
self.assertEqual(output, output_expected, atol=3e-4, rtol=0)
|
|
|
|
gradient_o = torch.randn(output.shape, dtype=torch.float)
|
|
|
|
grads = torch.autograd.grad(output, [input, weight], gradient_o)
|
|
grads_expected = torch.autograd.grad(
|
|
output_expected, [input, weight], gradient_o
|
|
)
|
|
for gr, gr_expected in zip(grads, grads_expected):
|
|
self.assertEqual(gr, gr_expected, atol=3e-4, rtol=0)
|
|
|
|
def test_conv_padding_mode(self):
|
|
with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
|
|
nn.Conv2d(3, 3, 3, padding_mode="xyz")
|
|
|
|
with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
|
|
nn.Conv2d(3, 3, 3, padding_mode=3)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'Only "zeros" '):
|
|
nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect")
|
|
|
|
def test_functional_grad_conv(self):
|
|
# Conv 1D
|
|
input = torch.randn(1, 1, 5, requires_grad=True)
|
|
weight = torch.randn(1, 1, 3, requires_grad=True)
|
|
output = F.conv1d(input, weight, dilation=2)
|
|
grad_output = torch.randn(output.shape)
|
|
|
|
grad_input_autograd, grad_weight_autograd = torch.autograd.grad(
|
|
output, (input, weight), grad_output
|
|
)
|
|
|
|
grad_input_functional = torch.nn.grad.conv1d_input(
|
|
input.shape, weight, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_input_functional, grad_input_autograd)
|
|
|
|
grad_weight_functional = torch.nn.grad.conv1d_weight(
|
|
input, weight.shape, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_weight_functional, grad_weight_autograd)
|
|
|
|
# Conv 2D
|
|
input = torch.randn(1, 1, 5, 5, requires_grad=True)
|
|
weight = torch.randn(1, 1, 3, 3, requires_grad=True)
|
|
output = F.conv2d(input, weight, dilation=2)
|
|
grad_output = torch.randn(output.shape)
|
|
|
|
(grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
|
|
output, (input, weight), grad_output
|
|
)
|
|
|
|
grad_input_functional = torch.nn.grad.conv2d_input(
|
|
input.shape, weight, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_input_functional, grad_input_autograd)
|
|
|
|
grad_weight_functional = torch.nn.grad.conv2d_weight(
|
|
input, weight.shape, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_weight_functional, grad_weight_autograd)
|
|
|
|
# Conv 3D
|
|
input = torch.randn(1, 1, 5, 5, 5, requires_grad=True)
|
|
weight = torch.randn(1, 1, 3, 3, 3, requires_grad=True)
|
|
output = F.conv3d(input, weight, dilation=2)
|
|
grad_output = torch.randn(output.shape)
|
|
|
|
(grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
|
|
output, (input, weight), grad_output
|
|
)
|
|
|
|
grad_input_functional = torch.nn.grad.conv3d_input(
|
|
input.shape, weight, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_input_functional, grad_input_autograd)
|
|
|
|
grad_weight_functional = torch.nn.grad.conv3d_weight(
|
|
input, weight.shape, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_weight_functional, grad_weight_autograd)
|
|
|
|
def test_functional_grad_conv2d(self):
|
|
BATCH_SIZE = 4
|
|
IN_CH = 8
|
|
OUT_CH = 16
|
|
SPATIAL = 32
|
|
|
|
def _test_conv2d(stride, kernel_size, groups, dilation):
|
|
padding = kernel_size // 2
|
|
|
|
input = (
|
|
torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL)
|
|
.uniform_(-8.0, 8.0)
|
|
.requires_grad_(True)
|
|
)
|
|
|
|
weight = (
|
|
torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size)
|
|
.uniform_(-4.0, 4.0)
|
|
.requires_grad_(True)
|
|
)
|
|
|
|
output = F.conv2d(
|
|
input,
|
|
weight,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
)
|
|
|
|
grad_output = torch.randn(output.shape)
|
|
|
|
(grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
|
|
output, (input, weight), grad_output
|
|
)
|
|
|
|
grad_input_functional = torch.nn.grad.conv2d_input(
|
|
input.shape,
|
|
weight,
|
|
grad_output,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
)
|
|
self.assertEqual(grad_input_functional, grad_input_autograd)
|
|
|
|
grad_weight_functional = torch.nn.grad.conv2d_weight(
|
|
input,
|
|
weight.shape,
|
|
grad_output,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
)
|
|
self.assertEqual(grad_weight_functional, grad_weight_autograd)
|
|
|
|
strides = [1, 2]
|
|
kernel_sizes = [1, 3, 5]
|
|
groups = [1, 2, 4]
|
|
dilates = [1, 2]
|
|
|
|
for s, k, g, d in product(strides, kernel_sizes, groups, dilates):
|
|
_test_conv2d(s, k, g, d)
|
|
|
|
def test_permute_conv2d_issue_120211(self):
|
|
def reproducer(radius: int):
|
|
image = torch.rand(1, 1024, 1024, 3)
|
|
image = image.permute(0, 3, 1, 2)
|
|
kernel_x = torch.zeros([3, 1, 1, radius * 2 + 1], device=image.device)
|
|
image = torch.nn.functional.conv2d(image, kernel_x, groups=image.shape[-3])
|
|
|
|
for i in range(0, 128):
|
|
# This should not fail
|
|
reproducer(radius=i)
|
|
|
|
def test_conv3d_issue_120406(self):
|
|
# This should not fail
|
|
F.conv3d(torch.ones(2, 3, 8, 9, 26), torch.ones(3, 1, 1, 1, 17), groups=3)
|
|
|
|
def test_conv1d_issue_120547(self):
|
|
weight = torch.ones([16, 1, 32])
|
|
bias = torch.ones([16])
|
|
stride, padding, dilation, groups = (1, 16, 1, 16)
|
|
input = torch.rand((1, 1, 16))
|
|
input = input.transpose(1, 2)
|
|
# This should not fail
|
|
F.conv1d(input, weight, bias, stride, padding, dilation, groups)
|
|
|
|
|
|
class TestConvolutionNNDeviceType(NNTestCase):
|
|
def run_conv_double_back_test(
|
|
self,
|
|
kern,
|
|
stride,
|
|
padding,
|
|
chan_in,
|
|
chan_out,
|
|
batch_size,
|
|
inp_size,
|
|
dilation,
|
|
no_weight,
|
|
groups=1,
|
|
use_cuda=False,
|
|
use_bias=True,
|
|
dtype=torch.double,
|
|
):
|
|
if use_cuda:
|
|
device = torch.device("cuda")
|
|
else:
|
|
device = torch.device("cpu")
|
|
|
|
x = torch.randn(
|
|
batch_size,
|
|
chan_in,
|
|
inp_size,
|
|
inp_size,
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
)
|
|
weight = torch.randn(
|
|
chan_out,
|
|
chan_in // groups,
|
|
kern,
|
|
kern,
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=not no_weight,
|
|
)
|
|
if use_bias:
|
|
bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True)
|
|
else:
|
|
bias = None
|
|
|
|
def func(*inputs):
|
|
if use_bias:
|
|
lx, lweight, lbias = inputs
|
|
else:
|
|
lx, lweight = inputs
|
|
lbias = None
|
|
# We disable cudnn during forward to avoid finite difference imprecision issues
|
|
with cudnn.flags(enabled=False):
|
|
out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
|
|
return out
|
|
|
|
if use_bias:
|
|
inputs = x, weight, bias
|
|
else:
|
|
inputs = x, weight
|
|
|
|
dummy_out = func(*inputs)
|
|
grad_y = torch.randn_like(
|
|
dummy_out, device=device, dtype=dtype, requires_grad=True
|
|
)
|
|
|
|
# Issue #15353: test mkldnn double backward, don't run gradgradcheck due
|
|
# to imprecision issues
|
|
if dtype == torch.float:
|
|
(g,) = torch.autograd.grad(dummy_out.sum(), x, create_graph=True)
|
|
return g.requires_grad
|
|
|
|
return gradgradcheck(func, inputs, (grad_y,))
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfNoCudnn
|
|
@dtypes(
|
|
*floating_and_complex_types_and(
|
|
torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []
|
|
)
|
|
)
|
|
@parametrize_test("dilation", [1, 2, 3])
|
|
def test_Conv2d_deterministic_cudnn(self, device, dtype, dilation):
|
|
inputs = torch.randn(2, 3, 7, 7, device=device, dtype=dtype, requires_grad=True)
|
|
with cudnn.flags(enabled=True, benchmark=True, deterministic=True):
|
|
conv1 = torch.nn.Conv2d(3, 3, 3, dilation=dilation).to(device, dtype)
|
|
conv2 = torch.nn.Conv2d(3, 3, 3, dilation=dilation).to(device, dtype)
|
|
conv2.bias.data.copy_(conv1.bias.data)
|
|
conv2.weight.data.copy_(conv1.weight.data)
|
|
out1 = conv1(inputs)
|
|
out2 = conv2(inputs)
|
|
self.assertEqual(out1, out2, atol=0.0, rtol=0)
|
|
y = torch.randn(out1.size(), device=device, dtype=dtype)
|
|
out1.backward(y)
|
|
out2.backward(y)
|
|
self.assertEqual(
|
|
conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0
|
|
)
|
|
self.assertEqual(
|
|
conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0
|
|
)
|
|
|
|
@onlyCUDA
|
|
@dtypes(
|
|
*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
|
|
)
|
|
def test_Conv2d_large_workspace(self, device, dtype):
|
|
# These sizes require huge cuDNN workspaces. Make sure we choose a
|
|
# reasonable algorithm that does not run out of memory
|
|
sizes = [
|
|
(1, 256, 109, 175),
|
|
(1, 256, 80, 128),
|
|
(1, 256, 120, 192),
|
|
]
|
|
|
|
def run_test(benchmark):
|
|
with torch.backends.cudnn.flags(enabled=True, benchmark=benchmark):
|
|
conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(
|
|
device, dtype
|
|
)
|
|
for size in sizes:
|
|
x = torch.randn(size, device=device, dtype=dtype)
|
|
out = conv(x.detach().clone().requires_grad_())
|
|
out.backward(torch.ones_like(out))
|
|
|
|
run_test(benchmark=False)
|
|
run_test(benchmark=True)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.half, torch.float)
|
|
def test_ConvTranspose2d_large_output_padding(self, device, dtype):
|
|
net1 = torch.nn.ConvTranspose2d(
|
|
128, 64, kernel_size=3, stride=2, padding=1, output_padding=1
|
|
).to(device=device, dtype=dtype)
|
|
net2 = torch.nn.ConvTranspose2d(
|
|
64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
|
|
).to(device=device, dtype=dtype)
|
|
net3 = torch.nn.ConvTranspose2d(
|
|
32, 3, kernel_size=3, stride=2, padding=1, output_padding=1
|
|
).to(device=device, dtype=dtype)
|
|
x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True)
|
|
x = net1(x)
|
|
x = net2(x)
|
|
x = net3(x)
|
|
x.backward(torch.randn_like(x))
|
|
torch.cuda.synchronize()
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
# Very similar to test_Conv2d_naive_groups but with special care to handle
|
|
# the number of groups == number of input channels
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@tf32_on_and_off(0.01)
|
|
def test_Conv2d_depthwise_naive_groups(self, device, dtype):
|
|
for depth_multiplier in [1, 2]:
|
|
m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
|
|
device, dtype
|
|
)
|
|
i = (
|
|
torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype)
|
|
.div_(2)
|
|
.requires_grad_()
|
|
)
|
|
output = m(i)
|
|
grad_output = (
|
|
torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype)
|
|
/ 2
|
|
)
|
|
output.backward(grad_output)
|
|
|
|
offset = 1 * depth_multiplier
|
|
|
|
m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
|
|
m1.weight.data = m.weight.data[:offset].clone()
|
|
m1.bias.data = m.bias.data[:offset].clone()
|
|
i1 = i.detach()[:, :1].clone().requires_grad_()
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :offset].contiguous())
|
|
|
|
m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
|
|
m2.weight.data.copy_(m.weight.data[offset:])
|
|
m2.bias.data.copy_(m.bias.data[offset:])
|
|
i2 = i.detach()[:, 1:].clone().requires_grad_()
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, offset:].contiguous())
|
|
|
|
self.assertEqual(
|
|
output,
|
|
torch.cat([output1, output2], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.bias.grad.data,
|
|
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@tf32_on_and_off(0.01)
|
|
def test_Conv3d_depthwise_naive_groups(self, device, dtype):
|
|
for depth_multiplier in [1, 2]:
|
|
m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
|
|
device, dtype
|
|
)
|
|
i = (
|
|
torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype)
|
|
.div_(2)
|
|
.requires_grad_()
|
|
)
|
|
output = m(i)
|
|
grad_output = (
|
|
torch.randn(
|
|
2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype
|
|
)
|
|
/ 2
|
|
)
|
|
output.backward(grad_output)
|
|
|
|
offset = 1 * depth_multiplier
|
|
|
|
m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
|
|
m1.weight.data = m.weight.data[:offset].clone()
|
|
m1.bias.data = m.bias.data[:offset].clone()
|
|
i1 = i.detach()[:, :1].clone().requires_grad_()
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :offset].contiguous())
|
|
|
|
m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
|
|
m2.weight.data.copy_(m.weight.data[offset:])
|
|
m2.bias.data.copy_(m.bias.data[offset:])
|
|
i2 = i.detach()[:, 1:].clone().requires_grad_()
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, offset:].contiguous())
|
|
is_cuda_sm86 = device.startswith(
|
|
"cuda"
|
|
) and torch.cuda.get_device_capability(0) == (8, 6)
|
|
atol, rtol = (
|
|
(3e-4, 3e-2)
|
|
if dtype == torch.float32 and is_cuda_sm86
|
|
else (dtype2prec_DONTUSE[dtype], 0)
|
|
)
|
|
|
|
self.assertEqual(
|
|
output, torch.cat([output1, output2], 1), atol=atol, rtol=rtol
|
|
)
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.bias.grad.data,
|
|
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=atol,
|
|
rtol=rtol,
|
|
)
|
|
|
|
@onlyCUDA
|
|
@dtypes(
|
|
*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
|
|
)
|
|
def test_noncontig_conv_grad(self, device, dtype):
|
|
# FIXME: remove after adding non-contiguous grad tests for all modules
|
|
module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype)
|
|
input = torch.randn(
|
|
2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True
|
|
)
|
|
output = module(input)
|
|
|
|
grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1]
|
|
assert not grad.is_contiguous()
|
|
output.backward(grad, retain_graph=True)
|
|
self.assertIsNotNone(input.grad)
|
|
result = input.grad.data.clone()
|
|
input.grad.data.zero_()
|
|
|
|
output.backward(grad.contiguous())
|
|
self.assertEqual(
|
|
result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0
|
|
)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.double)
|
|
def test_conv_double_backward(self, device, dtype):
|
|
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
|
|
# Double backward only runs with DoubleTensor due to precision reason
|
|
batch_size = 1
|
|
for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]:
|
|
for stride, padding, chan_in, chan_out, dilation in product(
|
|
[1], [2], [2], [3], dilations
|
|
):
|
|
no_weight = stride == 2
|
|
result = self.run_conv_double_back_test(
|
|
kern,
|
|
stride,
|
|
padding,
|
|
chan_in,
|
|
chan_out,
|
|
batch_size,
|
|
inp_size,
|
|
dilation,
|
|
no_weight,
|
|
use_cuda=True,
|
|
dtype=dtype,
|
|
)
|
|
self.assertTrue(
|
|
result,
|
|
"Conv double backward test failed with parameters:"
|
|
+ "\nkern: "
|
|
+ str(kern)
|
|
+ "\nstride: "
|
|
+ str(stride)
|
|
+ "\npadding: "
|
|
+ str(padding)
|
|
+ "\nchan_in: "
|
|
+ str(chan_in)
|
|
+ "\nchan_out: "
|
|
+ str(chan_out)
|
|
+ "\nbatch_size: "
|
|
+ str(batch_size)
|
|
+ "\ninp_size: "
|
|
+ str(inp_size)
|
|
+ "\ndilation: "
|
|
+ str(dilation),
|
|
)
|
|
|
|
def test_conv_double_backward_no_bias(self):
|
|
kern = 3
|
|
stride = 2
|
|
chan_in, chan_out = 2, 4
|
|
batch_size = 2
|
|
inp_size = 5
|
|
padding = 1
|
|
dilation = 1
|
|
no_weight = False
|
|
use_bias = True
|
|
result = self.run_conv_double_back_test(
|
|
kern,
|
|
stride,
|
|
padding,
|
|
chan_in,
|
|
chan_out,
|
|
batch_size,
|
|
inp_size,
|
|
dilation,
|
|
no_weight,
|
|
use_bias=use_bias,
|
|
)
|
|
self.assertTrue(
|
|
result,
|
|
"Conv double backward test failed with parameters:"
|
|
+ "\nkern: "
|
|
+ str(kern)
|
|
+ "\nstride: "
|
|
+ str(stride)
|
|
+ "\npadding: "
|
|
+ str(padding)
|
|
+ "\nchan_in: "
|
|
+ str(chan_in)
|
|
+ "\nchan_out: "
|
|
+ str(chan_out)
|
|
+ "\nbatch_size: "
|
|
+ str(batch_size)
|
|
+ "\ninp_size: "
|
|
+ str(inp_size)
|
|
+ "\ndilation: "
|
|
+ str(dilation),
|
|
)
|
|
|
|
def test_conv_double_backward_groups(self):
|
|
kern = 3
|
|
stride = 1
|
|
padding = 2
|
|
chan_in, chan_out = 2, 4
|
|
batch_size = 2
|
|
inp_size = 6
|
|
dilation = 1
|
|
no_weight = False
|
|
groups = 2
|
|
result = self.run_conv_double_back_test(
|
|
kern,
|
|
stride,
|
|
padding,
|
|
chan_in * groups,
|
|
chan_out * groups,
|
|
batch_size,
|
|
inp_size,
|
|
dilation,
|
|
no_weight,
|
|
groups=groups,
|
|
)
|
|
self.assertTrue(
|
|
result,
|
|
"Conv double backward test failed with parameters:"
|
|
+ "\nkern: "
|
|
+ str(kern)
|
|
+ "\nstride: "
|
|
+ str(stride)
|
|
+ "\npadding: "
|
|
+ str(padding)
|
|
+ "\nchan_in: "
|
|
+ str(chan_in)
|
|
+ "\nchan_out: "
|
|
+ str(chan_out)
|
|
+ "\nbatch_size: "
|
|
+ str(batch_size)
|
|
+ "\ninp_size: "
|
|
+ str(inp_size)
|
|
+ "\ndilation: "
|
|
+ str(dilation)
|
|
+ "\ngroups: "
|
|
+ str(groups),
|
|
)
|
|
|
|
def test_conv_double_backward_stride(self):
|
|
batch_size = 2
|
|
|
|
# Cannot provide ggW when stride is > 1
|
|
for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]:
|
|
for stride, padding, chan_in, chan_out, dilation in product(
|
|
[2], [0, 1], [1], [2], dilations
|
|
):
|
|
no_weight = False
|
|
self.run_conv_double_back_test(
|
|
kern,
|
|
stride,
|
|
padding,
|
|
chan_in,
|
|
chan_out,
|
|
batch_size,
|
|
inp_size,
|
|
dilation,
|
|
no_weight,
|
|
)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
def test_conv1d_same_padding(self, device, dtype):
|
|
# Test padding='same' outputs the correct shape
|
|
test_args = [
|
|
# in_size
|
|
range(50, 55),
|
|
# kernel_size
|
|
[1, 2, 3, 8],
|
|
# dilation
|
|
range(1, 4),
|
|
# stride
|
|
[1],
|
|
]
|
|
for in_size, k_size, dilation, stride in itertools.product(*test_args):
|
|
x = torch.rand(1, 1, in_size, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, k_size, device=device, dtype=dtype)
|
|
z = F.conv1d(x, y, padding="same", dilation=dilation, stride=stride)
|
|
self.assertEqual(z.size(2), int(math.ceil(in_size / stride)))
|
|
|
|
# Compare F.conv1d padding='same' output against manual padding
|
|
# Without strides/dilation
|
|
x = torch.rand(1, 1, 12, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 3, device=device, dtype=dtype)
|
|
expect = F.conv1d(x, y, padding=1)
|
|
actual = F.conv1d(x, y, padding="same")
|
|
self.assertEqual(expect, actual)
|
|
|
|
# With dilation
|
|
x = torch.rand(1, 1, 12, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 4, device=device, dtype=dtype)
|
|
expect = F.conv1d(x, y, padding=3, dilation=2)
|
|
actual = F.conv1d(x, y, padding="same", dilation=2)
|
|
self.assertEqual(expect, actual)
|
|
|
|
# Dilation with asymmetric padding
|
|
expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:]
|
|
actual = F.conv1d(x, y, padding="same", dilation=3)
|
|
self.assertEqual(expect, actual)
|
|
|
|
@tf32_on_and_off(0.005)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
@dtypes(torch.float, torch.cfloat)
|
|
def test_conv2d_same_padding(self, device, dtype):
|
|
# Compare F.conv2d padding='same' output against manual padding
|
|
# Without strides/dilation
|
|
x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype)
|
|
expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
|
|
actual = F.conv2d(x, y, padding="same")
|
|
self.assertEqual(expect, actual)
|
|
|
|
# With dilation
|
|
y = torch.rand(1, 1, 3, 4, device=device, dtype=dtype)
|
|
expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
|
|
actual = F.conv2d(x, y, padding="same", dilation=2)
|
|
self.assertEqual(expect, actual)
|
|
|
|
# Dilation with asymmetric padding
|
|
y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype)
|
|
expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
|
|
actual = F.conv2d(x, y, padding="same", dilation=3)
|
|
self.assertEqual(expect, actual)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
def test_conv3d_same_padding(self, device, dtype):
|
|
if dtype is torch.cfloat:
|
|
rtol, atol = 2e-6, 2e-6
|
|
else:
|
|
rtol, atol = None, None
|
|
# Compare F.conv3d padding='same' output against manual padding
|
|
# Without strides/dilation
|
|
x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype)
|
|
expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :]
|
|
actual = F.conv3d(x, y, padding="same")
|
|
self.assertEqual(expect, actual, rtol=rtol, atol=atol)
|
|
|
|
# With dilation
|
|
expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
|
|
actual = F.conv3d(x, y, padding="same", dilation=2)
|
|
self.assertEqual(expect, actual, rtol=rtol, atol=atol)
|
|
|
|
# Dilation with asymmetric padding
|
|
y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype)
|
|
expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:]
|
|
actual = F.conv3d(x, y, padding="same", dilation=3)
|
|
self.assertEqual(expect, actual, rtol=rtol, atol=atol)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
def test_conv1d_valid_padding(self, device, dtype):
|
|
# Test F.conv1d padding='valid' is the same as no padding
|
|
x = torch.rand(1, 1, 10, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 4, device=device, dtype=dtype)
|
|
expect = F.conv1d(x, y)
|
|
actual = F.conv1d(x, y, padding="valid")
|
|
self.assertEqual(expect, actual)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
def test_conv2d_valid_padding(self, device, dtype):
|
|
# Test F.conv2d padding='valid' is the same as no padding
|
|
x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype)
|
|
expect = F.conv2d(x, y)
|
|
actual = F.conv2d(x, y, padding="valid")
|
|
self.assertEqual(expect, actual)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
def test_conv3d_valid_padding(self, device, dtype):
|
|
# Test F.conv3d padding='valid' is the same as no padding
|
|
x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device)
|
|
y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device)
|
|
expect = F.conv3d(x, y)
|
|
actual = F.conv3d(x, y, padding="valid")
|
|
self.assertEqual(expect, actual)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
def test_conv1d_same_padding_backward(self, device, dtype):
|
|
# Test F.conv1d gradients work with padding='same'
|
|
x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True)
|
|
y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
|
|
|
|
# Symmetric padding
|
|
z = F.conv1d(x, y, padding=3, dilation=2)
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv1d(x, y, padding="same", dilation=2)
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
x.grad, y.grad = None, None
|
|
|
|
# Asymmetric padding
|
|
z = F.conv1d(x, y, padding=2)[..., 1:]
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv1d(x, y, padding="same")
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
@tf32_on_and_off(0.001)
|
|
def test_conv2d_same_padding_backward(self, device, dtype):
|
|
# Test F.conv2d gradients work with padding='same'
|
|
x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True)
|
|
y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True)
|
|
|
|
# Symmetric padding
|
|
z = F.conv2d(x, y, padding=(3, 4), dilation=2)
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv2d(x, y, padding="same", dilation=2)
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
x.grad, y.grad = None, None
|
|
|
|
# Asymmetric padding
|
|
y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True)
|
|
z = F.conv2d(x, y, padding=2)[..., 1:, 1:]
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv2d(x, y, padding="same")
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
|
|
@dtypes(torch.double, torch.cdouble)
|
|
@dtypesIfMPS(
|
|
torch.float, torch.cfloat
|
|
) # Double, complex double not supported on MPS
|
|
@expectedFailureMPS # https://github.com/pytorch/pytorch/issues/107214
|
|
def test_conv3d_same_padding_backward(self, device, dtype):
|
|
check_forward_ad = torch.device(device).type != "xla"
|
|
|
|
# Test F.conv3d gradients work with padding='same'
|
|
x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True)
|
|
y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True)
|
|
|
|
# Symmetric padding
|
|
z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv3d(x, y, padding="same", dilation=2)
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
x.grad, y.grad = None, None
|
|
|
|
gradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
|
|
(x, y),
|
|
check_forward_ad=check_forward_ad,
|
|
nondet_tol=1e-5,
|
|
)
|
|
if torch.device(device).type != "cuda":
|
|
# https://github.com/pytorch/pytorch/issues/70702
|
|
gradgradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
|
|
(x, y),
|
|
check_fwd_over_rev=True,
|
|
)
|
|
|
|
# Asymmetric padding
|
|
y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True)
|
|
z = F.conv3d(x, y, padding=2)[..., 1:, 1:]
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv3d(x, y, padding="same")
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
|
|
gradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="same"),
|
|
(x, y),
|
|
check_forward_ad=check_forward_ad,
|
|
nondet_tol=1e-5,
|
|
)
|
|
if torch.device(device).type != "cuda":
|
|
# https://github.com/pytorch/pytorch/issues/70702
|
|
gradgradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="same"),
|
|
(x, y),
|
|
check_fwd_over_rev=True,
|
|
)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
def test_conv1d_valid_padding_backward(self, device, dtype):
|
|
# Test F.conv1d gradients work with padding='valid'
|
|
x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True)
|
|
y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
|
|
F.conv1d(x, y, padding=0).sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
F.conv1d(x, y, padding="valid").sum().abs().backward()
|
|
gx_actual, gy_actual = x.grad, y.grad
|
|
self.assertEqual(gx_expect, gx_actual)
|
|
self.assertEqual(gy_expect, gy_actual)
|
|
|
|
@unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
@parametrize_test("mode", ("valid", "same"))
|
|
def test_conv1d_vs_scipy(self, device, dtype, mode):
|
|
t = make_tensor((1, 10), device=device, dtype=dtype)
|
|
feat_dim = t.shape[1]
|
|
weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype)
|
|
weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype)
|
|
|
|
def _test(t, weight, mode):
|
|
# SciPy expects two 1-D inputs.
|
|
t_a = t.view(-1).cpu().numpy()
|
|
w_a = weight.view(-1).cpu().numpy()
|
|
expected = scipy.signal.convolve(t_a, w_a, mode=mode)
|
|
|
|
kwargs = {"padding": mode}
|
|
if mode == "same":
|
|
# `same` padding in PyTorch conv1d is different
|
|
# from SciPy
|
|
p = weight.shape[2] // 2
|
|
t = torch.nn.functional.pad(t, (p, p))
|
|
# We have already taken care of padding
|
|
kwargs.pop("padding")
|
|
|
|
# second input is flipped in SciPy's convolve
|
|
weight_flipped = torch.flip(weight, (2,))
|
|
actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0)
|
|
if mode == "same":
|
|
actual = actual[:feat_dim]
|
|
|
|
self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5)
|
|
|
|
# Global dtype for this test suite is torch.double
|
|
# This leads to change in type-promotion
|
|
# and conv1d outputs `complex128` for `complex64` input.
|
|
with set_default_dtype(torch.float):
|
|
_test(t, weight_even, mode)
|
|
_test(t, weight_odd, mode)
|
|
|
|
@unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
@parametrize_test("mode", ("valid", "same"))
|
|
def test_conv2d_vs_scipy(self, device, dtype, mode):
|
|
t = make_tensor((1, 5, 10), device=device, dtype=dtype)
|
|
weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype)
|
|
weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype)
|
|
|
|
def _test(t, weight, mode):
|
|
# SciPy expects two 2-D inputs.
|
|
t_a = t.squeeze(0).cpu().numpy()
|
|
w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
|
|
expected = scipy.signal.convolve2d(t_a, w_a, mode=mode)
|
|
|
|
kwargs = {"padding": mode}
|
|
if mode == "same":
|
|
# `same` padding in PyTorch conv2d is different
|
|
# from SciPy
|
|
left_right_pad = weight.shape[3] // 2
|
|
top_bottom_pad = weight.shape[2] // 2
|
|
p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad)
|
|
t = torch.nn.functional.pad(t, p)
|
|
# We have already taken care of padding
|
|
kwargs.pop("padding")
|
|
|
|
# second input is flipped in SciPy's convolve2d
|
|
weight_flipped = torch.flip(weight, (2, 3))
|
|
actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0)
|
|
if mode == "same":
|
|
actual = actual[:5, :10]
|
|
|
|
self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
|
|
|
|
# Global dtype for this test suite is torch.double
|
|
# This leads to change in type-promotion
|
|
# and conv1d outputs `complex128` for `complex64` input.
|
|
with set_default_dtype(torch.float):
|
|
_test(t, weight_even, mode)
|
|
_test(t, weight_odd, mode)
|
|
|
|
@unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
|
|
@skipMPS # Results in CI are inconsistent, forced to skip
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@parametrize_test("mode", ("valid", "same"))
|
|
def test_conv3d_vs_scipy(self, device, dtype, mode):
|
|
t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype)
|
|
weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype)
|
|
weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype)
|
|
|
|
def _test(t, weight, mode):
|
|
# SciPy expects two 3-D inputs.
|
|
t_a = t.squeeze(0).cpu().numpy()
|
|
w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
|
|
expected = scipy.signal.convolve(t_a, w_a, mode=mode)
|
|
|
|
kwargs = {"padding": mode}
|
|
if mode == "same":
|
|
# `same` padding in PyTorch conv3d is different
|
|
# from SciPy
|
|
left_right_pad = weight.shape[4] // 2
|
|
top_bottom_pad = weight.shape[3] // 2
|
|
front_back_pad = weight.shape[2] // 2
|
|
p = (
|
|
left_right_pad,
|
|
left_right_pad,
|
|
top_bottom_pad,
|
|
top_bottom_pad,
|
|
front_back_pad,
|
|
front_back_pad,
|
|
)
|
|
t = torch.nn.functional.pad(t, p)
|
|
# We have already taken care of padding
|
|
kwargs.pop("padding")
|
|
|
|
# second input is flipped in SciPy's convolve
|
|
weight_flipped = torch.flip(weight, (2, 3, 4))
|
|
actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0)
|
|
if mode == "same":
|
|
actual = actual[:5, :5, :10]
|
|
|
|
if torch.cuda.is_tf32_supported() and (
|
|
dtype == torch.float or dtype == torch.complex64
|
|
):
|
|
self.assertEqual(actual, expected, atol=0.05, rtol=0.05)
|
|
else:
|
|
self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
|
|
|
|
# Global dtype for this test suite is torch.double
|
|
# This leads to change in type-promotion
|
|
# and conv1d outputs `complex128` for `complex64` input.
|
|
with set_default_dtype(torch.float):
|
|
_test(t, weight_even, mode)
|
|
_test(t, weight_odd, mode)
|
|
|
|
@dtypes(torch.float, torch.complex64)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
def test_conv2d_valid_padding_backward(self, device, dtype):
|
|
# Test F.conv2d gradients work with padding='valid'
|
|
x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True)
|
|
y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True)
|
|
F.conv2d(x, y, padding=0).sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
F.conv2d(x, y, padding="valid").sum().abs().backward()
|
|
gx_actual, gy_actual = x.grad, y.grad
|
|
self.assertEqual(gx_expect, gx_actual)
|
|
self.assertEqual(gy_expect, gy_actual)
|
|
|
|
@dtypes(torch.double, torch.cdouble)
|
|
@dtypesIfMPS(
|
|
torch.float, torch.cfloat
|
|
) # Double, complex double not supported on MPS
|
|
@expectedFailureMPS # https://github.com/pytorch/pytorch/issues/107214
|
|
def test_conv3d_valid_padding_backward(self, device, dtype):
|
|
check_forward_ad = torch.device(device).type != "xla"
|
|
|
|
# Test F.conv3d gradients work with padding='valid'
|
|
x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True)
|
|
y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True)
|
|
F.conv3d(x, y, padding=0).sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
F.conv3d(x, y, padding="valid").sum().abs().backward()
|
|
gx_actual, gy_actual = x.grad, y.grad
|
|
self.assertEqual(gx_expect, gx_actual)
|
|
self.assertEqual(gy_expect, gy_actual)
|
|
|
|
gradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="valid"),
|
|
(x, y),
|
|
check_forward_ad=check_forward_ad,
|
|
)
|
|
gradgradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="valid"),
|
|
(x, y),
|
|
check_fwd_over_rev=check_forward_ad,
|
|
)
|
|
|
|
@parametrize_test(
|
|
arg_str="N",
|
|
arg_values=[
|
|
subtest(arg_values=(2), name="ConvTranspose2d"),
|
|
subtest(arg_values=(3), name="ConvTranspose3d"),
|
|
],
|
|
)
|
|
def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
|
|
# For inputs with no batch dim, verify output is the correct shape when output_size is set.
|
|
# See https://github.com/pytorch/pytorch/issues/75889
|
|
inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device)
|
|
output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200)
|
|
ConvTransposeNd = getattr(nn, f"ConvTranspose{N}d")
|
|
m = ConvTransposeNd(
|
|
1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device
|
|
)
|
|
output = m(inp, output_size=output_size)
|
|
self.assertEqual(output.shape, output_size)
|
|
|
|
@skipMeta
|
|
@parametrize_test(
|
|
"input_shape,transposed,dilated,groups,layout,backend_expected",
|
|
[
|
|
# === slow ===
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Slow2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowTranspose2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow1d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowDilated2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow1d_dilated",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
True,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowTranspose2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow1d_dilated_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Slow2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowTranspose2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow2d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowDilated2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow2d_dilated",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
True,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowTranspose2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow2d_dilated_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Slow3d,
|
|
),
|
|
decorators=[onlyCPU, disableMkldnn],
|
|
name="slow3d_cpu",
|
|
),
|
|
# CUDA doesn't have a slow 3D implementation, so it goes to the dilated 3D implementation instead
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowDilated3d,
|
|
),
|
|
decorators=[onlyCUDA, disablecuDNN],
|
|
name="slow3d_cuda",
|
|
),
|
|
# FIXME: RuntimeError: CUDA out of memory.
|
|
# subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
|
|
# decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowDilated3d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow3d_dilated",
|
|
),
|
|
# FIXME: RuntimeError: CUDA out of memory.
|
|
# subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
|
|
# decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'),
|
|
subtest(
|
|
(
|
|
(0, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_channel1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch_channel1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_channel2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch_channel2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_channel3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch_channel3d",
|
|
),
|
|
# === cuda ===
|
|
# Note that disablecuDNN disables miopen as well.
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.CudaDepthwise2d,
|
|
),
|
|
decorators=[onlyCUDA, disablecuDNN],
|
|
name="cuda_depthwise1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.CudaDepthwise2d,
|
|
),
|
|
decorators=[onlyCUDA, disablecuDNN],
|
|
name="cuda_depthwise2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.CudaDepthwise3d,
|
|
),
|
|
decorators=[onlyCUDA, disablecuDNN],
|
|
name="cuda_depthwise3d",
|
|
),
|
|
# === cudnn ===
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Cudnn,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
|
|
name="cudnn1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Cudnn,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
|
|
name="cudnn2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Cudnn,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
|
|
name="cudnn3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.CudnnTranspose,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
|
|
name="cudnn1d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.CudnnTranspose,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
|
|
name="cudnn2d_transposed",
|
|
),
|
|
# FIXME: RuntimeError: CUDA out of memory.
|
|
# subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose),
|
|
# decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'),
|
|
# === miopen ===
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Miopen,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Miopen,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Miopen,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenTranspose,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen1d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenTranspose,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen2d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenTranspose,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen3d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenDepthwise,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen_depthwise1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenDepthwise,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen_depthwise2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenDepthwise,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen_depthwise3d",
|
|
),
|
|
# === mkldnn ===
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn3d",
|
|
),
|
|
# Transposed convolution is broken for mkldnn. See https://github.com/pytorch/pytorch/issues/68775.
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
True,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
|
|
name="mkldnn1d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
True,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
|
|
name="mkldnn2d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
True,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
|
|
name="mkldnn3d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn1d_cpu_input",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn2d_cpu_input",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn3d_cpu_input",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_channel1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch_channel1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_channel2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch_channel2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_channel3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch_channel3d",
|
|
),
|
|
# Note: Tests for mobile backends are not currently supported. This comprises
|
|
# NnpackSpatial, Winograd3x3Depthwise, and Xnnpack2d backends. Testing these
|
|
# requires the ability to gate tests by whether PyTorch is built with USE_MOBILE=1.
|
|
],
|
|
)
|
|
# Test with both bias and no bias.
|
|
@parametrize_test("has_bias", [False, True])
|
|
# Test with both stride=1 and stride>1 cases.
|
|
@parametrize_test("strided", [False, True])
|
|
# Test with both contiguous and non-contiguous inputs.
|
|
@parametrize_test("contiguous", [False, True])
|
|
@expectedFailureMPS # No double support
|
|
def test_conv_backend(
|
|
self,
|
|
device,
|
|
input_shape,
|
|
has_bias,
|
|
strided,
|
|
contiguous,
|
|
transposed,
|
|
dilated,
|
|
groups,
|
|
layout,
|
|
backend_expected,
|
|
):
|
|
# Build up inputs.
|
|
dtype = torch.float32
|
|
C_in, C_out, dim, kernel_size = input_shape[1], 12, len(input_shape) - 2, 3
|
|
x = torch.randn(*input_shape, device=device, dtype=dtype, requires_grad=True)
|
|
weight = torch.randn(
|
|
C_in if transposed else C_out,
|
|
C_out // groups if transposed else C_in // groups,
|
|
*[kernel_size for _ in range(dim)],
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
)
|
|
bias = (
|
|
torch.randn(C_out, device=device, dtype=dtype, requires_grad=True)
|
|
if has_bias
|
|
else None
|
|
)
|
|
|
|
def _make_noncontiguous(inp):
|
|
if inp is None:
|
|
return None
|
|
old_requires_grad = inp.requires_grad
|
|
inp = torch.repeat_interleave(inp, 2, dim=-1)
|
|
inp = inp[..., ::2].detach().requires_grad_(old_requires_grad)
|
|
return inp
|
|
|
|
if not contiguous:
|
|
x = _make_noncontiguous(x)
|
|
weight = _make_noncontiguous(weight)
|
|
bias = _make_noncontiguous(bias)
|
|
|
|
if layout is torch._mkldnn:
|
|
x = x.to_mkldnn()
|
|
# Note that weight and bias are not supported as mkldnn tensors during training.
|
|
|
|
stride = (2,) * dim if strided else (1,) * dim
|
|
padding = (0,) * dim
|
|
dilation = (2,) * dim if dilated else (1,) * dim
|
|
output_padding = (0,) * dim
|
|
inputs = [
|
|
x,
|
|
weight,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
]
|
|
|
|
# Ensure correct backend is selected.
|
|
backend_actual = torch._C._select_conv_backend(*inputs)
|
|
self.assertEqual(backend_actual, backend_expected)
|
|
|
|
# Ensure backward call succeeds.
|
|
convolution = torch.ops.aten.convolution
|
|
output = convolution(*inputs)
|
|
grad_output = torch.randn(output.shape, device=device, dtype=dtype)
|
|
if not contiguous:
|
|
grad_output = _make_noncontiguous(grad_output)
|
|
if layout is torch._mkldnn:
|
|
grad_output = grad_output.to_mkldnn()
|
|
output.backward(grad_output)
|
|
|
|
# mkldnn doesn't support gradcheck :(
|
|
if layout is torch._mkldnn:
|
|
return
|
|
|
|
if backend_actual != torch._C._ConvBackend.Empty: # FIXME: forward AD fails
|
|
# Forward AD and forward-over-reverse AD smoke test in float32
|
|
# TODO: remove this if we introduce per-op gradient tests for float32
|
|
with fwAD.dual_level():
|
|
dual_inputs = [
|
|
(
|
|
fwAD.make_dual(i, torch.rand_like(i))
|
|
if isinstance(i, torch.Tensor)
|
|
else i
|
|
)
|
|
for i in inputs
|
|
]
|
|
# Forward AD
|
|
output = convolution(*dual_inputs)
|
|
# Forward over reverse AD
|
|
grad_output_d = fwAD.make_dual(
|
|
torch.rand_like(output), torch.rand_like(output)
|
|
)
|
|
if has_bias:
|
|
torch.autograd.grad(output, [x, weight, bias], grad_output_d)
|
|
else:
|
|
torch.autograd.grad(output, [x, weight], grad_output_d)
|
|
|
|
# Convert to float64 for gradcheck.
|
|
x = x.to(torch.float64).detach().requires_grad_(True)
|
|
weight = weight.to(torch.float64).detach().requires_grad_(True)
|
|
if bias is not None:
|
|
bias = bias.to(torch.float64).detach().requires_grad_(True)
|
|
inputs = [
|
|
x,
|
|
weight,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
]
|
|
|
|
# Set some backend-specific validation settings.
|
|
gradcheck_nondet_tol = 0.0
|
|
if torch.backends.cudnn.is_available():
|
|
# cuDNN introduces non-determinism
|
|
gradcheck_nondet_tol = GRADCHECK_NONDET_TOL
|
|
|
|
self.assertTrue(gradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol))
|
|
|
|
# double backward doesn't support bias gradients
|
|
if bias is not None:
|
|
bias.requires_grad_(False)
|
|
self.assertTrue(
|
|
gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)
|
|
)
|
|
|
|
@onlyCPU
|
|
def test_conv_contiguous_for_oneDNN(self):
|
|
# See https://github.com/pytorch/pytorch/issues/80837.
|
|
for dtype in [torch.float, torch.bfloat16, torch.half]:
|
|
conv = nn.Conv2d(
|
|
1,
|
|
128,
|
|
kernel_size=(5, 2),
|
|
stride=(2, 1),
|
|
padding=(0, 1),
|
|
dilation=(1, 1),
|
|
groups=1,
|
|
bias=True,
|
|
padding_mode="zeros",
|
|
).to(dtype=dtype)
|
|
|
|
x = torch.rand([1, 2, 321, 201, 1]).to(dtype=dtype)
|
|
x = torch.transpose(x, 1, 4)
|
|
x2 = x[..., 0]
|
|
if torch.backends.mkldnn.is_available():
|
|
y = conv(x2)
|
|
# Disable MKLDNN explicitly
|
|
with torch.backends.mkldnn.flags(enabled=False):
|
|
y_ = conv(x2)
|
|
self.assertEqual(y, y_)
|
|
|
|
@onlyCPU
|
|
def test_conv_ic1_channels_last_for_oneDNN(self):
|
|
# See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path.
|
|
for dtype in [torch.float, torch.bfloat16, torch.half]:
|
|
conv = torch.nn.Conv2d(
|
|
1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False
|
|
)
|
|
conv = conv.to(memory_format=torch.channels_last).to(dtype=dtype)
|
|
x = torch.rand(2, 1, 100, 100).to(dtype=dtype)
|
|
if torch.backends.mkldnn.is_available():
|
|
y = conv(x)
|
|
# Disable MKLDNN explicitly
|
|
with torch.backends.mkldnn.flags(enabled=False):
|
|
y_ = conv(x)
|
|
self.assertEqual(y, y_)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
def test_conv_empty_channel(self, device, dtype):
|
|
in_channels = 0
|
|
mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device)
|
|
inp = torch.randn(2, 0, 15, device=device, dtype=dtype)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
|
inp = torch.randn(2, 1, 0, device=device, dtype=dtype)
|
|
mod(inp)
|
|
|
|
mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
|
|
inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
|
inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype)
|
|
mod(inp)
|
|
|
|
mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
|
|
inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
|
inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype)
|
|
mod(inp)
|
|
|
|
def test_group_conv_empty(self, device):
|
|
mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(
|
|
device
|
|
)
|
|
inp = torch.randn(0, 4, 4, 4, device=device)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
if self.device_type == "cuda" and self.has_cudnn():
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
def test_group_convTranspose_empty(self, device):
|
|
mod = torch.nn.ConvTranspose2d(
|
|
4, 4, stride=2, kernel_size=3, padding=1, groups=4
|
|
).to(device)
|
|
inp = torch.randn(0, 4, 4, 4, device=device)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
if self.device_type == "cuda" and self.has_cudnn():
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
def test_convTranspose_empty(self, device):
|
|
mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to(
|
|
device
|
|
)
|
|
inp = torch.randn(0, 4, 4, 4, device=device)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
if self.device_type == "cuda" and self.has_cudnn():
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
@onlyCUDA
|
|
@largeTensorTest("12GB")
|
|
def test_conv_large_nosplit(self, device):
|
|
# Here we just test the convolution correctly route to the fallback implementation
|
|
# that is, it does not crash. The correctness of fallback implementation should be
|
|
# covered in other tests
|
|
dtype = torch.half if self.device_type == "cuda" else torch.float
|
|
conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype)
|
|
input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device)
|
|
conv1(input_large)
|
|
conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype)
|
|
input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device)
|
|
conv2(input_large)
|
|
|
|
def test_conv_noncontig_weights(self, device):
|
|
for dim in (1, 2, 3):
|
|
for grouped in (False, True):
|
|
nc = 3
|
|
groups = 3 if grouped else 1
|
|
w = torch.randn([3] * dim, device=device)
|
|
w = w.expand([nc, int(nc / groups)] + list(w.shape))
|
|
w = w.detach().requires_grad_()
|
|
x = torch.randn(
|
|
[1, nc] + ([5] * dim), device=device, requires_grad=True
|
|
)
|
|
y = getattr(F, f"conv{dim}d")(x, w, groups=groups)
|
|
y.sum().backward()
|
|
y = getattr(F, f"conv_transpose{dim}d")(x, w, groups=groups)
|
|
y.sum().backward()
|
|
|
|
def test_conv_noncontig_weights_and_bias(self, device):
|
|
# need floats to exercise https://github.com/pytorch/pytorch/issues/16018
|
|
for bias in [True, False]:
|
|
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=bias).to(
|
|
device, torch.float
|
|
)
|
|
|
|
input_nc = torch.randn(
|
|
(1, 3, 224, 224, 2), device=device, dtype=torch.float
|
|
)[:, :, :, :, 1]
|
|
input_c = input_nc.contiguous()
|
|
|
|
weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[
|
|
:, :, :, :, 1
|
|
]
|
|
conv1.weight = nn.Parameter(weight_nc)
|
|
weight_c = conv1.weight.contiguous()
|
|
|
|
if bias:
|
|
bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1]
|
|
conv1.bias = nn.Parameter(bias_nc)
|
|
bias_c = conv1.bias.contiguous()
|
|
|
|
out1 = conv1(input_nc)
|
|
conv1.weight = nn.Parameter(weight_c)
|
|
if bias:
|
|
conv1.bias = nn.Parameter(bias_c)
|
|
out2 = conv1(input_c)
|
|
self.assertEqual(out1, out2)
|
|
|
|
@onlyCUDA
|
|
@largeTensorTest("12GB")
|
|
def test_conv_transposed_large(self, device):
|
|
dtype = torch.half if self.device_type == "cuda" else torch.float
|
|
conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
|
|
input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device)
|
|
# forward
|
|
ret = conv(input_large)
|
|
maxdiff0 = (
|
|
(ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024)))
|
|
.abs_()
|
|
.max()
|
|
.item()
|
|
)
|
|
maxdiff1 = (
|
|
(ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024)))
|
|
.abs_()
|
|
.max()
|
|
.item()
|
|
)
|
|
maxdiff2 = (
|
|
(ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024)))
|
|
.abs_()
|
|
.max()
|
|
.item()
|
|
)
|
|
maxdiff3 = (
|
|
(ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024)))
|
|
.abs_()
|
|
.max()
|
|
.item()
|
|
)
|
|
if self.device_type == "cuda":
|
|
# cuDNN may use algorithms such as FFT that don't guarantee a diff of 0
|
|
self.assertEqual(maxdiff0, 0, atol=2e-3, rtol=1e-5)
|
|
self.assertEqual(maxdiff1, 0, atol=2e-3, rtol=1e-5)
|
|
self.assertEqual(maxdiff2, 0, atol=2e-3, rtol=1e-5)
|
|
self.assertEqual(maxdiff3, 0, atol=2e-3, rtol=1e-5)
|
|
else:
|
|
self.assertEqual(maxdiff0, 0)
|
|
self.assertEqual(maxdiff1, 0)
|
|
self.assertEqual(maxdiff2, 0)
|
|
self.assertEqual(maxdiff3, 0)
|
|
|
|
@onlyCUDA
|
|
@largeTensorTest("12GB")
|
|
def test_conv_large(self, device):
|
|
dtype = torch.half if self.device_type == "cuda" else torch.float
|
|
conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype)
|
|
input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device)
|
|
# forward
|
|
ret = conv(input_large)
|
|
self.assertEqual(ret[:2048], conv(input_large[:2048]))
|
|
self.assertEqual(ret[2048:4096], conv(input_large[2048:4096]))
|
|
self.assertEqual(ret[4096:], conv(input_large[4096:]))
|
|
|
|
# backward
|
|
conv.zero_grad()
|
|
# When computing the backward, we are using the `max(dim=1)`` to create
|
|
# some sparsity. Without this sparsity, the rounding error would be
|
|
# too large (as large as 1e-5) to satisfy the creterion (1e-6) of `assertEqual`
|
|
ret.view(4097, -1).max(dim=1).values.sum().backward()
|
|
del ret
|
|
grad1 = conv.weight.grad.detach().clone()
|
|
conv.zero_grad()
|
|
conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward()
|
|
conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward()
|
|
conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward()
|
|
grad2 = conv.weight.grad.detach().clone()
|
|
# gradients are at the order of hundreds, we need to scale it to
|
|
# the order of one so that we can compare
|
|
scale = 1 / grad2.abs().mean()
|
|
grad1 = grad1 * scale
|
|
grad2 = grad2 * scale
|
|
self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
|
|
|
|
@onlyCUDA
|
|
@largeTensorTest("20GB", "cpu")
|
|
@largeTensorTest("60GB", "cuda")
|
|
def test_conv_large_batch_1(self, device):
|
|
in_channels = 514
|
|
dim = 2048
|
|
out_channels = 1
|
|
kernel_size = 3
|
|
stride = 1
|
|
padding = 1
|
|
|
|
input_tensor = torch.ones(1, in_channels, dim, dim).cuda().half()
|
|
model = (
|
|
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
|
.cuda()
|
|
.half()
|
|
)
|
|
output = model(input_tensor)
|
|
_model_cpu = model.cpu().float()
|
|
output_cpu = model(input_tensor.float().cpu())
|
|
self.assertEqual(output.cpu().float(), output_cpu, atol=1e-3, rtol=1e-3)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfNoCudnn
|
|
def test_contig_wrong_stride_cudnn(self, device):
|
|
# x has to have batch_size 1 to test contiguous checks
|
|
x = torch.randn(1, 16, 5, 5, device=device)
|
|
stride = list(x.stride())
|
|
stride[0] = 20
|
|
# change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
|
|
x.set_(x.storage(), 0, x.size(), stride)
|
|
self.assertTrue(x.is_contiguous())
|
|
F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device))
|
|
F.conv2d(x, torch.randn(1, 16, 1, 1, device=device))
|
|
|
|
@onlyCUDA
|
|
@tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005)
|
|
def test_Conv2d_size_1_kernel(self, device):
|
|
x_cpu = torch.randn(2, 3, 5, 5)
|
|
conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1)
|
|
y_cpu = conv_cpu(x_cpu)
|
|
y = torch.rand_like(y_cpu)
|
|
y_cpu.backward(y)
|
|
|
|
with cudnn.flags(enabled=False):
|
|
conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device)
|
|
conv_cuda.bias.data.copy_(conv_cpu.bias.data)
|
|
conv_cuda.weight.data.copy_(conv_cpu.weight.data)
|
|
y_cuda = conv_cuda(x_cpu.to(device))
|
|
y_cuda.backward(y.to(device))
|
|
|
|
self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
|
|
self.assertEqual(
|
|
conv_cpu.bias.grad.data,
|
|
conv_cuda.bias.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
self.assertEqual(
|
|
conv_cpu.weight.grad.data,
|
|
conv_cuda.weight.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
|
|
@onlyCUDA
|
|
@tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005)
|
|
def test_ConvTranspose2d_size_1_kernel(self, device):
|
|
x_cpu = torch.randn(2, 3, 5, 5)
|
|
conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1)
|
|
y_cpu = conv_cpu(x_cpu)
|
|
y = torch.rand_like(y_cpu)
|
|
y_cpu.backward(y)
|
|
|
|
with cudnn.flags(enabled=False):
|
|
conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device)
|
|
conv_cuda.bias.data.copy_(conv_cpu.bias.data)
|
|
conv_cuda.weight.data.copy_(conv_cpu.weight.data)
|
|
y_cuda = conv_cuda(x_cpu.to(device))
|
|
y_cuda.backward(y.to(device))
|
|
|
|
self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
|
|
self.assertEqual(
|
|
conv_cpu.bias.grad.data,
|
|
conv_cuda.bias.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
self.assertEqual(
|
|
conv_cpu.weight.grad.data,
|
|
conv_cuda.weight.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
|
|
@onlyCUDA
|
|
def test_ConvTranspose3d_size_1_kernel(self, device):
|
|
with set_default_dtype(torch.double):
|
|
x_cpu = torch.randn(2, 3, 3, 5, 5)
|
|
conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1)
|
|
y_cpu = conv_cpu(x_cpu)
|
|
y = torch.rand_like(y_cpu)
|
|
y_cpu.backward(y)
|
|
|
|
with cudnn.flags(enabled=False):
|
|
conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device)
|
|
conv_cuda.bias.data.copy_(conv_cpu.bias.data)
|
|
conv_cuda.weight.data.copy_(conv_cpu.weight.data)
|
|
y_cuda = conv_cuda(x_cpu.to(device))
|
|
y_cuda.backward(y.to(device))
|
|
|
|
self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
|
|
self.assertEqual(
|
|
conv_cpu.bias.grad.data,
|
|
conv_cuda.bias.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
self.assertEqual(
|
|
conv_cpu.weight.grad.data,
|
|
conv_cuda.weight.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
|
|
@dtypesIfCUDA(
|
|
*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
|
|
)
|
|
@dtypes(torch.float)
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@tf32_on_and_off(0.001)
|
|
def test_Conv2d_naive_groups(self, device, dtype):
|
|
# Check that grouped convolutions matches two half convolutions
|
|
m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
|
|
i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
|
|
output = m(i)
|
|
grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
|
|
output.backward(grad_output)
|
|
|
|
m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
|
|
m1.weight.data.copy_(m.weight.data[:2])
|
|
m1.bias.data.copy_(m.bias.data[:2])
|
|
i1 = i.data[:, :2].contiguous().requires_grad_(True)
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :2].contiguous())
|
|
|
|
m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
|
|
m2.weight.data.copy_(m.weight.data[2:])
|
|
m2.bias.data.copy_(m.bias.data[2:])
|
|
i2 = i.data[:, 2:].contiguous().requires_grad_(True)
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, 2:].contiguous())
|
|
|
|
self.assertEqual(output, torch.cat([output1, output2], 1))
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.bias.grad.data,
|
|
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
|
|
@dtypes(torch.double, torch.cdouble)
|
|
@dtypesIfMPS(torch.float, torch.cfloat)
|
|
@expectedFailureMPS # https://github.com/pytorch/pytorch/issues/107214
|
|
def test_Conv2d_backward_depthwise(self, device, dtype):
|
|
x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True)
|
|
weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True)
|
|
|
|
def conv2d_depthwise(x, weight):
|
|
return torch.nn.functional.conv2d(
|
|
x, weight, bias=None, stride=(1, 10), groups=2
|
|
)
|
|
|
|
for cudnn_enabled in [False, True]:
|
|
with torch.backends.cudnn.flags(enabled=cudnn_enabled):
|
|
torch.autograd.gradcheck(conv2d_depthwise, (x, weight))
|
|
|
|
@onlyCPU
|
|
@dtypes(torch.float, torch.double)
|
|
def test_conv_thnn_nhwc(self, device, dtype):
|
|
def helper(
|
|
mod,
|
|
n,
|
|
c,
|
|
h,
|
|
w,
|
|
out_channels,
|
|
kernel_size,
|
|
dilation,
|
|
groups,
|
|
input_format,
|
|
weight_format,
|
|
):
|
|
input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
|
|
memory_format=input_format
|
|
)
|
|
input.requires_grad_()
|
|
conv = mod(
|
|
c, out_channels, kernel_size, dilation=dilation, groups=groups
|
|
).to(device="cpu", dtype=dtype, memory_format=weight_format)
|
|
for p in conv.parameters():
|
|
p.data = torch.randint_like(p, -3, 3)
|
|
|
|
ref_input = input.detach().clone().contiguous().requires_grad_()
|
|
ref_conv = mod(
|
|
c, out_channels, kernel_size, dilation=dilation, groups=groups
|
|
)
|
|
# load_state_dict will restore the stride & memory_layout on ref_conv.weight.
|
|
ref_conv.load_state_dict(conv.state_dict())
|
|
ref_conv = ref_conv.to(
|
|
device="cpu", dtype=dtype, memory_format=torch.contiguous_format
|
|
)
|
|
|
|
out = conv(input)
|
|
ref_out = ref_conv(ref_input)
|
|
|
|
grad = torch.randint_like(out, -3, 3)
|
|
ref_grad = grad.detach().clone().contiguous()
|
|
|
|
out.backward(grad)
|
|
ref_out.backward(ref_grad)
|
|
|
|
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
|
self.assertTrue(ref_out.is_contiguous())
|
|
self.assertEqual(out, ref_out, exact_dtype=False)
|
|
self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
|
|
self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
|
|
self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
|
|
|
|
with torch.backends.mkldnn.flags(enabled=False):
|
|
formats = [
|
|
[torch.channels_last, torch.channels_last],
|
|
[torch.channels_last, torch.contiguous_format],
|
|
[torch.contiguous_format, torch.channels_last],
|
|
]
|
|
for input_format, weight_format in formats:
|
|
# non-dilated conv: thnn_conv2d normal path (with im2col)
|
|
helper(
|
|
nn.Conv2d,
|
|
2,
|
|
8,
|
|
4,
|
|
4,
|
|
out_channels=4,
|
|
kernel_size=3,
|
|
dilation=1,
|
|
groups=1,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
helper(
|
|
nn.Conv2d,
|
|
2,
|
|
8,
|
|
4,
|
|
4,
|
|
out_channels=8,
|
|
kernel_size=3,
|
|
dilation=1,
|
|
groups=8,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
# test when input chanels is 1 and not converted to channels last
|
|
helper(
|
|
nn.Conv2d,
|
|
2,
|
|
1,
|
|
10,
|
|
10,
|
|
out_channels=8,
|
|
kernel_size=3,
|
|
dilation=1,
|
|
groups=1,
|
|
input_format=torch.contiguous_format,
|
|
weight_format=torch.channels_last,
|
|
)
|
|
# non-dilated conv: thnn_conv2d fast path (skip im2col)
|
|
helper(
|
|
nn.Conv2d,
|
|
1,
|
|
16,
|
|
56,
|
|
56,
|
|
out_channels=16,
|
|
kernel_size=1,
|
|
dilation=1,
|
|
groups=1,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
# ic == oc == 1 here, so need to stick input to CL to activate channels last
|
|
helper(
|
|
nn.Conv2d,
|
|
1,
|
|
16,
|
|
56,
|
|
56,
|
|
out_channels=16,
|
|
kernel_size=1,
|
|
dilation=1,
|
|
groups=16,
|
|
input_format=torch.channels_last,
|
|
weight_format=weight_format,
|
|
)
|
|
# dilated conv: slow_conv_dilated2d
|
|
helper(
|
|
nn.Conv2d,
|
|
2,
|
|
8,
|
|
11,
|
|
13,
|
|
out_channels=16,
|
|
kernel_size=3,
|
|
dilation=2,
|
|
groups=1,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
helper(
|
|
nn.Conv2d,
|
|
2,
|
|
16,
|
|
11,
|
|
13,
|
|
out_channels=32,
|
|
kernel_size=3,
|
|
dilation=2,
|
|
groups=16,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
# transposed-conv: slow_conv_transpose2d
|
|
helper(
|
|
nn.ConvTranspose2d,
|
|
2,
|
|
8,
|
|
4,
|
|
4,
|
|
out_channels=4,
|
|
kernel_size=3,
|
|
dilation=1,
|
|
groups=1,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
helper(
|
|
nn.ConvTranspose2d,
|
|
2,
|
|
8,
|
|
4,
|
|
4,
|
|
out_channels=8,
|
|
kernel_size=3,
|
|
dilation=1,
|
|
groups=8,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
helper(
|
|
nn.ConvTranspose2d,
|
|
1,
|
|
16,
|
|
56,
|
|
56,
|
|
out_channels=16,
|
|
kernel_size=1,
|
|
dilation=1,
|
|
groups=1,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
helper(
|
|
nn.ConvTranspose2d,
|
|
1,
|
|
16,
|
|
56,
|
|
56,
|
|
out_channels=32,
|
|
kernel_size=1,
|
|
dilation=1,
|
|
groups=16,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.half, torch.float, torch.cfloat)
|
|
def test_conv_cudnn_nhwc(self, device, dtype):
|
|
def helper(n, c, h, w, out_channels, kernel_size, groups):
|
|
# randint with dtype=torch.cfloat fails with
|
|
# RuntimeError: check_random_bounds handles only integral, floating-point and boolean types
|
|
# must create randint and randint_like using default int64, then cast to desired
|
|
input = torch.randint(
|
|
-3, 3, (n, c, h, w), dtype=torch.int64, device=device
|
|
).to(dtype, memory_format=torch.channels_last)
|
|
input.requires_grad_()
|
|
conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
|
|
device="cuda", dtype=dtype, memory_format=torch.channels_last
|
|
)
|
|
for p in conv.parameters():
|
|
p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype)
|
|
|
|
# use FP64 channels-first conv as reference
|
|
ref_input = input.detach().clone().contiguous().double().requires_grad_()
|
|
ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)
|
|
# load_state_dict will restore the stride & memory_layout on ref_conv.weight.
|
|
ref_conv.load_state_dict(conv.state_dict())
|
|
ref_conv = ref_conv.to(
|
|
device="cuda", dtype=torch.double, memory_format=torch.contiguous_format
|
|
)
|
|
|
|
out = conv(input)
|
|
ref_out = ref_conv(ref_input)
|
|
|
|
grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype)
|
|
ref_grad = grad.detach().clone().double().contiguous()
|
|
|
|
out.backward(grad)
|
|
ref_out.backward(ref_grad)
|
|
|
|
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
|
self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last))
|
|
self.assertTrue(
|
|
conv.weight.grad.is_contiguous(memory_format=torch.channels_last)
|
|
)
|
|
|
|
self.assertTrue(ref_out.is_contiguous())
|
|
self.assertTrue(ref_input.grad.is_contiguous())
|
|
self.assertTrue(ref_conv.weight.grad.is_contiguous())
|
|
|
|
self.assertEqual(out, ref_out, exact_dtype=False)
|
|
self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
|
|
self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
|
|
self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
|
|
|
|
helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1)
|
|
helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8)
|
|
helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1)
|
|
helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.half, torch.float)
|
|
def test_conv_cudnn_ndhwc(self, device, dtype):
|
|
def helper(n, c, d, h, w, out_channels, kernel_size, groups):
|
|
input = torch.randint(
|
|
-2, 2, (n, c, d, h, w), dtype=dtype, device=device
|
|
).to(memory_format=torch.channels_last_3d)
|
|
input.requires_grad_()
|
|
conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups).to(
|
|
device="cuda", dtype=dtype, memory_format=torch.channels_last_3d
|
|
)
|
|
for p in conv.parameters():
|
|
p.data = torch.randint_like(p, -2, 2)
|
|
|
|
# use FP64 channels-first conv as reference
|
|
ref_input = input.detach().clone().contiguous().double().requires_grad_()
|
|
ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)
|
|
# load_state_dict will restore the stride & memory_layout on ref_conv.weight.
|
|
ref_conv.load_state_dict(conv.state_dict())
|
|
ref_conv = ref_conv.to(
|
|
device="cuda", dtype=torch.double, memory_format=torch.contiguous_format
|
|
)
|
|
|
|
out = conv(input)
|
|
ref_out = ref_conv(ref_input)
|
|
|
|
grad = torch.randint_like(out, -2, 2)
|
|
ref_grad = grad.detach().clone().double().contiguous()
|
|
|
|
out.backward(grad)
|
|
ref_out.backward(ref_grad)
|
|
|
|
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d))
|
|
self.assertTrue(
|
|
input.grad.is_contiguous(memory_format=torch.channels_last_3d)
|
|
)
|
|
self.assertTrue(
|
|
conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d)
|
|
)
|
|
|
|
self.assertTrue(ref_out.is_contiguous())
|
|
self.assertTrue(ref_input.grad.is_contiguous())
|
|
self.assertTrue(ref_conv.weight.grad.is_contiguous())
|
|
|
|
self.assertEqual(out, ref_out, exact_dtype=False)
|
|
self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
|
|
self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
|
|
self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
|
|
|
|
helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1)
|
|
helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8)
|
|
helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1)
|
|
helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16)
|
|
|
|
def _run_conv(
|
|
self,
|
|
layer,
|
|
device,
|
|
inp,
|
|
grad,
|
|
ref_conv,
|
|
ref_input,
|
|
ref_out,
|
|
input_format,
|
|
weight_format,
|
|
grad_format,
|
|
output_format,
|
|
):
|
|
conv = (
|
|
layer(inp.size(1), grad.size(1), ref_conv.weight.size(2)).float().to(device)
|
|
)
|
|
# load_state_dict will restore the stride & memory_layout on ref_conv.weight.
|
|
conv.load_state_dict(ref_conv.state_dict())
|
|
weight_data = (
|
|
conv.weight.detach().clone().contiguous(memory_format=weight_format)
|
|
)
|
|
conv.weight.data = weight_data.resize_(
|
|
weight_data.size(), memory_format=weight_format
|
|
)
|
|
input = inp.clone().contiguous(memory_format=input_format)
|
|
input.resize_(input.size(), memory_format=input_format)
|
|
input = input.requires_grad_()
|
|
grad = grad.contiguous(memory_format=grad_format)
|
|
grad.resize_(grad.size(), memory_format=grad_format)
|
|
out = conv(input)
|
|
out.backward(grad)
|
|
self.assertTrue(out.is_contiguous(memory_format=output_format))
|
|
self.assertEqual(out, ref_out)
|
|
self.assertEqual(conv.weight.grad, ref_conv.weight.grad)
|
|
self.assertEqual(conv.bias.grad, ref_conv.bias.grad)
|
|
self.assertEqual(input.grad, ref_input.grad)
|
|
|
|
def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
|
|
data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device)
|
|
ref_input = data.clone().contiguous().requires_grad_(True)
|
|
ref_conv = layer(c, k, filter_size).float().to(device)
|
|
ref_out = ref_conv(ref_input)
|
|
grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device="cuda")
|
|
ref_out.backward(grad)
|
|
|
|
for w_f in [torch.contiguous_format, torch.channels_last]:
|
|
for g_f in [torch.contiguous_format, torch.channels_last]:
|
|
for input_format in [torch.contiguous_format, torch.channels_last]:
|
|
output_format = torch.contiguous_format
|
|
# Older versions of CudNN have Channels Last support disabled
|
|
if torch.backends.cudnn.version() >= 7603:
|
|
if input_format == torch.channels_last:
|
|
output_format = torch.channels_last
|
|
# This is because we have N111 weight that cannot handle
|
|
# the ambiguous memory_format
|
|
if w_f == torch.channels_last:
|
|
if layer == nn.Conv2d and filter_size * c != 1:
|
|
output_format = torch.channels_last
|
|
if layer == nn.ConvTranspose2d and filter_size * k != 1:
|
|
output_format = torch.channels_last
|
|
self._run_conv(
|
|
layer,
|
|
device,
|
|
data,
|
|
grad,
|
|
ref_conv,
|
|
ref_input,
|
|
ref_out,
|
|
input_format,
|
|
w_f,
|
|
g_f,
|
|
output_format,
|
|
)
|
|
|
|
@onlyCUDA
|
|
@tf32_on_and_off(0.05)
|
|
def test_conv_cudnn_mismatch_memory_format(self, device):
|
|
configs = [
|
|
[4, 2, 8, 8, 4, 2],
|
|
[4, 1, 8, 8, 4, 2],
|
|
[1, 1, 8, 8, 4, 2],
|
|
[4, 2, 2, 8, 4, 1],
|
|
[4, 2, 1, 8, 4, 1],
|
|
[4, 2, 8, 8, 4, 1],
|
|
[4, 1, 8, 8, 4, 1],
|
|
]
|
|
for n, c, h, w, k, filter_size in configs:
|
|
self._test_conv_cudnn_nhwc_nchw(
|
|
nn.Conv2d, n, c, h, w, k, filter_size, device
|
|
)
|
|
self._test_conv_cudnn_nhwc_nchw(
|
|
nn.ConvTranspose2d, n, c, h, w, k, filter_size, device
|
|
)
|
|
|
|
# torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4
|
|
# returning CUDNN_STATUS_BAD_PARAM
|
|
# Disabling that specific test for now [see issue # 33918]
|
|
@onlyCUDA
|
|
@skipCUDAIfNoCudnn
|
|
@dtypes(torch.float, torch.double)
|
|
def test_conv_cudnn_nhwc_support(self, device, dtype):
|
|
input = torch.randn(
|
|
(1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True
|
|
)
|
|
weight = torch.randn(
|
|
(8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True
|
|
)
|
|
weight = weight.to(memory_format=torch.channels_last)
|
|
o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
|
|
self.assertTrue(o.is_contiguous(memory_format=torch.channels_last))
|
|
o.sum().backward()
|
|
|
|
# Test that faster algorithms used for inference produce the same results
|
|
# Validates depthwise3x3 bug reported in https://github.com/pytorch/pytorch/issues/60176
|
|
@onlyCPU
|
|
@dtypes(torch.float)
|
|
def test_conv2d_no_grad(self, device, dtype):
|
|
for batch in [1, 2, 3]:
|
|
for groups in [1, 2, 4]:
|
|
input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device)
|
|
m = nn.Conv2d(
|
|
groups,
|
|
8,
|
|
kernel_size=(3, 3),
|
|
groups=groups,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
with torch.no_grad():
|
|
output_ng = m(input)
|
|
output = m(input)
|
|
self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfNoCudnn
|
|
@dtypes(torch.float, torch.float16)
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@precisionOverride({torch.half: 0.002, torch.float: 1e-4})
|
|
def test_cudnn_convolution_relu(self, device, dtype):
|
|
for batch, groups, image_size, kernel_size, memory_format in product(
|
|
(1, 2, 3),
|
|
(1, 2, 4),
|
|
((1, 1), (8, 8)),
|
|
((1, 1), (3, 3)),
|
|
(torch.channels_last, torch.contiguous_format),
|
|
):
|
|
if image_size[0] < kernel_size[0]:
|
|
continue
|
|
inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
|
|
w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
|
|
conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
|
|
inp = inp.to(memory_format=memory_format)
|
|
w = w.to(memory_format=memory_format)
|
|
if torch.version.hip:
|
|
cudnn_out = torch.miopen_convolution_relu(
|
|
inp, w, None, (1, 1), (0, 0), (1, 1), 1
|
|
)
|
|
else:
|
|
cudnn_out = torch.cudnn_convolution_relu(
|
|
inp, w, None, (1, 1), (0, 0), (1, 1), 1
|
|
)
|
|
self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
|
|
if torch.cuda.is_tf32_supported() and dtype == torch.float:
|
|
self.assertEqual(conv2d_out.relu(), cudnn_out, atol=4e-3, rtol=0.006)
|
|
else:
|
|
self.assertEqual(conv2d_out.relu(), cudnn_out)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfNoCudnn
|
|
@dtypes(torch.float, torch.float16)
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@precisionOverride({torch.half: 0.002, torch.float: 1e-4})
|
|
def test_cudnn_convolution_add_relu(self, device, dtype):
|
|
for batch, groups, image_size, kernel_size, memory_format in product(
|
|
(1, 2, 3),
|
|
(1, 2, 4),
|
|
((1, 1), (8, 8)),
|
|
((1, 1), (3, 3)),
|
|
(torch.channels_last, torch.contiguous_format),
|
|
):
|
|
if image_size[0] < kernel_size[0]:
|
|
continue
|
|
inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
|
|
w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
|
|
conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
|
|
alpha = 2.0
|
|
z = torch.randn_like(conv2d_out)
|
|
|
|
inp = inp.to(memory_format=memory_format)
|
|
w = w.to(memory_format=memory_format)
|
|
z = z.to(memory_format=memory_format)
|
|
if torch.version.hip:
|
|
cudnn_out = torch.miopen_convolution_add_relu(
|
|
inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1
|
|
)
|
|
else:
|
|
cudnn_out = torch.cudnn_convolution_add_relu(
|
|
inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1
|
|
)
|
|
|
|
self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
|
|
if torch.cuda.is_tf32_supported() and dtype == torch.float:
|
|
self.assertEqual(
|
|
F.relu(conv2d_out + alpha * z), cudnn_out, atol=2e-3, rtol=0.006
|
|
)
|
|
else:
|
|
self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)
|
|
|
|
@onlyCUDA
|
|
def test_convert_conv2d_weight_memory_format(self, device):
|
|
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
|
|
model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
|
|
for memory_format in [torch.channels_last, torch.contiguous_format]:
|
|
model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
|
|
out = model(input)
|
|
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
|
|
|
model = (
|
|
nn.Sequential(nn.ConvTranspose2d(8, 4, 3), nn.BatchNorm2d(4))
|
|
.to(device)
|
|
.float()
|
|
)
|
|
for memory_format in [torch.channels_last, torch.contiguous_format]:
|
|
model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
|
|
out = model(input)
|
|
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
|
|
|
@onlyCUDA
|
|
def test_convert_conv3d_weight_memory_format(self, device):
|
|
input = torch.randint(
|
|
1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device
|
|
)
|
|
model = (
|
|
nn.Sequential(nn.ConvTranspose3d(8, 4, 3), nn.BatchNorm3d(4))
|
|
.to(device)
|
|
.float()
|
|
)
|
|
for memory_format in [torch.channels_last_3d, torch.contiguous_format]:
|
|
model = nn.utils.convert_conv3d_weight_memory_format(model, memory_format)
|
|
out = model(input)
|
|
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
|
|
|
def test_conv_double_backward_strided_with_3D_input_and_weight(self, device):
|
|
# Test that _convolution_double_backward() outputs the correct grad shapes
|
|
# for 3D input / weight when stride > 1. This is an ad-hoc regression test for a
|
|
# specific case that was uncovered during the convolution consolidation effort.
|
|
# The test can be safely deleted if _convolution_double_backward() is removed.
|
|
|
|
input = torch.randn(2, 3, 6, device=device)
|
|
weight = torch.randn(3, 3, 3, device=device)
|
|
bias = torch.randn(3, device=device)
|
|
stride = (2,)
|
|
padding = (1,)
|
|
dilation = (1,)
|
|
transposed = False
|
|
output_padding = (0,)
|
|
groups = 1
|
|
output = torch.ops.aten.convolution(
|
|
input,
|
|
weight,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
)
|
|
|
|
ggI = torch.randn(input.shape, device=device)
|
|
ggW = torch.randn(weight.shape, device=device)
|
|
ggB = torch.randn(bias.shape, device=device)
|
|
gO = torch.randn(output.shape, device=device)
|
|
output_mask = [True, True, True]
|
|
(
|
|
grad_grad_output,
|
|
grad_input,
|
|
grad_weight,
|
|
) = torch.ops.aten._convolution_double_backward(
|
|
ggI,
|
|
ggW,
|
|
ggB,
|
|
gO,
|
|
weight,
|
|
input,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
output_mask,
|
|
)
|
|
|
|
# Make sure the correct shapes are computed.
|
|
self.assertEqual(grad_grad_output.shape, gO.shape)
|
|
self.assertEqual(grad_input.shape, input.shape)
|
|
self.assertEqual(grad_weight.shape, weight.shape)
|
|
|
|
@skipCUDAIfRocm
|
|
@onlyCUDA
|
|
@largeTensorTest("40GB")
|
|
@largeTensorTest("24GB", "cpu")
|
|
@tf32_on_and_off(0.005)
|
|
def test_conv3d_64bit_indexing(self, device):
|
|
x = torch.rand(1, 32, 512, 512, 256)
|
|
m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False)
|
|
yref = m(x)
|
|
y = m.to(device=device)(x.to(device=device))
|
|
self.assertEqual(yref, y)
|
|
|
|
@skipCUDAIfRocm
|
|
@onlyCUDA
|
|
@largeTensorTest("20GB")
|
|
@largeTensorTest("64GB", "cpu")
|
|
def test_depthwise_conv_64bit_indexing(self, device):
|
|
x = torch.randn(1, 2, 32800, 32800, dtype=torch.half).to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
c = nn.Conv2d(
|
|
2, 2, kernel_size=3, stride=1, padding=1, groups=2, dtype=torch.half
|
|
).to(memory_format=torch.channels_last)
|
|
yref = c(x)
|
|
y = c.to(device=device)(x.to(device=device))
|
|
self.assertEqual(yref, y, atol=1e-3, rtol=1e-4)
|
|
del y, yref
|
|
|
|
# try a batch-splittable case
|
|
x = x.reshape(100, 2, 3280, 3280).contiguous(memory_format=torch.channels_last)
|
|
yref = c(x)
|
|
y = c.to(device=device)(x.to(device=device))
|
|
self.assertEqual(yref, y, atol=1e-3, rtol=1e-4)
|
|
|
|
|
|
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True)
|
|
instantiate_parametrized_tests(TestConvolutionNN)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|