mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This should help stabilize some flaky test behavior where miopen would pick different solutions for different parts of the same test and the test expects bitwise identical results. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164598 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
4148 lines
155 KiB
Python
4148 lines
155 KiB
Python
# Owner(s): ["module: nn"]
|
|
import itertools
|
|
import math
|
|
import os
|
|
import unittest
|
|
import warnings
|
|
from itertools import product
|
|
|
|
import torch
|
|
import torch.autograd.forward_ad as fwAD
|
|
import torch.backends.cudnn as cudnn
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN, tf32_on_and_off
|
|
from torch.testing._internal.common_device_type import (
|
|
disablecuDNN,
|
|
disableMkldnn,
|
|
dtypes,
|
|
dtypesIfCUDA,
|
|
dtypesIfMPS,
|
|
expectedFailureMPS,
|
|
instantiate_device_type_tests,
|
|
largeTensorTest,
|
|
onlyCPU,
|
|
onlyCUDA,
|
|
onlyNativeDeviceTypes,
|
|
precisionOverride,
|
|
skipCPUIfNoMkldnn,
|
|
skipCUDAIfMiopen,
|
|
skipCUDAIfNoCudnn,
|
|
skipCUDAIfNoMiopen,
|
|
skipCUDAIfRocm,
|
|
skipMeta,
|
|
skipMPS,
|
|
)
|
|
from torch.testing._internal.common_dtype import (
|
|
floating_and_complex_types_and,
|
|
floating_types_and,
|
|
)
|
|
from torch.testing._internal.common_nn import _test_module_empty_input, NNTestCase
|
|
from torch.testing._internal.common_utils import (
|
|
download_file,
|
|
dtype2prec_DONTUSE,
|
|
gradcheck,
|
|
GRADCHECK_NONDET_TOL,
|
|
gradgradcheck,
|
|
instantiate_parametrized_tests,
|
|
MACOS_VERSION,
|
|
parametrize as parametrize_test,
|
|
run_tests,
|
|
set_default_dtype,
|
|
subtest,
|
|
TEST_SCIPY,
|
|
TEST_WITH_ROCM,
|
|
xfailIf,
|
|
)
|
|
|
|
|
|
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
|
|
|
|
|
|
if TEST_WITH_ROCM:
|
|
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
|
|
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
|
|
|
|
|
|
if TEST_SCIPY:
|
|
import scipy.ndimage
|
|
import scipy.signal
|
|
|
|
|
|
class TestConvolutionNN(NNTestCase):
|
|
_do_cuda_memory_leak_check = True
|
|
_do_cuda_non_default_stream = True
|
|
|
|
def test_conv_backcompat(self):
|
|
from torch.serialization import SourceChangeWarning
|
|
|
|
# This file was generated by running on PyTorch 1.0.1 on Python 2:
|
|
#
|
|
# import torch
|
|
# from torch import nn
|
|
# m = nn.Conv2d(1, 1, 1)
|
|
# torch.save(m, 'legacy_conv2d.pt')
|
|
#
|
|
# NB: This Pickle also contains some Unicode data!
|
|
path = download_file("https://download.pytorch.org/test_data/legacy_conv2d.pt")
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore", SourceChangeWarning)
|
|
# weights_only=False as this is legacy code that saves the model
|
|
m = torch.load(path, encoding="utf-8", weights_only=False)
|
|
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
|
|
self.assertEqual(m(input).size(), (1, 1, 1, 1))
|
|
|
|
def test_invalid_conv1d(self):
|
|
for dtype in [
|
|
torch.half,
|
|
torch.bfloat16,
|
|
torch.float,
|
|
torch.double,
|
|
torch.cfloat,
|
|
torch.cdouble,
|
|
]:
|
|
module = nn.Conv1d(
|
|
in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True
|
|
).to(dtype)
|
|
input = torch.randn(1, 3, 4).to(dtype)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Calculated padded input size per channel: \(4\). "
|
|
+ r"Kernel size: \(10\). Kernel size can\'t be greater than actual input size",
|
|
):
|
|
module(input)
|
|
|
|
# Negative stride check
|
|
module = nn.Conv1d(
|
|
in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True
|
|
).to(dtype)
|
|
input = torch.randn(1, 3, 4).to(dtype)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "non-positive stride is not supported"
|
|
):
|
|
module(input)
|
|
|
|
def test_mismatch_shape_conv2d(self):
|
|
for dtype in (torch.float, torch.cfloat):
|
|
x = torch.randn(1, 10, 1, 28, 28, dtype=dtype)
|
|
w = torch.randn(6, 1, 5, 5, dtype=dtype)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got "
|
|
+ r"input of size: \[1, 10, 1, 28, 28\]",
|
|
):
|
|
F.conv2d(x, w)
|
|
|
|
def test_conv2d_discontiguous_weight(self):
|
|
for dtype in (torch.float, torch.cfloat):
|
|
# Test for https://github.com/pytorch/pytorch/issues/55781
|
|
x = torch.ones(64, 16, 16, 16, dtype=dtype)
|
|
weight = (
|
|
torch.arange(0, 1.0, 1 / 2.0**10)
|
|
.reshape(32, 16, 1, 2)
|
|
.to(dtype)[:, :, :, ::2]
|
|
)
|
|
self.assertFalse(weight.is_contiguous())
|
|
y = torch.nn.functional.conv2d(x, weight, None)
|
|
if torch.backends.mkldnn.is_available():
|
|
# Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
|
|
with torch.backends.mkldnn.flags(enabled=False):
|
|
y_ = torch.nn.functional.conv2d(x, weight, None)
|
|
self.assertEqual(y, y_)
|
|
self.assertEqual(y.sum(), 4186112.0)
|
|
|
|
def test_invalid_conv2d(self):
|
|
for dtype in [
|
|
torch.half,
|
|
torch.bfloat16,
|
|
torch.float,
|
|
torch.double,
|
|
torch.cfloat,
|
|
torch.cdouble,
|
|
]:
|
|
module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(
|
|
dtype
|
|
)
|
|
input = torch.empty(1, 1, 4, 4).to(dtype)
|
|
self.assertRaises(RuntimeError, lambda: module(input))
|
|
|
|
module = nn.Conv2d(
|
|
in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True
|
|
)
|
|
input = torch.randn(1, 3, 1, 1)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Calculated padded input size per channel: \(1 x 1\). "
|
|
+ r"Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size",
|
|
):
|
|
module(input)
|
|
|
|
# Negative stride check
|
|
module = nn.Conv2d(
|
|
in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True
|
|
).to(dtype)
|
|
input = torch.randn(1, 3, 4, 4).to(dtype)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "non-positive stride is not supported"
|
|
):
|
|
module(input)
|
|
|
|
# Zero stride check
|
|
module = nn.Conv2d(
|
|
in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True
|
|
).to(dtype)
|
|
input = torch.randn(1, 3, 4, 4).to(dtype)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "non-positive stride is not supported"
|
|
):
|
|
module(input)
|
|
|
|
def test_invalid_conv3d(self):
|
|
for dtype in [
|
|
torch.half,
|
|
torch.bfloat16,
|
|
torch.float,
|
|
torch.double,
|
|
torch.cfloat,
|
|
torch.cdouble,
|
|
]:
|
|
module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2).to(
|
|
dtype
|
|
)
|
|
input = torch.empty(1, 1, 4, 4, 4).to(dtype)
|
|
self.assertRaises(RuntimeError, lambda: module(input))
|
|
|
|
# Negative stride check
|
|
module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2)
|
|
input = torch.empty(1, 1, 4, 4, 4)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "non-positive stride is not supported"
|
|
):
|
|
module(input)
|
|
|
|
def test_conv_invalid_groups(self):
|
|
with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
|
|
torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0)
|
|
with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
|
|
torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1)
|
|
with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
|
|
torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2)
|
|
|
|
def test_conv3d_overflow_values(self):
|
|
input = torch.full(
|
|
(
|
|
0,
|
|
7,
|
|
9,
|
|
1,
|
|
5,
|
|
),
|
|
0,
|
|
dtype=torch.float32,
|
|
requires_grad=False,
|
|
)
|
|
weight = torch.full(
|
|
(
|
|
9,
|
|
1,
|
|
),
|
|
4.14214e16,
|
|
dtype=torch.float32,
|
|
requires_grad=False,
|
|
)
|
|
stride = [5, 5, 5]
|
|
|
|
with self.assertRaisesRegex(ValueError, "Padding height too large"):
|
|
torch.ops.aten.slow_conv3d(
|
|
input,
|
|
weight,
|
|
kernel_size=[5, 5, 5],
|
|
bias=None,
|
|
stride=stride,
|
|
padding=[2**62, 2**62, 2**62],
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Kernel height x width product is too large:"
|
|
):
|
|
torch.ops.aten.slow_conv3d(
|
|
input,
|
|
weight,
|
|
kernel_size=[2**32, 2**32, 2**32],
|
|
bias=None,
|
|
stride=stride,
|
|
padding=[2**31, 2**31, 2**31],
|
|
)
|
|
|
|
def test_Conv1d_module_same_padding(self):
|
|
# Compare module against functional: without strides/dilation, asymmetric padding
|
|
x = torch.rand(1, 1, 20)
|
|
module = nn.Conv1d(
|
|
in_channels=1, out_channels=1, kernel_size=10, padding="same"
|
|
)
|
|
expect = F.conv1d(x, module.weight, module.bias, padding="same")
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# Test dilation, symmetric padding
|
|
module = nn.Conv1d(
|
|
in_channels=1, out_channels=1, kernel_size=10, padding="same", dilation=2
|
|
)
|
|
expect = F.conv1d(x, module.weight, module.bias, padding="same", dilation=2)
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# Test non-zero padding_mode, requiring explicit padding
|
|
module = nn.Conv1d(
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=10,
|
|
padding="same",
|
|
padding_mode="replicate",
|
|
)
|
|
x_padded = F.pad(x, [4, 5], mode="replicate")
|
|
expect = F.conv1d(x_padded, module.weight, module.bias, padding="valid")
|
|
self.assertEqual(expect, module(x))
|
|
self.assertEqual(x.size(), expect.size())
|
|
|
|
# Test connstruction with invalid padding string raises
|
|
with self.assertRaisesRegex(ValueError, "Invalid padding string"):
|
|
module = nn.Conv1d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="foo"
|
|
)
|
|
|
|
# Test connstruction with same padding and strides raises
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv1d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
|
|
)
|
|
|
|
def test_Conv2d_module_same_padding(self):
|
|
# Compare module against functional:
|
|
# without strides/dilation, both symmetric and asymmetric padding
|
|
x = torch.rand(1, 1, 9, 20)
|
|
module = nn.Conv2d(
|
|
in_channels=1, out_channels=1, kernel_size=(5, 10), padding="same"
|
|
)
|
|
expect = F.conv2d(x, module.weight, module.bias, padding="same")
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# with dilation, symmetric padding
|
|
module = nn.Conv2d(
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=(3, 4),
|
|
padding="same",
|
|
dilation=(1, 2),
|
|
)
|
|
expect = F.conv2d(
|
|
x, module.weight, module.bias, padding="same", dilation=(1, 2)
|
|
)
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# Test non-zero padding_mode, requiring explicit padding
|
|
module = nn.Conv2d(
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=(3, 4),
|
|
padding="same",
|
|
padding_mode="reflect",
|
|
)
|
|
x_padded = F.pad(x, [1, 2, 1, 1], mode="reflect")
|
|
expect = F.conv2d(x_padded, module.weight, module.bias, padding="valid")
|
|
self.assertEqual(expect, module(x))
|
|
self.assertEqual(x.size(), expect.size())
|
|
|
|
# Test connstruction with invalid padding string raises
|
|
with self.assertRaisesRegex(ValueError, "Invalid padding string"):
|
|
module = nn.Conv2d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="foo"
|
|
)
|
|
|
|
# Test connstruction with same padding and strides raises
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv2d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv2d(
|
|
in_channels=3,
|
|
out_channels=33,
|
|
kernel_size=10,
|
|
padding="same",
|
|
stride=(1, 3),
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv2d(
|
|
in_channels=3,
|
|
out_channels=33,
|
|
kernel_size=10,
|
|
padding="same",
|
|
stride=(4, 1),
|
|
)
|
|
|
|
def test_Conv3d_module_same_padding(self):
|
|
# Compare module against functional:
|
|
x = torch.rand(1, 1, 4, 4, 4)
|
|
# without dilation, both symmetric and asymmetric padding
|
|
module = nn.Conv3d(
|
|
in_channels=1, out_channels=1, kernel_size=(2, 3, 4), padding="same"
|
|
)
|
|
expect = F.conv3d(x, module.weight, module.bias, padding="same")
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# with dilation, both symmetric and asymmetric padding
|
|
module = nn.Conv3d(
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=(2, 3, 4),
|
|
padding="same",
|
|
dilation=(3, 2, 1),
|
|
)
|
|
expect = F.conv3d(
|
|
x, module.weight, module.bias, padding="same", dilation=(3, 2, 1)
|
|
)
|
|
self.assertEqual(expect, module(x))
|
|
|
|
# Test non-zero padding_mode, requiring explicit padding
|
|
module = nn.Conv3d(
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=(2, 3, 4),
|
|
padding="same",
|
|
padding_mode="circular",
|
|
)
|
|
x_padded = F.pad(x, [1, 2, 1, 1, 0, 1], mode="circular")
|
|
expect = F.conv3d(x_padded, module.weight, module.bias, padding="valid")
|
|
self.assertEqual(expect, module(x))
|
|
self.assertEqual(x.size(), expect.size())
|
|
|
|
# Test connstruction with invalid padding string raises
|
|
with self.assertRaisesRegex(ValueError, "Invalid padding string"):
|
|
module = nn.Conv3d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="foo"
|
|
)
|
|
|
|
# Test connstruction with same padding and strides raises
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv3d(
|
|
in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv3d(
|
|
in_channels=3,
|
|
out_channels=33,
|
|
kernel_size=10,
|
|
padding="same",
|
|
stride=(1, 1, 3),
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv3d(
|
|
in_channels=3,
|
|
out_channels=33,
|
|
kernel_size=10,
|
|
padding="same",
|
|
stride=(1, 4, 1),
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "padding='same'"):
|
|
module = nn.Conv3d(
|
|
in_channels=3,
|
|
out_channels=33,
|
|
kernel_size=10,
|
|
padding="same",
|
|
stride=(5, 1, 1),
|
|
)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
def test_thnn_conv_strided_padded_dilated(self):
|
|
for convfn, dims, transposed in (
|
|
(torch.nn.functional.conv2d, 2, False),
|
|
(torch.nn.functional.conv_transpose2d, 2, True),
|
|
(torch.nn.functional.conv3d, 3, False),
|
|
(torch.nn.functional.conv_transpose3d, 3, True),
|
|
):
|
|
for stride, padding, dilation in (
|
|
(2, 0, 1),
|
|
(1, 1, 1),
|
|
(2, 1, 1),
|
|
(1, 0, 2),
|
|
):
|
|
kwargs = {"stride": stride, "padding": padding, "dilation": dilation}
|
|
inp_shape = (1, 2) + dims * (4,)
|
|
weight_shape = (2, 2) + dims * (1,)
|
|
inputs = torch.randn(
|
|
inp_shape, dtype=torch.double, device="cuda", requires_grad=True
|
|
)
|
|
weight = torch.randn(
|
|
weight_shape, dtype=torch.double, device="cuda", requires_grad=True
|
|
)
|
|
bias = torch.randn(
|
|
2, dtype=torch.double, device="cuda", requires_grad=True
|
|
)
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
res = convfn(inputs, weight, bias, **kwargs)
|
|
res_cpu = convfn(inputs.cpu(), weight.cpu(), bias.cpu(), **kwargs)
|
|
self.assertEqual(res, res_cpu)
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
torch.autograd.gradcheck(
|
|
lambda x, w, b: convfn(x, w, b, **kwargs),
|
|
(inputs, weight, bias),
|
|
)
|
|
torch.autograd.gradcheck(
|
|
lambda x, w, b: convfn(x, w, b, **kwargs),
|
|
(inputs.cpu(), weight.cpu(), bias.cpu()),
|
|
)
|
|
|
|
def test_Conv2d_inconsistent_types(self):
|
|
inputs = torch.randn(4, 1, 7, 7, dtype=torch.float)
|
|
weights = torch.randn(1, 1, 3, 3, dtype=torch.double)
|
|
# inconsistent types should raise an exception
|
|
self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
|
|
# but it should work with the same type
|
|
nn.functional.conv2d(inputs.float(), weights.float())
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self):
|
|
inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
|
|
weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
|
|
bias = torch.randn(1, dtype=torch.double, device="cuda")
|
|
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
# inconsistent types should raise an exception
|
|
self.assertRaises(
|
|
RuntimeError, lambda: nn.functional.conv2d(inputs, weights)
|
|
)
|
|
self.assertRaises(
|
|
RuntimeError,
|
|
lambda: nn.functional.conv2d(inputs, weights.float(), bias),
|
|
)
|
|
|
|
# but it should work with the same type
|
|
nn.functional.conv2d(inputs.float(), weights.float(), bias.float())
|
|
|
|
def test_Conv2d_1x1(self):
|
|
in_channels = 2
|
|
mod = torch.nn.Conv2d(2, 2, 1, bias=False).to(dtype=torch.double)
|
|
input = torch.randn(
|
|
1, in_channels, 5, 5, requires_grad=True, dtype=torch.double
|
|
)
|
|
for enabled in (False, True):
|
|
with torch.backends.mkldnn.flags(enabled=enabled):
|
|
gradcheck(F.conv2d, (input, mod.weight))
|
|
|
|
def test_Conv2d_OneDNN(self):
|
|
def run_once(group_val=24, dilation=1):
|
|
ifm = torch.ones([1, group_val, 6, 6], dtype=torch.float32)
|
|
weights = torch.ones([group_val, 1, 3, 3], dtype=torch.float32)
|
|
op = torch.nn.Conv2d(
|
|
in_channels=group_val,
|
|
out_channels=group_val,
|
|
kernel_size=[3, 3],
|
|
stride=[2, 2],
|
|
padding=[1, 1],
|
|
dilation=[dilation, dilation],
|
|
groups=group_val,
|
|
bias=False,
|
|
padding_mode="zeros",
|
|
)
|
|
|
|
op.weight.data = weights
|
|
res = op(ifm)
|
|
grad_in = torch.ones(res.shape, dtype=torch.float32)
|
|
res.backward(grad_in)
|
|
return op.weight.grad
|
|
|
|
for gorup_val in (24, 48, 23, 25):
|
|
for dilation in (1, 2):
|
|
with torch.backends.mkldnn.flags(enabled=False):
|
|
without_onednn = run_once(gorup_val, dilation)
|
|
|
|
with torch.backends.mkldnn.flags(enabled=True):
|
|
with_onednn = run_once(gorup_val, dilation)
|
|
|
|
self.assertEqual(without_onednn, with_onednn)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
|
|
def test_cudnn_non_contiguous(self):
|
|
x = torch.randn(192, 16, 50).cuda()
|
|
x = x.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
|
m = torch.nn.Conv1d(
|
|
in_channels=16, out_channels=32, kernel_size=2, bias=True
|
|
).cuda()
|
|
m(x)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
|
|
def test_cudnn_not_mutate_stride(self):
|
|
weight = torch.randn(64, 64, 1, 1)
|
|
x = torch.randn(2, 64, 10, 10).to(memory_format=torch.channels_last)
|
|
weight_stride = weight.stride()
|
|
|
|
def conv(x, weight):
|
|
return torch.convolution(
|
|
x,
|
|
weight,
|
|
stride=(1, 1),
|
|
padding=(0, 0),
|
|
dilation=(1, 1),
|
|
transposed=False,
|
|
output_padding=(0, 0),
|
|
groups=1,
|
|
bias=None,
|
|
)
|
|
|
|
# should have run in nhwc without mutating input strides
|
|
out_nhwc = conv(x, weight)
|
|
self.assertEqual(weight.stride(), weight_stride)
|
|
self.assertTrue(out_nhwc.is_contiguous(memory_format=torch.channels_last))
|
|
|
|
x = x.contiguous(memory_format=torch.contiguous_format)
|
|
out_c = conv(x, weight)
|
|
self.assertTrue(out_c.is_contiguous(memory_format=torch.contiguous_format))
|
|
self.assertEqual(out_c, out_nhwc)
|
|
self.assertEqual(weight.stride(), weight_stride)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
|
|
def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self):
|
|
inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
|
|
weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
|
|
bias = torch.randn(1, dtype=torch.double, device="cuda")
|
|
|
|
with torch.backends.cudnn.flags(enabled=True):
|
|
# inconsistent types should raise an exception
|
|
self.assertRaises(
|
|
RuntimeError, lambda: nn.functional.conv2d(inputs, weights)
|
|
)
|
|
self.assertRaises(
|
|
RuntimeError,
|
|
lambda: nn.functional.conv2d(inputs, weights.float(), bias),
|
|
)
|
|
|
|
# but it should work with the same type
|
|
nn.functional.conv2d(inputs.float(), weights.float(), bias.float())
|
|
|
|
def test_Conv2d_missing_argument(self):
|
|
c = nn.Conv2d(3, 3, 3)
|
|
self.assertRaises(TypeError, lambda: c(None))
|
|
|
|
def test_Conv2d_backward_twice(self):
|
|
input = torch.randn(2, 3, 5, 5)
|
|
c = nn.Conv2d(3, 3, 3)
|
|
o1 = c(input)
|
|
o1.sum().backward()
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "Specify retain_graph=True", lambda: o1.sum().backward()
|
|
)
|
|
|
|
def test_conv_modules_raise_error_on_incorrect_input_size(self):
|
|
for dtype in [torch.half, torch.bfloat16, torch.double, torch.float]:
|
|
modules = [
|
|
nn.Conv1d(3, 8, 3).to(dtype),
|
|
nn.ConvTranspose1d(3, 8, 3).to(dtype),
|
|
nn.Conv2d(3, 8, 3).to(dtype),
|
|
nn.ConvTranspose2d(3, 8, 3).to(dtype),
|
|
nn.Conv3d(3, 8, 3).to(dtype),
|
|
nn.ConvTranspose3d(3, 8, 3).to(dtype),
|
|
]
|
|
|
|
invalid_input_dims = [(1, 4), (1, 4), (2, 5), (2, 5), (3, 6), (3, 6)]
|
|
|
|
for invalid_dims, module in zip(invalid_input_dims, modules):
|
|
for dims in invalid_dims:
|
|
input = torch.empty(torch.Size((3,) * dims))
|
|
self.assertRaises(RuntimeError, lambda: module(input))
|
|
|
|
def test_conv_shapecheck(self):
|
|
def test(should_raise, module, input_size, dtype):
|
|
input = torch.empty(3, *input_size).to(dtype)
|
|
if should_raise:
|
|
self.assertRaises(RuntimeError, lambda: module(input))
|
|
else:
|
|
# just run it to ensure no exception raised.
|
|
module(input)
|
|
|
|
for dtype in [
|
|
torch.half,
|
|
torch.bfloat16,
|
|
torch.float,
|
|
torch.double,
|
|
torch.cfloat,
|
|
torch.cdouble,
|
|
]:
|
|
# Conv1d
|
|
test(True, nn.Conv1d(1, 1, 3).to(dtype), (1, 2), dtype)
|
|
test(True, nn.Conv1d(1, 1, 3, stride=2).to(dtype), (1, 2), dtype)
|
|
test(False, nn.Conv1d(1, 1, 2).to(dtype), (1, 2), dtype)
|
|
test(False, nn.Conv1d(1, 1, 2, stride=2).to(dtype), (1, 2), dtype)
|
|
test(
|
|
False, nn.Conv1d(1, 1, 3, stride=2, padding=1).to(dtype), (1, 2), dtype
|
|
)
|
|
|
|
# Conv2d
|
|
test(True, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 2, 2), dtype)
|
|
test(False, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 3, 3), dtype)
|
|
test(False, nn.Conv2d(1, 1, (3, 3), padding=1).to(dtype), (1, 2, 2), dtype)
|
|
|
|
# Conv3D
|
|
test(True, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 2, 2, 2), dtype)
|
|
test(False, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 3, 3, 3), dtype)
|
|
test(
|
|
False,
|
|
nn.Conv3d(1, 1, (3, 3, 3), padding=1).to(dtype),
|
|
(1, 2, 2, 2),
|
|
dtype,
|
|
)
|
|
|
|
def test_ConvTranspose2d_output_size(self):
|
|
m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
|
|
i = torch.randn(2, 3, 6, 6)
|
|
for h in range(15, 22):
|
|
for w in range(15, 22):
|
|
if 18 <= h <= 20 and 18 <= w <= 20:
|
|
output = m(i, output_size=(h, w))
|
|
self.assertEqual(output.size()[2:], (h, w))
|
|
else:
|
|
self.assertRaises(ValueError, lambda: m(i, (h, w)))
|
|
|
|
def test_ConvTranspose2d_output_size_downsample_upsample(self):
|
|
b, c, hid_c = 2, 3, 2
|
|
for h in range(13, 24):
|
|
for w in range(13, 17):
|
|
for k in range(2, 5):
|
|
for d in range(1, 5):
|
|
for s in range(1, 4):
|
|
for p in range(3):
|
|
conv = nn.Conv2d(
|
|
in_channels=c,
|
|
out_channels=hid_c,
|
|
kernel_size=k,
|
|
stride=s,
|
|
padding=p,
|
|
dilation=d,
|
|
)
|
|
|
|
t_conv = nn.ConvTranspose2d(
|
|
in_channels=hid_c,
|
|
out_channels=c,
|
|
kernel_size=k,
|
|
stride=s,
|
|
padding=p,
|
|
dilation=d,
|
|
)
|
|
|
|
i = torch.randn(b, c, h, w)
|
|
|
|
out = t_conv(conv(i), output_size=i.shape)
|
|
|
|
self.assertEqual(out.size()[2:], i.size()[2:])
|
|
|
|
def test_ConvTranspose3d_correct_output_size(self):
|
|
# Check that ConvTranspose3d can take a 5d output_size.
|
|
m = nn.ConvTranspose3d(2, 2, 2)
|
|
i = torch.rand(1, 2, 1, 1, 1)
|
|
m(i, output_size=(1, 2, 2, 2, 2))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
def test_ConvTranspose2d_half_cublas_gemm(self):
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
inputs = torch.randn(1, 1, 16, 16, device="cuda", dtype=torch.half)
|
|
deconv = (
|
|
nn.ConvTranspose2d(1, 1, 3, stride=2, padding=1, output_padding=1)
|
|
.cuda()
|
|
.half()
|
|
)
|
|
output = deconv(inputs)
|
|
output.mean().backward()
|
|
|
|
# For https://github.com/pytorch/pytorch/pull/1273
|
|
# Almost identical to the above `test_Conv2d_naive_groups`
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@torch.backends.miopen.flags(immediate=True)
|
|
@tf32_on_and_off(0.001)
|
|
def test_Conv2d_groups_nobias(self):
|
|
dev_dtypes = [("cpu", torch.float)]
|
|
if TEST_CUDA:
|
|
dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
|
|
if AMPERE_OR_ROCM:
|
|
dev_dtypes += [("cuda", torch.bfloat16)]
|
|
for device, dtype in dev_dtypes:
|
|
m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype)
|
|
i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
|
|
output = m(i)
|
|
grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
|
|
output.backward(grad_output)
|
|
|
|
m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
|
|
m1.weight.data.copy_(m.weight.data[:2])
|
|
i1 = i.data[:, :2].contiguous().requires_grad_(True)
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :2].contiguous())
|
|
|
|
m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
|
|
m2.weight.data.copy_(m.weight.data[2:])
|
|
i2 = i.data[:, 2:].contiguous().requires_grad_(True)
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, 2:].contiguous())
|
|
|
|
self.assertEqual(output, torch.cat([output1, output2], 1))
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
|
|
# Almost identical to the above `test_Conv2d_naive_groups`
|
|
# Covering special case when group > 1, input-channel / group < 16 and output-channel is multiple of 16
|
|
# See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686
|
|
# and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@torch.backends.miopen.flags(immediate=True)
|
|
@tf32_on_and_off(0.001)
|
|
def test_Conv2d_groups_nobias_v2(self):
|
|
torch.manual_seed(123)
|
|
dev_dtypes = [("cpu", torch.float)]
|
|
if TEST_CUDA:
|
|
dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
|
|
if AMPERE_OR_ROCM:
|
|
dev_dtypes += [("cuda", torch.bfloat16)]
|
|
for device, dtype in dev_dtypes:
|
|
m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype)
|
|
i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
|
|
output = m(i)
|
|
grad_output = torch.randn(2, 16, 4, 4, device=device, dtype=dtype)
|
|
output.backward(grad_output)
|
|
|
|
m1 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
|
|
m1.weight.data.copy_(m.weight.data[:8])
|
|
i1 = i.data[:, :2].contiguous().requires_grad_(True)
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :8].contiguous())
|
|
|
|
m2 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
|
|
m2.weight.data.copy_(m.weight.data[8:])
|
|
i2 = i.data[:, 2:].contiguous().requires_grad_(True)
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, 8:].contiguous())
|
|
|
|
self.assertEqual(output, torch.cat([output1, output2], 1))
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
|
|
# CPU-only test for group conv3d fast implementation using bmm
|
|
# See: https://github.com/pytorch/pytorch/pull/36355
|
|
def test_Conv3d_groups_nobias(self):
|
|
torch.manual_seed(123)
|
|
m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=False).to("cpu", torch.float)
|
|
i = torch.randn(
|
|
2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True
|
|
)
|
|
output = m(i)
|
|
grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
|
|
output.backward(grad_output)
|
|
|
|
m1 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
|
|
m1.weight.data.copy_(m.weight.data[:8])
|
|
i1 = i.data[:, :2].contiguous().requires_grad_(True)
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :8].contiguous())
|
|
|
|
m2 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
|
|
m2.weight.data.copy_(m.weight.data[8:])
|
|
i2 = i.data[:, 2:].contiguous().requires_grad_(True)
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, 8:].contiguous())
|
|
|
|
self.assertEqual(output, torch.cat([output1, output2], 1))
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[torch.float],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[torch.float],
|
|
rtol=dtype2prec_DONTUSE[torch.float],
|
|
)
|
|
|
|
def test_Conv3d_groups_wbias(self):
|
|
torch.manual_seed(123)
|
|
m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=True).to("cpu", torch.float)
|
|
i = torch.randn(
|
|
2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True
|
|
)
|
|
output = m(i)
|
|
grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
|
|
output.backward(grad_output)
|
|
|
|
m1 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
|
|
m1.weight.data.copy_(m.weight.data[:8])
|
|
m1.bias.data.copy_(m.bias.data[:8])
|
|
i1 = i.data[:, :2].contiguous().requires_grad_(True)
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :8].contiguous())
|
|
|
|
m2 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
|
|
m2.weight.data.copy_(m.weight.data[8:])
|
|
m2.bias.data.copy_(m.bias.data[8:])
|
|
i2 = i.data[:, 2:].contiguous().requires_grad_(True)
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, 8:].contiguous())
|
|
|
|
self.assertEqual(output, torch.cat([output1, output2], 1))
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[torch.float],
|
|
rtol=dtype2prec_DONTUSE[torch.float],
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[torch.float],
|
|
rtol=dtype2prec_DONTUSE[torch.float],
|
|
)
|
|
self.assertEqual(
|
|
m.bias.grad.data,
|
|
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[torch.float],
|
|
rtol=dtype2prec_DONTUSE[torch.float],
|
|
)
|
|
|
|
def test_conv_tbc(self):
|
|
with set_default_dtype(torch.double):
|
|
inp = torch.randn(9, 4, 5, requires_grad=True)
|
|
weight = torch.randn(3, 5, 6, requires_grad=True)
|
|
bias = torch.randn(6, requires_grad=True)
|
|
|
|
gradcheck(
|
|
lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3)
|
|
)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
|
def test_grouped_conv_cudnn_nhwc_support(self):
|
|
# in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
|
|
input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
weight = torch.randn((8, 4, 3, 3), dtype=torch.float16, device="cuda").to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), False, (0, 0), 4)
|
|
input = torch.randn((16, 8, 8, 8), dtype=torch.float16, device="cuda").to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), True, (0, 0), 4)
|
|
|
|
@unittest.expectedFailure
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
|
def test_conv_cudnn_memory_layout_dominance(self):
|
|
# desired behavior here is to have the memory_layout of conv.weight to
|
|
# dominante the layout of output.
|
|
# which is not the same as current behavior, we'll fix this in
|
|
# following up PRs and remove the `expectedFailure` tag
|
|
input = torch.randint(
|
|
1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True
|
|
)
|
|
conv = nn.Conv2d(8, 4, 3).cuda().float()
|
|
|
|
out = conv(input)
|
|
self.assertTrue(out.is_contiguous())
|
|
|
|
input = input.contiguous(memory_format=torch.channels_last)
|
|
out = conv(input)
|
|
self.assertTrue(out.is_contiguous())
|
|
|
|
conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last)
|
|
out = conv(input)
|
|
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
|
|
|
input = input.contiguous()
|
|
out = conv(input)
|
|
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_cudnn_noncontiguous_weight(self):
|
|
# Noncontiguous weights must be contiguous() before being
|
|
# passed to cuDNN
|
|
input = torch.tensor([1, 1, 1], dtype=torch.double, device="cuda").view(1, 1, 3)
|
|
weights1 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2)
|
|
weights2 = (
|
|
torch.tensor([1], dtype=torch.double, device="cuda")
|
|
.expand(1, 1, 2)
|
|
.contiguous()
|
|
)
|
|
self.assertEqual(
|
|
F.conv1d(input, weights1, bias=None, stride=2, dilation=2),
|
|
F.conv1d(input, weights2, bias=None, stride=2, dilation=2),
|
|
)
|
|
|
|
def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient="input"):
|
|
for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
|
|
for batch, stride, padding, chan_in, chan_out, dilation in product(
|
|
[1, 2], [1, 2], [0, 1, 2], [2], [3], [1]
|
|
):
|
|
for has_bias in [True, False]:
|
|
input_shape = [batch, chan_in]
|
|
weight_shape = [chan_out, chan_in]
|
|
for _ in range(dim):
|
|
input_shape.append(inp_size)
|
|
weight_shape.append(kern)
|
|
|
|
input = torch.randn(input_shape, requires_grad=True)
|
|
weight = torch.randn(weight_shape, requires_grad=True)
|
|
if has_bias:
|
|
bias = torch.randn([chan_out], requires_grad=True)
|
|
output = func_forward(
|
|
input,
|
|
weight,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
bias=bias,
|
|
)
|
|
|
|
gradient_o = torch.randn(output.shape)
|
|
gradient_w = torch.autograd.grad(
|
|
output, input if (gradient == "input") else weight, gradient_o
|
|
)
|
|
|
|
self.assertEqual(
|
|
gradient_w[0],
|
|
func_backward(
|
|
input_shape if (gradient == "input") else input,
|
|
weight_shape if (gradient == "weight") else weight,
|
|
gradient_o,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
),
|
|
)
|
|
|
|
def test_grad_conv1d_input(self):
|
|
self.run_grad_conv_test(F.conv1d, F.grad.conv1d_input, 1, "input")
|
|
|
|
def test_grad_conv1d_weight(self):
|
|
self.run_grad_conv_test(F.conv1d, F.grad.conv1d_weight, 1, "weight")
|
|
|
|
def test_grad_conv2d_input(self):
|
|
self.run_grad_conv_test(F.conv2d, F.grad.conv2d_input, 2, "input")
|
|
|
|
def test_grad_conv2d_weight(self):
|
|
self.run_grad_conv_test(F.conv2d, F.grad.conv2d_weight, 2, "weight")
|
|
|
|
def test_grad_conv3d_input(self):
|
|
self.run_grad_conv_test(F.conv3d, F.grad.conv3d_input, 3, "input")
|
|
|
|
def test_grad_conv3d_weight(self):
|
|
self.run_grad_conv_test(F.conv3d, F.grad.conv3d_weight, 3, "weight")
|
|
|
|
@unittest.skipIf(not torch._nnpack_available(), "NNPACK unavailable")
|
|
def test_nnpack_conv(self):
|
|
for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
|
|
for batch, stride, padding, chan_in, chan_out in product(
|
|
[1, 2, 3, 4], [1, 2], [0, 1, 2], [2], [3]
|
|
):
|
|
for has_bias in [True, False]:
|
|
input_shape = [batch, chan_in]
|
|
weight_shape = [chan_out, chan_in]
|
|
for _ in range(2):
|
|
input_shape.append(inp_size)
|
|
weight_shape.append(kern)
|
|
|
|
input = torch.randn(
|
|
input_shape, requires_grad=True, dtype=torch.float
|
|
)
|
|
weight = torch.randn(
|
|
weight_shape, requires_grad=True, dtype=torch.float
|
|
)
|
|
if has_bias:
|
|
bias = torch.randn(
|
|
[chan_out], requires_grad=True, dtype=torch.float
|
|
)
|
|
output = torch._nnpack_spatial_convolution(
|
|
input, weight, stride=stride, padding=padding, bias=bias
|
|
)
|
|
output_expected = torch.nn.functional.conv2d(
|
|
input, weight, stride=stride, padding=padding, bias=bias
|
|
)
|
|
self.assertEqual(output, output_expected, atol=3e-4, rtol=0)
|
|
|
|
gradient_o = torch.randn(output.shape, dtype=torch.float)
|
|
|
|
grads = torch.autograd.grad(output, [input, weight], gradient_o)
|
|
grads_expected = torch.autograd.grad(
|
|
output_expected, [input, weight], gradient_o
|
|
)
|
|
for gr, gr_expected in zip(grads, grads_expected):
|
|
self.assertEqual(gr, gr_expected, atol=3e-4, rtol=0)
|
|
|
|
def test_conv_padding_mode(self):
|
|
with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
|
|
nn.Conv2d(3, 3, 3, padding_mode="xyz")
|
|
|
|
with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
|
|
nn.Conv2d(3, 3, 3, padding_mode=3)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'Only "zeros" '):
|
|
nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect")
|
|
|
|
def test_functional_grad_conv(self):
|
|
# Conv 1D
|
|
input = torch.randn(1, 1, 5, requires_grad=True)
|
|
weight = torch.randn(1, 1, 3, requires_grad=True)
|
|
output = F.conv1d(input, weight, dilation=2)
|
|
grad_output = torch.randn(output.shape)
|
|
|
|
grad_input_autograd, grad_weight_autograd = torch.autograd.grad(
|
|
output, (input, weight), grad_output
|
|
)
|
|
|
|
grad_input_functional = torch.nn.grad.conv1d_input(
|
|
input.shape, weight, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_input_functional, grad_input_autograd)
|
|
|
|
grad_weight_functional = torch.nn.grad.conv1d_weight(
|
|
input, weight.shape, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_weight_functional, grad_weight_autograd)
|
|
|
|
# Conv 2D
|
|
input = torch.randn(1, 1, 5, 5, requires_grad=True)
|
|
weight = torch.randn(1, 1, 3, 3, requires_grad=True)
|
|
output = F.conv2d(input, weight, dilation=2)
|
|
grad_output = torch.randn(output.shape)
|
|
|
|
(grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
|
|
output, (input, weight), grad_output
|
|
)
|
|
|
|
grad_input_functional = torch.nn.grad.conv2d_input(
|
|
input.shape, weight, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_input_functional, grad_input_autograd)
|
|
|
|
grad_weight_functional = torch.nn.grad.conv2d_weight(
|
|
input, weight.shape, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_weight_functional, grad_weight_autograd)
|
|
|
|
# Conv 3D
|
|
input = torch.randn(1, 1, 5, 5, 5, requires_grad=True)
|
|
weight = torch.randn(1, 1, 3, 3, 3, requires_grad=True)
|
|
output = F.conv3d(input, weight, dilation=2)
|
|
grad_output = torch.randn(output.shape)
|
|
|
|
(grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
|
|
output, (input, weight), grad_output
|
|
)
|
|
|
|
grad_input_functional = torch.nn.grad.conv3d_input(
|
|
input.shape, weight, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_input_functional, grad_input_autograd)
|
|
|
|
grad_weight_functional = torch.nn.grad.conv3d_weight(
|
|
input, weight.shape, grad_output, dilation=2
|
|
)
|
|
self.assertEqual(grad_weight_functional, grad_weight_autograd)
|
|
|
|
def test_functional_grad_conv2d(self):
|
|
BATCH_SIZE = 4
|
|
IN_CH = 8
|
|
OUT_CH = 16
|
|
SPATIAL = 32
|
|
|
|
def _test_conv2d(stride, kernel_size, groups, dilation):
|
|
padding = kernel_size // 2
|
|
|
|
input = (
|
|
torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL)
|
|
.uniform_(-8.0, 8.0)
|
|
.requires_grad_(True)
|
|
)
|
|
|
|
weight = (
|
|
torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size)
|
|
.uniform_(-4.0, 4.0)
|
|
.requires_grad_(True)
|
|
)
|
|
|
|
output = F.conv2d(
|
|
input,
|
|
weight,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
)
|
|
|
|
grad_output = torch.randn(output.shape)
|
|
|
|
(grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
|
|
output, (input, weight), grad_output
|
|
)
|
|
|
|
grad_input_functional = torch.nn.grad.conv2d_input(
|
|
input.shape,
|
|
weight,
|
|
grad_output,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
)
|
|
self.assertEqual(grad_input_functional, grad_input_autograd)
|
|
|
|
grad_weight_functional = torch.nn.grad.conv2d_weight(
|
|
input,
|
|
weight.shape,
|
|
grad_output,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
)
|
|
self.assertEqual(grad_weight_functional, grad_weight_autograd)
|
|
|
|
strides = [1, 2]
|
|
kernel_sizes = [1, 3, 5]
|
|
groups = [1, 2, 4]
|
|
dilates = [1, 2]
|
|
|
|
for s, k, g, d in product(strides, kernel_sizes, groups, dilates):
|
|
_test_conv2d(s, k, g, d)
|
|
|
|
def test_permute_conv2d_issue_120211(self):
|
|
def reproducer(radius: int):
|
|
image = torch.rand(1, 1024, 1024, 3)
|
|
image = image.permute(0, 3, 1, 2)
|
|
kernel_x = torch.zeros([3, 1, 1, radius * 2 + 1], device=image.device)
|
|
image = torch.nn.functional.conv2d(image, kernel_x, groups=image.shape[-3])
|
|
|
|
for i in range(0, 128):
|
|
# This should not fail
|
|
reproducer(radius=i)
|
|
|
|
def test_conv3d_issue_120406(self):
|
|
# This should not fail
|
|
F.conv3d(torch.ones(2, 3, 8, 9, 26), torch.ones(3, 1, 1, 1, 17), groups=3)
|
|
|
|
def test_conv1d_issue_120547(self):
|
|
weight = torch.ones([16, 1, 32])
|
|
bias = torch.ones([16])
|
|
stride, padding, dilation, groups = (1, 16, 1, 16)
|
|
input = torch.rand((1, 1, 16))
|
|
input = input.transpose(1, 2)
|
|
# This should not fail
|
|
F.conv1d(input, weight, bias, stride, padding, dilation, groups)
|
|
|
|
|
|
class TestConvolutionNNDeviceType(NNTestCase):
|
|
def run_conv_double_back_test(
|
|
self,
|
|
kern,
|
|
stride,
|
|
padding,
|
|
chan_in,
|
|
chan_out,
|
|
batch_size,
|
|
inp_size,
|
|
dilation,
|
|
no_weight,
|
|
groups=1,
|
|
use_cuda=False,
|
|
use_bias=True,
|
|
dtype=torch.double,
|
|
):
|
|
if use_cuda:
|
|
device = torch.device("cuda")
|
|
else:
|
|
device = torch.device("cpu")
|
|
|
|
x = torch.randn(
|
|
batch_size,
|
|
chan_in,
|
|
inp_size,
|
|
inp_size,
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
)
|
|
weight = torch.randn(
|
|
chan_out,
|
|
chan_in // groups,
|
|
kern,
|
|
kern,
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=not no_weight,
|
|
)
|
|
if use_bias:
|
|
bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True)
|
|
else:
|
|
bias = None
|
|
|
|
def func(*inputs):
|
|
if use_bias:
|
|
lx, lweight, lbias = inputs
|
|
else:
|
|
lx, lweight = inputs
|
|
lbias = None
|
|
# We disable cudnn during forward to avoid finite difference imprecision issues
|
|
with cudnn.flags(enabled=False):
|
|
out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
|
|
return out
|
|
|
|
if use_bias:
|
|
inputs = x, weight, bias
|
|
else:
|
|
inputs = x, weight
|
|
|
|
dummy_out = func(*inputs)
|
|
grad_y = torch.randn_like(
|
|
dummy_out, device=device, dtype=dtype, requires_grad=True
|
|
)
|
|
|
|
# Issue #15353: test mkldnn double backward, don't run gradgradcheck due
|
|
# to imprecision issues
|
|
if dtype == torch.float:
|
|
(g,) = torch.autograd.grad(dummy_out.sum(), x, create_graph=True)
|
|
return g.requires_grad
|
|
|
|
return gradgradcheck(func, inputs, (grad_y,))
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfNoCudnn
|
|
@dtypes(
|
|
*floating_and_complex_types_and(
|
|
torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []
|
|
)
|
|
)
|
|
@parametrize_test("dilation", [1, 2, 3])
|
|
def test_Conv2d_deterministic_cudnn(self, device, dtype, dilation):
|
|
inputs = torch.randn(2, 3, 7, 7, device=device, dtype=dtype, requires_grad=True)
|
|
with cudnn.flags(enabled=True, benchmark=True, deterministic=True):
|
|
conv1 = torch.nn.Conv2d(3, 3, 3, dilation=dilation).to(device, dtype)
|
|
conv2 = torch.nn.Conv2d(3, 3, 3, dilation=dilation).to(device, dtype)
|
|
conv2.bias.data.copy_(conv1.bias.data)
|
|
conv2.weight.data.copy_(conv1.weight.data)
|
|
out1 = conv1(inputs)
|
|
out2 = conv2(inputs)
|
|
self.assertEqual(out1, out2, atol=0.0, rtol=0)
|
|
y = torch.randn(out1.size(), device=device, dtype=dtype)
|
|
out1.backward(y)
|
|
out2.backward(y)
|
|
self.assertEqual(
|
|
conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0
|
|
)
|
|
self.assertEqual(
|
|
conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0
|
|
)
|
|
|
|
@onlyCUDA
|
|
@dtypes(
|
|
*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
|
|
)
|
|
def test_Conv2d_large_workspace(self, device, dtype):
|
|
# These sizes require huge cuDNN workspaces. Make sure we choose a
|
|
# reasonable algorithm that does not run out of memory
|
|
sizes = [
|
|
(1, 256, 109, 175),
|
|
(1, 256, 80, 128),
|
|
(1, 256, 120, 192),
|
|
]
|
|
|
|
def run_test(benchmark):
|
|
with torch.backends.cudnn.flags(enabled=True, benchmark=benchmark):
|
|
conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(
|
|
device, dtype
|
|
)
|
|
for size in sizes:
|
|
x = torch.randn(size, device=device, dtype=dtype)
|
|
out = conv(x.detach().clone().requires_grad_())
|
|
out.backward(torch.ones_like(out))
|
|
|
|
run_test(benchmark=False)
|
|
run_test(benchmark=True)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.half, torch.float)
|
|
def test_ConvTranspose2d_large_output_padding(self, device, dtype):
|
|
net1 = torch.nn.ConvTranspose2d(
|
|
128, 64, kernel_size=3, stride=2, padding=1, output_padding=1
|
|
).to(device=device, dtype=dtype)
|
|
net2 = torch.nn.ConvTranspose2d(
|
|
64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
|
|
).to(device=device, dtype=dtype)
|
|
net3 = torch.nn.ConvTranspose2d(
|
|
32, 3, kernel_size=3, stride=2, padding=1, output_padding=1
|
|
).to(device=device, dtype=dtype)
|
|
x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True)
|
|
x = net1(x)
|
|
x = net2(x)
|
|
x = net3(x)
|
|
x.backward(torch.randn_like(x))
|
|
torch.cuda.synchronize()
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
# Very similar to test_Conv2d_naive_groups but with special care to handle
|
|
# the number of groups == number of input channels
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@torch.backends.miopen.flags(immediate=True)
|
|
@tf32_on_and_off(0.01)
|
|
def test_Conv2d_depthwise_naive_groups(self, device, dtype):
|
|
for depth_multiplier in [1, 2]:
|
|
m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
|
|
device, dtype
|
|
)
|
|
i = (
|
|
torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype)
|
|
.div_(2)
|
|
.requires_grad_()
|
|
)
|
|
output = m(i)
|
|
grad_output = (
|
|
torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype)
|
|
/ 2
|
|
)
|
|
output.backward(grad_output)
|
|
|
|
offset = 1 * depth_multiplier
|
|
|
|
m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
|
|
m1.weight.data = m.weight.data[:offset].clone()
|
|
m1.bias.data = m.bias.data[:offset].clone()
|
|
i1 = i.detach()[:, :1].clone().requires_grad_()
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :offset].contiguous())
|
|
|
|
m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
|
|
m2.weight.data.copy_(m.weight.data[offset:])
|
|
m2.bias.data.copy_(m.bias.data[offset:])
|
|
i2 = i.detach()[:, 1:].clone().requires_grad_()
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, offset:].contiguous())
|
|
|
|
self.assertEqual(
|
|
output,
|
|
torch.cat([output1, output2], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.bias.grad.data,
|
|
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@torch.backends.miopen.flags(immediate=True)
|
|
@tf32_on_and_off(0.01)
|
|
def test_Conv3d_depthwise_naive_groups(self, device, dtype):
|
|
for depth_multiplier in [1, 2]:
|
|
m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
|
|
device, dtype
|
|
)
|
|
i = (
|
|
torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype)
|
|
.div_(2)
|
|
.requires_grad_()
|
|
)
|
|
output = m(i)
|
|
grad_output = (
|
|
torch.randn(
|
|
2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype
|
|
)
|
|
/ 2
|
|
)
|
|
output.backward(grad_output)
|
|
|
|
offset = 1 * depth_multiplier
|
|
|
|
m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
|
|
m1.weight.data = m.weight.data[:offset].clone()
|
|
m1.bias.data = m.bias.data[:offset].clone()
|
|
i1 = i.detach()[:, :1].clone().requires_grad_()
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :offset].contiguous())
|
|
|
|
m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
|
|
m2.weight.data.copy_(m.weight.data[offset:])
|
|
m2.bias.data.copy_(m.bias.data[offset:])
|
|
i2 = i.detach()[:, 1:].clone().requires_grad_()
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, offset:].contiguous())
|
|
is_cuda_sm86 = device.startswith(
|
|
"cuda"
|
|
) and torch.cuda.get_device_capability(0) == (8, 6)
|
|
atol, rtol = (
|
|
(3e-4, 3e-2)
|
|
if dtype == torch.float32 and is_cuda_sm86
|
|
else (dtype2prec_DONTUSE[dtype], 0)
|
|
)
|
|
|
|
self.assertEqual(
|
|
output, torch.cat([output1, output2], 1), atol=atol, rtol=rtol
|
|
)
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.bias.grad.data,
|
|
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=atol,
|
|
rtol=rtol,
|
|
)
|
|
|
|
@onlyCUDA
|
|
@dtypes(
|
|
*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
|
|
)
|
|
def test_noncontig_conv_grad(self, device, dtype):
|
|
# FIXME: remove after adding non-contiguous grad tests for all modules
|
|
module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype)
|
|
input = torch.randn(
|
|
2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True
|
|
)
|
|
output = module(input)
|
|
|
|
grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1]
|
|
assert not grad.is_contiguous()
|
|
output.backward(grad, retain_graph=True)
|
|
self.assertIsNotNone(input.grad)
|
|
result = input.grad.data.clone()
|
|
input.grad.data.zero_()
|
|
|
|
output.backward(grad.contiguous())
|
|
self.assertEqual(
|
|
result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0
|
|
)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.double)
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@torch.backends.miopen.flags(immediate=True)
|
|
def test_conv_double_backward(self, device, dtype):
|
|
# Double backward only runs with DoubleTensor due to precision reason
|
|
batch_size = 1
|
|
for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]:
|
|
for stride, padding, chan_in, chan_out, dilation in product(
|
|
[1], [2], [2], [3], dilations
|
|
):
|
|
no_weight = stride == 2
|
|
result = self.run_conv_double_back_test(
|
|
kern,
|
|
stride,
|
|
padding,
|
|
chan_in,
|
|
chan_out,
|
|
batch_size,
|
|
inp_size,
|
|
dilation,
|
|
no_weight,
|
|
use_cuda=True,
|
|
dtype=dtype,
|
|
)
|
|
self.assertTrue(
|
|
result,
|
|
"Conv double backward test failed with parameters:"
|
|
+ "\nkern: "
|
|
+ str(kern)
|
|
+ "\nstride: "
|
|
+ str(stride)
|
|
+ "\npadding: "
|
|
+ str(padding)
|
|
+ "\nchan_in: "
|
|
+ str(chan_in)
|
|
+ "\nchan_out: "
|
|
+ str(chan_out)
|
|
+ "\nbatch_size: "
|
|
+ str(batch_size)
|
|
+ "\ninp_size: "
|
|
+ str(inp_size)
|
|
+ "\ndilation: "
|
|
+ str(dilation),
|
|
)
|
|
|
|
def test_conv_double_backward_no_bias(self):
|
|
kern = 3
|
|
stride = 2
|
|
chan_in, chan_out = 2, 4
|
|
batch_size = 2
|
|
inp_size = 5
|
|
padding = 1
|
|
dilation = 1
|
|
no_weight = False
|
|
use_bias = True
|
|
result = self.run_conv_double_back_test(
|
|
kern,
|
|
stride,
|
|
padding,
|
|
chan_in,
|
|
chan_out,
|
|
batch_size,
|
|
inp_size,
|
|
dilation,
|
|
no_weight,
|
|
use_bias=use_bias,
|
|
)
|
|
self.assertTrue(
|
|
result,
|
|
"Conv double backward test failed with parameters:"
|
|
+ "\nkern: "
|
|
+ str(kern)
|
|
+ "\nstride: "
|
|
+ str(stride)
|
|
+ "\npadding: "
|
|
+ str(padding)
|
|
+ "\nchan_in: "
|
|
+ str(chan_in)
|
|
+ "\nchan_out: "
|
|
+ str(chan_out)
|
|
+ "\nbatch_size: "
|
|
+ str(batch_size)
|
|
+ "\ninp_size: "
|
|
+ str(inp_size)
|
|
+ "\ndilation: "
|
|
+ str(dilation),
|
|
)
|
|
|
|
def test_conv_double_backward_groups(self):
|
|
kern = 3
|
|
stride = 1
|
|
padding = 2
|
|
chan_in, chan_out = 2, 4
|
|
batch_size = 2
|
|
inp_size = 6
|
|
dilation = 1
|
|
no_weight = False
|
|
groups = 2
|
|
result = self.run_conv_double_back_test(
|
|
kern,
|
|
stride,
|
|
padding,
|
|
chan_in * groups,
|
|
chan_out * groups,
|
|
batch_size,
|
|
inp_size,
|
|
dilation,
|
|
no_weight,
|
|
groups=groups,
|
|
)
|
|
self.assertTrue(
|
|
result,
|
|
"Conv double backward test failed with parameters:"
|
|
+ "\nkern: "
|
|
+ str(kern)
|
|
+ "\nstride: "
|
|
+ str(stride)
|
|
+ "\npadding: "
|
|
+ str(padding)
|
|
+ "\nchan_in: "
|
|
+ str(chan_in)
|
|
+ "\nchan_out: "
|
|
+ str(chan_out)
|
|
+ "\nbatch_size: "
|
|
+ str(batch_size)
|
|
+ "\ninp_size: "
|
|
+ str(inp_size)
|
|
+ "\ndilation: "
|
|
+ str(dilation)
|
|
+ "\ngroups: "
|
|
+ str(groups),
|
|
)
|
|
|
|
def test_conv_double_backward_stride(self):
|
|
batch_size = 2
|
|
|
|
# Cannot provide ggW when stride is > 1
|
|
for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]:
|
|
for stride, padding, chan_in, chan_out, dilation in product(
|
|
[2], [0, 1], [1], [2], dilations
|
|
):
|
|
no_weight = False
|
|
self.run_conv_double_back_test(
|
|
kern,
|
|
stride,
|
|
padding,
|
|
chan_in,
|
|
chan_out,
|
|
batch_size,
|
|
inp_size,
|
|
dilation,
|
|
no_weight,
|
|
)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@torch.backends.miopen.flags(immediate=True)
|
|
def test_conv1d_same_padding(self, device, dtype):
|
|
# Test padding='same' outputs the correct shape
|
|
test_args = [
|
|
# in_size
|
|
range(50, 55),
|
|
# kernel_size
|
|
[1, 2, 3, 8],
|
|
# dilation
|
|
range(1, 4),
|
|
# stride
|
|
[1],
|
|
]
|
|
for in_size, k_size, dilation, stride in itertools.product(*test_args):
|
|
x = torch.rand(1, 1, in_size, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, k_size, device=device, dtype=dtype)
|
|
z = F.conv1d(x, y, padding="same", dilation=dilation, stride=stride)
|
|
self.assertEqual(z.size(2), int(math.ceil(in_size / stride)))
|
|
|
|
# Compare F.conv1d padding='same' output against manual padding
|
|
# Without strides/dilation
|
|
x = torch.rand(1, 1, 12, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 3, device=device, dtype=dtype)
|
|
expect = F.conv1d(x, y, padding=1)
|
|
actual = F.conv1d(x, y, padding="same")
|
|
self.assertEqual(expect, actual)
|
|
|
|
# With dilation
|
|
x = torch.rand(1, 1, 12, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 4, device=device, dtype=dtype)
|
|
expect = F.conv1d(x, y, padding=3, dilation=2)
|
|
actual = F.conv1d(x, y, padding="same", dilation=2)
|
|
self.assertEqual(expect, actual)
|
|
|
|
# Dilation with asymmetric padding
|
|
expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:]
|
|
actual = F.conv1d(x, y, padding="same", dilation=3)
|
|
self.assertEqual(expect, actual)
|
|
|
|
@tf32_on_and_off(0.005)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
@dtypes(torch.float, torch.cfloat)
|
|
def test_conv2d_same_padding(self, device, dtype):
|
|
# Compare F.conv2d padding='same' output against manual padding
|
|
# Without strides/dilation
|
|
x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype)
|
|
expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
|
|
actual = F.conv2d(x, y, padding="same")
|
|
self.assertEqual(expect, actual)
|
|
|
|
# With dilation
|
|
y = torch.rand(1, 1, 3, 4, device=device, dtype=dtype)
|
|
expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
|
|
actual = F.conv2d(x, y, padding="same", dilation=2)
|
|
self.assertEqual(expect, actual)
|
|
|
|
# Dilation with asymmetric padding
|
|
y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype)
|
|
expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
|
|
actual = F.conv2d(x, y, padding="same", dilation=3)
|
|
self.assertEqual(expect, actual)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
def test_conv3d_same_padding(self, device, dtype):
|
|
if dtype is torch.cfloat:
|
|
rtol, atol = 2e-6, 2e-6
|
|
else:
|
|
rtol, atol = None, None
|
|
# Compare F.conv3d padding='same' output against manual padding
|
|
# Without strides/dilation
|
|
x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype)
|
|
expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :]
|
|
actual = F.conv3d(x, y, padding="same")
|
|
self.assertEqual(expect, actual, rtol=rtol, atol=atol)
|
|
|
|
# With dilation
|
|
expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
|
|
actual = F.conv3d(x, y, padding="same", dilation=2)
|
|
self.assertEqual(expect, actual, rtol=rtol, atol=atol)
|
|
|
|
# Dilation with asymmetric padding
|
|
y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype)
|
|
expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:]
|
|
actual = F.conv3d(x, y, padding="same", dilation=3)
|
|
self.assertEqual(expect, actual, rtol=rtol, atol=atol)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
def test_conv1d_valid_padding(self, device, dtype):
|
|
# Test F.conv1d padding='valid' is the same as no padding
|
|
x = torch.rand(1, 1, 10, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 4, device=device, dtype=dtype)
|
|
expect = F.conv1d(x, y)
|
|
actual = F.conv1d(x, y, padding="valid")
|
|
self.assertEqual(expect, actual)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
def test_conv2d_valid_padding(self, device, dtype):
|
|
# Test F.conv2d padding='valid' is the same as no padding
|
|
x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype)
|
|
y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype)
|
|
expect = F.conv2d(x, y)
|
|
actual = F.conv2d(x, y, padding="valid")
|
|
self.assertEqual(expect, actual)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
def test_conv3d_valid_padding(self, device, dtype):
|
|
# Test F.conv3d padding='valid' is the same as no padding
|
|
x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device)
|
|
y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device)
|
|
expect = F.conv3d(x, y)
|
|
actual = F.conv3d(x, y, padding="valid")
|
|
self.assertEqual(expect, actual)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
def test_conv1d_same_padding_backward(self, device, dtype):
|
|
# Test F.conv1d gradients work with padding='same'
|
|
x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True)
|
|
y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
|
|
|
|
# Symmetric padding
|
|
z = F.conv1d(x, y, padding=3, dilation=2)
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv1d(x, y, padding="same", dilation=2)
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
x.grad, y.grad = None, None
|
|
|
|
# Asymmetric padding
|
|
z = F.conv1d(x, y, padding=2)[..., 1:]
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv1d(x, y, padding="same")
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
@tf32_on_and_off(0.001)
|
|
def test_conv2d_same_padding_backward(self, device, dtype):
|
|
# Test F.conv2d gradients work with padding='same'
|
|
x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True)
|
|
y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True)
|
|
|
|
# Symmetric padding
|
|
z = F.conv2d(x, y, padding=(3, 4), dilation=2)
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv2d(x, y, padding="same", dilation=2)
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
x.grad, y.grad = None, None
|
|
|
|
# Asymmetric padding
|
|
y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True)
|
|
z = F.conv2d(x, y, padding=2)[..., 1:, 1:]
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv2d(x, y, padding="same")
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
|
|
@dtypes(torch.double, torch.cdouble)
|
|
@dtypesIfMPS(
|
|
torch.float, torch.cfloat
|
|
) # Double, complex double not supported on MPS
|
|
@expectedFailureMPS # https://github.com/pytorch/pytorch/issues/107214
|
|
def test_conv3d_same_padding_backward(self, device, dtype):
|
|
check_forward_ad = torch.device(device).type != "xla"
|
|
|
|
# Test F.conv3d gradients work with padding='same'
|
|
x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True)
|
|
y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True)
|
|
|
|
# Symmetric padding
|
|
z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv3d(x, y, padding="same", dilation=2)
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
x.grad, y.grad = None, None
|
|
|
|
gradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
|
|
(x, y),
|
|
check_forward_ad=check_forward_ad,
|
|
nondet_tol=1e-5,
|
|
)
|
|
if torch.device(device).type != "cuda":
|
|
# https://github.com/pytorch/pytorch/issues/70702
|
|
gradgradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
|
|
(x, y),
|
|
check_fwd_over_rev=True,
|
|
)
|
|
|
|
# Asymmetric padding
|
|
y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True)
|
|
z = F.conv3d(x, y, padding=2)[..., 1:, 1:]
|
|
z.sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
z = F.conv3d(x, y, padding="same")
|
|
z.sum().abs().backward()
|
|
self.assertEqual(gx_expect, x.grad)
|
|
self.assertEqual(gy_expect, y.grad)
|
|
|
|
gradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="same"),
|
|
(x, y),
|
|
check_forward_ad=check_forward_ad,
|
|
nondet_tol=1e-5,
|
|
)
|
|
if torch.device(device).type != "cuda":
|
|
# https://github.com/pytorch/pytorch/issues/70702
|
|
gradgradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="same"),
|
|
(x, y),
|
|
check_fwd_over_rev=True,
|
|
)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
def test_conv1d_valid_padding_backward(self, device, dtype):
|
|
# Test F.conv1d gradients work with padding='valid'
|
|
x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True)
|
|
y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
|
|
F.conv1d(x, y, padding=0).sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
F.conv1d(x, y, padding="valid").sum().abs().backward()
|
|
gx_actual, gy_actual = x.grad, y.grad
|
|
self.assertEqual(gx_expect, gx_actual)
|
|
self.assertEqual(gy_expect, gy_actual)
|
|
|
|
@unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
@parametrize_test("mode", ("valid", "same"))
|
|
def test_conv1d_vs_scipy(self, device, dtype, mode):
|
|
t = make_tensor((1, 10), device=device, dtype=dtype)
|
|
feat_dim = t.shape[1]
|
|
weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype)
|
|
weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype)
|
|
|
|
def _test(t, weight, mode):
|
|
# SciPy expects two 1-D inputs.
|
|
t_a = t.view(-1).cpu().numpy()
|
|
w_a = weight.view(-1).cpu().numpy()
|
|
expected = scipy.signal.convolve(t_a, w_a, mode=mode)
|
|
|
|
kwargs = {"padding": mode}
|
|
if mode == "same":
|
|
# `same` padding in PyTorch conv1d is different
|
|
# from SciPy
|
|
p = weight.shape[2] // 2
|
|
t = torch.nn.functional.pad(t, (p, p))
|
|
# We have already taken care of padding
|
|
kwargs.pop("padding")
|
|
|
|
# second input is flipped in SciPy's convolve
|
|
weight_flipped = torch.flip(weight, (2,))
|
|
actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0)
|
|
if mode == "same":
|
|
actual = actual[:feat_dim]
|
|
|
|
self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5)
|
|
|
|
# Global dtype for this test suite is torch.double
|
|
# This leads to change in type-promotion
|
|
# and conv1d outputs `complex128` for `complex64` input.
|
|
with set_default_dtype(torch.float):
|
|
_test(t, weight_even, mode)
|
|
_test(t, weight_odd, mode)
|
|
|
|
@unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
@parametrize_test("mode", ("valid", "same"))
|
|
def test_conv2d_vs_scipy(self, device, dtype, mode):
|
|
t = make_tensor((1, 5, 10), device=device, dtype=dtype)
|
|
weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype)
|
|
weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype)
|
|
|
|
def _test(t, weight, mode):
|
|
# SciPy expects two 2-D inputs.
|
|
t_a = t.squeeze(0).cpu().numpy()
|
|
w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
|
|
expected = scipy.signal.convolve2d(t_a, w_a, mode=mode)
|
|
|
|
kwargs = {"padding": mode}
|
|
if mode == "same":
|
|
# `same` padding in PyTorch conv2d is different
|
|
# from SciPy
|
|
left_right_pad = weight.shape[3] // 2
|
|
top_bottom_pad = weight.shape[2] // 2
|
|
p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad)
|
|
t = torch.nn.functional.pad(t, p)
|
|
# We have already taken care of padding
|
|
kwargs.pop("padding")
|
|
|
|
# second input is flipped in SciPy's convolve2d
|
|
weight_flipped = torch.flip(weight, (2, 3))
|
|
actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0)
|
|
if mode == "same":
|
|
actual = actual[:5, :10]
|
|
|
|
self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
|
|
|
|
# Global dtype for this test suite is torch.double
|
|
# This leads to change in type-promotion
|
|
# and conv1d outputs `complex128` for `complex64` input.
|
|
with set_default_dtype(torch.float):
|
|
_test(t, weight_even, mode)
|
|
_test(t, weight_odd, mode)
|
|
|
|
@unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
|
|
@skipMPS # Results in CI are inconsistent, forced to skip
|
|
@dtypes(torch.float, torch.cfloat)
|
|
@parametrize_test("mode", ("valid", "same"))
|
|
def test_conv3d_vs_scipy(self, device, dtype, mode):
|
|
t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype)
|
|
weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype)
|
|
weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype)
|
|
|
|
def _test(t, weight, mode):
|
|
# SciPy expects two 3-D inputs.
|
|
t_a = t.squeeze(0).cpu().numpy()
|
|
w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
|
|
expected = scipy.signal.convolve(t_a, w_a, mode=mode)
|
|
|
|
kwargs = {"padding": mode}
|
|
if mode == "same":
|
|
# `same` padding in PyTorch conv3d is different
|
|
# from SciPy
|
|
left_right_pad = weight.shape[4] // 2
|
|
top_bottom_pad = weight.shape[3] // 2
|
|
front_back_pad = weight.shape[2] // 2
|
|
p = (
|
|
left_right_pad,
|
|
left_right_pad,
|
|
top_bottom_pad,
|
|
top_bottom_pad,
|
|
front_back_pad,
|
|
front_back_pad,
|
|
)
|
|
t = torch.nn.functional.pad(t, p)
|
|
# We have already taken care of padding
|
|
kwargs.pop("padding")
|
|
|
|
# second input is flipped in SciPy's convolve
|
|
weight_flipped = torch.flip(weight, (2, 3, 4))
|
|
actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0)
|
|
if mode == "same":
|
|
actual = actual[:5, :5, :10]
|
|
|
|
if torch.cuda.is_tf32_supported() and (
|
|
dtype == torch.float or dtype == torch.complex64
|
|
):
|
|
self.assertEqual(actual, expected, atol=0.05, rtol=0.05)
|
|
else:
|
|
self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
|
|
|
|
# Global dtype for this test suite is torch.double
|
|
# This leads to change in type-promotion
|
|
# and conv1d outputs `complex128` for `complex64` input.
|
|
with set_default_dtype(torch.float):
|
|
_test(t, weight_even, mode)
|
|
_test(t, weight_odd, mode)
|
|
|
|
@dtypes(torch.float, torch.complex64)
|
|
@dtypesIfMPS(
|
|
*([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
|
|
) # Complex not supported on MacOS13
|
|
def test_conv2d_valid_padding_backward(self, device, dtype):
|
|
# Test F.conv2d gradients work with padding='valid'
|
|
x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True)
|
|
y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True)
|
|
F.conv2d(x, y, padding=0).sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
F.conv2d(x, y, padding="valid").sum().abs().backward()
|
|
gx_actual, gy_actual = x.grad, y.grad
|
|
self.assertEqual(gx_expect, gx_actual)
|
|
self.assertEqual(gy_expect, gy_actual)
|
|
|
|
@dtypes(torch.double, torch.cdouble)
|
|
@dtypesIfMPS(
|
|
torch.float, torch.cfloat
|
|
) # Double, complex double not supported on MPS
|
|
@expectedFailureMPS # https://github.com/pytorch/pytorch/issues/107214
|
|
def test_conv3d_valid_padding_backward(self, device, dtype):
|
|
check_forward_ad = torch.device(device).type != "xla"
|
|
|
|
# Test F.conv3d gradients work with padding='valid'
|
|
x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True)
|
|
y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True)
|
|
F.conv3d(x, y, padding=0).sum().abs().backward()
|
|
gx_expect, gy_expect = x.grad, y.grad
|
|
x.grad, y.grad = None, None
|
|
|
|
F.conv3d(x, y, padding="valid").sum().abs().backward()
|
|
gx_actual, gy_actual = x.grad, y.grad
|
|
self.assertEqual(gx_expect, gx_actual)
|
|
self.assertEqual(gy_expect, gy_actual)
|
|
|
|
gradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="valid"),
|
|
(x, y),
|
|
check_forward_ad=check_forward_ad,
|
|
)
|
|
gradgradcheck(
|
|
lambda x, y: F.conv3d(x, y, padding="valid"),
|
|
(x, y),
|
|
check_fwd_over_rev=check_forward_ad,
|
|
)
|
|
|
|
@parametrize_test(
|
|
arg_str="N",
|
|
arg_values=[
|
|
subtest(arg_values=(2), name="ConvTranspose2d"),
|
|
subtest(arg_values=(3), name="ConvTranspose3d"),
|
|
],
|
|
)
|
|
def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
|
|
# For inputs with no batch dim, verify output is the correct shape when output_size is set.
|
|
# See https://github.com/pytorch/pytorch/issues/75889
|
|
inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device)
|
|
output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200)
|
|
ConvTransposeNd = getattr(nn, f"ConvTranspose{N}d")
|
|
m = ConvTransposeNd(
|
|
1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device
|
|
)
|
|
output = m(inp, output_size=output_size)
|
|
self.assertEqual(output.shape, output_size)
|
|
|
|
@skipMeta
|
|
@parametrize_test(
|
|
"input_shape,transposed,dilated,groups,layout,backend_expected",
|
|
[
|
|
# === slow ===
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Slow2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowTranspose2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow1d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowDilated2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow1d_dilated",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
True,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowTranspose2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow1d_dilated_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Slow2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowTranspose2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow2d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowDilated2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow2d_dilated",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
True,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowTranspose2d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow2d_dilated_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Slow3d,
|
|
),
|
|
decorators=[onlyCPU, disableMkldnn],
|
|
name="slow3d_cpu",
|
|
),
|
|
# CUDA doesn't have a slow 3D implementation, so it goes to the dilated 3D implementation instead
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowDilated3d,
|
|
),
|
|
decorators=[onlyCUDA, disablecuDNN],
|
|
name="slow3d_cuda",
|
|
),
|
|
# FIXME: RuntimeError: CUDA out of memory.
|
|
# subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
|
|
# decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.SlowDilated3d,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
|
|
name="slow3d_dilated",
|
|
),
|
|
# FIXME: RuntimeError: CUDA out of memory.
|
|
# subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
|
|
# decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'),
|
|
subtest(
|
|
(
|
|
(0, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_channel1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch_channel1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_channel2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch_channel2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_channel3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Empty,
|
|
),
|
|
decorators=[onlyNativeDeviceTypes, disableMkldnn],
|
|
name="empty_batch_channel3d",
|
|
),
|
|
# === cuda ===
|
|
# Note that disablecuDNN disables miopen as well.
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.CudaDepthwise2d,
|
|
),
|
|
decorators=[onlyCUDA, disablecuDNN],
|
|
name="cuda_depthwise1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.CudaDepthwise2d,
|
|
),
|
|
decorators=[onlyCUDA, disablecuDNN],
|
|
name="cuda_depthwise2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.CudaDepthwise3d,
|
|
),
|
|
decorators=[onlyCUDA, disablecuDNN],
|
|
name="cuda_depthwise3d",
|
|
),
|
|
# === cudnn ===
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Cudnn,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
|
|
name="cudnn1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Cudnn,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
|
|
name="cudnn2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Cudnn,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
|
|
name="cudnn3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.CudnnTranspose,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
|
|
name="cudnn1d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.CudnnTranspose,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
|
|
name="cudnn2d_transposed",
|
|
),
|
|
# FIXME: RuntimeError: CUDA out of memory.
|
|
# subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose),
|
|
# decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'),
|
|
# === miopen ===
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Miopen,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Miopen,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Miopen,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenTranspose,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen1d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenTranspose,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen2d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
True,
|
|
False,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenTranspose,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen3d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenDepthwise,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen_depthwise1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenDepthwise,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen_depthwise2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
6,
|
|
torch.strided,
|
|
torch._C._ConvBackend.MiopenDepthwise,
|
|
),
|
|
decorators=[onlyCUDA, skipCUDAIfNoMiopen],
|
|
name="miopen_depthwise3d",
|
|
),
|
|
# === mkldnn ===
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn3d",
|
|
),
|
|
# Transposed convolution is broken for mkldnn. See https://github.com/pytorch/pytorch/issues/68775.
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
True,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
|
|
name="mkldnn1d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
True,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
|
|
name="mkldnn2d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
True,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
|
|
name="mkldnn3d_transposed",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn1d_cpu_input",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn2d_cpu_input",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 6, 7, 8, 9),
|
|
False,
|
|
True,
|
|
3,
|
|
torch.strided,
|
|
torch._C._ConvBackend.Mkldnn,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn3d_cpu_input",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 6, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_channel1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch_channel1d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 6, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_channel2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7, 8),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch_channel2d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 6, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(2, 0, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_channel3d",
|
|
),
|
|
subtest(
|
|
(
|
|
(0, 0, 7, 8, 9),
|
|
False,
|
|
False,
|
|
3,
|
|
torch._mkldnn,
|
|
torch._C._ConvBackend.MkldnnEmpty,
|
|
),
|
|
decorators=[onlyCPU, skipCPUIfNoMkldnn],
|
|
name="mkldnn_empty_batch_channel3d",
|
|
),
|
|
# Note: Tests for mobile backends are not currently supported. This comprises
|
|
# NnpackSpatial, Winograd3x3Depthwise, and Xnnpack2d backends. Testing these
|
|
# requires the ability to gate tests by whether PyTorch is built with USE_MOBILE=1.
|
|
],
|
|
)
|
|
# Test with both bias and no bias.
|
|
@parametrize_test("has_bias", [False, True])
|
|
# Test with both stride=1 and stride>1 cases.
|
|
@parametrize_test("strided", [False, True])
|
|
# Test with both contiguous and non-contiguous inputs.
|
|
@parametrize_test("contiguous", [False, True])
|
|
@expectedFailureMPS # No double support
|
|
def test_conv_backend(
|
|
self,
|
|
device,
|
|
input_shape,
|
|
has_bias,
|
|
strided,
|
|
contiguous,
|
|
transposed,
|
|
dilated,
|
|
groups,
|
|
layout,
|
|
backend_expected,
|
|
):
|
|
# Build up inputs.
|
|
dtype = torch.float32
|
|
C_in, C_out, dim, kernel_size = input_shape[1], 12, len(input_shape) - 2, 3
|
|
x = torch.randn(*input_shape, device=device, dtype=dtype, requires_grad=True)
|
|
weight = torch.randn(
|
|
C_in if transposed else C_out,
|
|
C_out // groups if transposed else C_in // groups,
|
|
*[kernel_size for _ in range(dim)],
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
)
|
|
bias = (
|
|
torch.randn(C_out, device=device, dtype=dtype, requires_grad=True)
|
|
if has_bias
|
|
else None
|
|
)
|
|
|
|
def _make_noncontiguous(inp):
|
|
if inp is None:
|
|
return None
|
|
old_requires_grad = inp.requires_grad
|
|
inp = torch.repeat_interleave(inp, 2, dim=-1)
|
|
inp = inp[..., ::2].detach().requires_grad_(old_requires_grad)
|
|
return inp
|
|
|
|
if not contiguous:
|
|
x = _make_noncontiguous(x)
|
|
weight = _make_noncontiguous(weight)
|
|
bias = _make_noncontiguous(bias)
|
|
|
|
if layout is torch._mkldnn:
|
|
x = x.to_mkldnn()
|
|
# Note that weight and bias are not supported as mkldnn tensors during training.
|
|
|
|
stride = (2,) * dim if strided else (1,) * dim
|
|
padding = (0,) * dim
|
|
dilation = (2,) * dim if dilated else (1,) * dim
|
|
output_padding = (0,) * dim
|
|
inputs = [
|
|
x,
|
|
weight,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
]
|
|
|
|
# Ensure correct backend is selected.
|
|
backend_actual = torch._C._select_conv_backend(*inputs)
|
|
self.assertEqual(backend_actual, backend_expected)
|
|
|
|
# Ensure backward call succeeds.
|
|
convolution = torch.ops.aten.convolution
|
|
output = convolution(*inputs)
|
|
grad_output = torch.randn(output.shape, device=device, dtype=dtype)
|
|
if not contiguous:
|
|
grad_output = _make_noncontiguous(grad_output)
|
|
if layout is torch._mkldnn:
|
|
grad_output = grad_output.to_mkldnn()
|
|
output.backward(grad_output)
|
|
|
|
# mkldnn doesn't support gradcheck :(
|
|
if layout is torch._mkldnn:
|
|
return
|
|
|
|
if backend_actual != torch._C._ConvBackend.Empty: # FIXME: forward AD fails
|
|
# Forward AD and forward-over-reverse AD smoke test in float32
|
|
# TODO: remove this if we introduce per-op gradient tests for float32
|
|
with fwAD.dual_level():
|
|
dual_inputs = [
|
|
(
|
|
fwAD.make_dual(i, torch.rand_like(i))
|
|
if isinstance(i, torch.Tensor)
|
|
else i
|
|
)
|
|
for i in inputs
|
|
]
|
|
# Forward AD
|
|
output = convolution(*dual_inputs)
|
|
# Forward over reverse AD
|
|
grad_output_d = fwAD.make_dual(
|
|
torch.rand_like(output), torch.rand_like(output)
|
|
)
|
|
if has_bias:
|
|
torch.autograd.grad(output, [x, weight, bias], grad_output_d)
|
|
else:
|
|
torch.autograd.grad(output, [x, weight], grad_output_d)
|
|
|
|
# Convert to float64 for gradcheck.
|
|
x = x.to(torch.float64).detach().requires_grad_(True)
|
|
weight = weight.to(torch.float64).detach().requires_grad_(True)
|
|
if bias is not None:
|
|
bias = bias.to(torch.float64).detach().requires_grad_(True)
|
|
inputs = [
|
|
x,
|
|
weight,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
]
|
|
|
|
# Set some backend-specific validation settings.
|
|
gradcheck_nondet_tol = 0.0
|
|
if torch.backends.cudnn.is_available():
|
|
# cuDNN introduces non-determinism
|
|
gradcheck_nondet_tol = GRADCHECK_NONDET_TOL
|
|
|
|
self.assertTrue(gradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol))
|
|
|
|
# double backward doesn't support bias gradients
|
|
if bias is not None:
|
|
bias.requires_grad_(False)
|
|
self.assertTrue(
|
|
gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)
|
|
)
|
|
|
|
@onlyCPU
|
|
def test_conv_contiguous_for_oneDNN(self):
|
|
# See https://github.com/pytorch/pytorch/issues/80837.
|
|
for dtype in [torch.float, torch.bfloat16, torch.half]:
|
|
conv = nn.Conv2d(
|
|
1,
|
|
128,
|
|
kernel_size=(5, 2),
|
|
stride=(2, 1),
|
|
padding=(0, 1),
|
|
dilation=(1, 1),
|
|
groups=1,
|
|
bias=True,
|
|
padding_mode="zeros",
|
|
).to(dtype=dtype)
|
|
|
|
x = torch.rand([1, 2, 321, 201, 1]).to(dtype=dtype)
|
|
x = torch.transpose(x, 1, 4)
|
|
x2 = x[..., 0]
|
|
if torch.backends.mkldnn.is_available():
|
|
y = conv(x2)
|
|
# Disable MKLDNN explicitly
|
|
with torch.backends.mkldnn.flags(enabled=False):
|
|
y_ = conv(x2)
|
|
self.assertEqual(y, y_)
|
|
|
|
@onlyCPU
|
|
def test_conv_ic1_channels_last_for_oneDNN(self):
|
|
# See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path.
|
|
for dtype in [torch.float, torch.bfloat16, torch.half]:
|
|
conv = torch.nn.Conv2d(
|
|
1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False
|
|
)
|
|
conv = conv.to(memory_format=torch.channels_last).to(dtype=dtype)
|
|
x = torch.rand(2, 1, 100, 100).to(dtype=dtype)
|
|
if torch.backends.mkldnn.is_available():
|
|
y = conv(x)
|
|
# Disable MKLDNN explicitly
|
|
with torch.backends.mkldnn.flags(enabled=False):
|
|
y_ = conv(x)
|
|
self.assertEqual(y, y_)
|
|
|
|
@dtypes(torch.float, torch.cfloat)
|
|
def test_conv_empty_channel(self, device, dtype):
|
|
in_channels = 0
|
|
mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device)
|
|
inp = torch.randn(2, 0, 15, device=device, dtype=dtype)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
|
inp = torch.randn(2, 1, 0, device=device, dtype=dtype)
|
|
mod(inp)
|
|
|
|
mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
|
|
inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
|
inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype)
|
|
mod(inp)
|
|
|
|
mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
|
|
inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
|
inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype)
|
|
mod(inp)
|
|
|
|
def test_group_conv_empty(self, device):
|
|
mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(
|
|
device
|
|
)
|
|
inp = torch.randn(0, 4, 4, 4, device=device)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
if self.device_type == "cuda" and self.has_cudnn():
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
def test_group_convTranspose_empty(self, device):
|
|
mod = torch.nn.ConvTranspose2d(
|
|
4, 4, stride=2, kernel_size=3, padding=1, groups=4
|
|
).to(device)
|
|
inp = torch.randn(0, 4, 4, 4, device=device)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
if self.device_type == "cuda" and self.has_cudnn():
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
def test_convTranspose_empty(self, device):
|
|
mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to(
|
|
device
|
|
)
|
|
inp = torch.randn(0, 4, 4, 4, device=device)
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
if self.device_type == "cuda" and self.has_cudnn():
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
_test_module_empty_input(self, mod, inp, check_size=False)
|
|
|
|
@onlyCUDA
|
|
@largeTensorTest("12GB")
|
|
def test_conv_large_nosplit(self, device):
|
|
# Here we just test the convolution correctly route to the fallback implementation
|
|
# that is, it does not crash. The correctness of fallback implementation should be
|
|
# covered in other tests
|
|
dtype = torch.half if self.device_type == "cuda" else torch.float
|
|
conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype)
|
|
input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device)
|
|
conv1(input_large)
|
|
conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype)
|
|
input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device)
|
|
conv2(input_large)
|
|
|
|
def test_conv_noncontig_weights(self, device):
|
|
for dim in (1, 2, 3):
|
|
for grouped in (False, True):
|
|
nc = 3
|
|
groups = 3 if grouped else 1
|
|
w = torch.randn([3] * dim, device=device)
|
|
w = w.expand([nc, int(nc / groups)] + list(w.shape))
|
|
w = w.detach().requires_grad_()
|
|
x = torch.randn(
|
|
[1, nc] + ([5] * dim), device=device, requires_grad=True
|
|
)
|
|
y = getattr(F, f"conv{dim}d")(x, w, groups=groups)
|
|
y.sum().backward()
|
|
y = getattr(F, f"conv_transpose{dim}d")(x, w, groups=groups)
|
|
y.sum().backward()
|
|
|
|
def test_conv_noncontig_weights_and_bias(self, device):
|
|
# need floats to exercise https://github.com/pytorch/pytorch/issues/16018
|
|
for bias in [True, False]:
|
|
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=bias).to(
|
|
device, torch.float
|
|
)
|
|
|
|
input_nc = torch.randn(
|
|
(1, 3, 224, 224, 2), device=device, dtype=torch.float
|
|
)[:, :, :, :, 1]
|
|
input_c = input_nc.contiguous()
|
|
|
|
weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[
|
|
:, :, :, :, 1
|
|
]
|
|
conv1.weight = nn.Parameter(weight_nc)
|
|
weight_c = conv1.weight.contiguous()
|
|
|
|
if bias:
|
|
bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1]
|
|
conv1.bias = nn.Parameter(bias_nc)
|
|
bias_c = conv1.bias.contiguous()
|
|
|
|
out1 = conv1(input_nc)
|
|
conv1.weight = nn.Parameter(weight_c)
|
|
if bias:
|
|
conv1.bias = nn.Parameter(bias_c)
|
|
out2 = conv1(input_c)
|
|
self.assertEqual(out1, out2)
|
|
|
|
@onlyCUDA
|
|
@largeTensorTest("12GB")
|
|
def test_conv_transposed_large(self, device):
|
|
dtype = torch.half if self.device_type == "cuda" else torch.float
|
|
conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
|
|
input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device)
|
|
# forward
|
|
ret = conv(input_large)
|
|
maxdiff0 = (
|
|
(ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024)))
|
|
.abs_()
|
|
.max()
|
|
.item()
|
|
)
|
|
maxdiff1 = (
|
|
(ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024)))
|
|
.abs_()
|
|
.max()
|
|
.item()
|
|
)
|
|
maxdiff2 = (
|
|
(ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024)))
|
|
.abs_()
|
|
.max()
|
|
.item()
|
|
)
|
|
maxdiff3 = (
|
|
(ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024)))
|
|
.abs_()
|
|
.max()
|
|
.item()
|
|
)
|
|
if self.device_type == "cuda":
|
|
# cuDNN may use algorithms such as FFT that don't guarantee a diff of 0
|
|
self.assertEqual(maxdiff0, 0, atol=2e-3, rtol=1e-5)
|
|
self.assertEqual(maxdiff1, 0, atol=2e-3, rtol=1e-5)
|
|
self.assertEqual(maxdiff2, 0, atol=2e-3, rtol=1e-5)
|
|
self.assertEqual(maxdiff3, 0, atol=2e-3, rtol=1e-5)
|
|
else:
|
|
self.assertEqual(maxdiff0, 0)
|
|
self.assertEqual(maxdiff1, 0)
|
|
self.assertEqual(maxdiff2, 0)
|
|
self.assertEqual(maxdiff3, 0)
|
|
|
|
@onlyCUDA
|
|
@largeTensorTest("12GB")
|
|
def test_conv_large(self, device):
|
|
dtype = torch.half if self.device_type == "cuda" else torch.float
|
|
conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype)
|
|
input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device)
|
|
# forward
|
|
ret = conv(input_large)
|
|
self.assertEqual(ret[:2048], conv(input_large[:2048]))
|
|
self.assertEqual(ret[2048:4096], conv(input_large[2048:4096]))
|
|
self.assertEqual(ret[4096:], conv(input_large[4096:]))
|
|
|
|
# backward
|
|
conv.zero_grad()
|
|
# When computing the backward, we are using the `max(dim=1)`` to create
|
|
# some sparsity. Without this sparsity, the rounding error would be
|
|
# too large (as large as 1e-5) to satisfy the creterion (1e-6) of `assertEqual`
|
|
ret.view(4097, -1).max(dim=1).values.sum().backward()
|
|
del ret
|
|
grad1 = conv.weight.grad.detach().clone()
|
|
conv.zero_grad()
|
|
conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward()
|
|
conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward()
|
|
conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward()
|
|
grad2 = conv.weight.grad.detach().clone()
|
|
# gradients are at the order of hundreds, we need to scale it to
|
|
# the order of one so that we can compare
|
|
scale = 1 / grad2.abs().mean()
|
|
grad1 = grad1 * scale
|
|
grad2 = grad2 * scale
|
|
self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
|
|
|
|
@onlyCUDA
|
|
@largeTensorTest("20GB", "cpu")
|
|
@largeTensorTest("60GB", "cuda")
|
|
def test_conv_large_batch_1(self, device):
|
|
in_channels = 514
|
|
dim = 2048
|
|
out_channels = 1
|
|
kernel_size = 3
|
|
stride = 1
|
|
padding = 1
|
|
|
|
input_tensor = torch.ones(1, in_channels, dim, dim).cuda().half()
|
|
model = (
|
|
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
|
.cuda()
|
|
.half()
|
|
)
|
|
output = model(input_tensor)
|
|
_model_cpu = model.cpu().float()
|
|
output_cpu = model(input_tensor.float().cpu())
|
|
self.assertEqual(output.cpu().float(), output_cpu, atol=1e-3, rtol=1e-3)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfNoCudnn
|
|
def test_contig_wrong_stride_cudnn(self, device):
|
|
# x has to have batch_size 1 to test contiguous checks
|
|
x = torch.randn(1, 16, 5, 5, device=device)
|
|
stride = list(x.stride())
|
|
stride[0] = 20
|
|
# change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
|
|
x.set_(x.storage(), 0, x.size(), stride)
|
|
self.assertTrue(x.is_contiguous())
|
|
F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device))
|
|
F.conv2d(x, torch.randn(1, 16, 1, 1, device=device))
|
|
|
|
@onlyCUDA
|
|
@tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005)
|
|
def test_Conv2d_size_1_kernel(self, device):
|
|
x_cpu = torch.randn(2, 3, 5, 5)
|
|
conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1)
|
|
y_cpu = conv_cpu(x_cpu)
|
|
y = torch.rand_like(y_cpu)
|
|
y_cpu.backward(y)
|
|
|
|
with cudnn.flags(enabled=False):
|
|
conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device)
|
|
conv_cuda.bias.data.copy_(conv_cpu.bias.data)
|
|
conv_cuda.weight.data.copy_(conv_cpu.weight.data)
|
|
y_cuda = conv_cuda(x_cpu.to(device))
|
|
y_cuda.backward(y.to(device))
|
|
|
|
self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
|
|
self.assertEqual(
|
|
conv_cpu.bias.grad.data,
|
|
conv_cuda.bias.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
self.assertEqual(
|
|
conv_cpu.weight.grad.data,
|
|
conv_cuda.weight.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
|
|
@onlyCUDA
|
|
@tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005)
|
|
def test_ConvTranspose2d_size_1_kernel(self, device):
|
|
x_cpu = torch.randn(2, 3, 5, 5)
|
|
conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1)
|
|
y_cpu = conv_cpu(x_cpu)
|
|
y = torch.rand_like(y_cpu)
|
|
y_cpu.backward(y)
|
|
|
|
with cudnn.flags(enabled=False):
|
|
conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device)
|
|
conv_cuda.bias.data.copy_(conv_cpu.bias.data)
|
|
conv_cuda.weight.data.copy_(conv_cpu.weight.data)
|
|
y_cuda = conv_cuda(x_cpu.to(device))
|
|
y_cuda.backward(y.to(device))
|
|
|
|
self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
|
|
self.assertEqual(
|
|
conv_cpu.bias.grad.data,
|
|
conv_cuda.bias.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
self.assertEqual(
|
|
conv_cpu.weight.grad.data,
|
|
conv_cuda.weight.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
|
|
@onlyCUDA
|
|
def test_ConvTranspose3d_size_1_kernel(self, device):
|
|
with set_default_dtype(torch.double):
|
|
x_cpu = torch.randn(2, 3, 3, 5, 5)
|
|
conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1)
|
|
y_cpu = conv_cpu(x_cpu)
|
|
y = torch.rand_like(y_cpu)
|
|
y_cpu.backward(y)
|
|
|
|
with cudnn.flags(enabled=False):
|
|
conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device)
|
|
conv_cuda.bias.data.copy_(conv_cpu.bias.data)
|
|
conv_cuda.weight.data.copy_(conv_cpu.weight.data)
|
|
y_cuda = conv_cuda(x_cpu.to(device))
|
|
y_cuda.backward(y.to(device))
|
|
|
|
self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
|
|
self.assertEqual(
|
|
conv_cpu.bias.grad.data,
|
|
conv_cuda.bias.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
self.assertEqual(
|
|
conv_cpu.weight.grad.data,
|
|
conv_cuda.weight.grad.data,
|
|
atol=1e-5,
|
|
rtol=0,
|
|
exact_device=False,
|
|
)
|
|
|
|
@dtypesIfCUDA(
|
|
*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
|
|
)
|
|
@dtypes(torch.float)
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@torch.backends.miopen.flags(immediate=True)
|
|
@tf32_on_and_off(0.001)
|
|
def test_Conv2d_naive_groups(self, device, dtype):
|
|
# Check that grouped convolutions matches two half convolutions
|
|
m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
|
|
i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
|
|
output = m(i)
|
|
grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
|
|
output.backward(grad_output)
|
|
|
|
m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
|
|
m1.weight.data.copy_(m.weight.data[:2])
|
|
m1.bias.data.copy_(m.bias.data[:2])
|
|
i1 = i.data[:, :2].contiguous().requires_grad_(True)
|
|
output1 = m1(i1)
|
|
output1.backward(grad_output[:, :2].contiguous())
|
|
|
|
m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
|
|
m2.weight.data.copy_(m.weight.data[2:])
|
|
m2.bias.data.copy_(m.bias.data[2:])
|
|
i2 = i.data[:, 2:].contiguous().requires_grad_(True)
|
|
output2 = m2(i2)
|
|
output2.backward(grad_output[:, 2:].contiguous())
|
|
|
|
self.assertEqual(output, torch.cat([output1, output2], 1))
|
|
self.assertEqual(
|
|
i.grad.data,
|
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.bias.grad.data,
|
|
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
m.weight.grad.data,
|
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
|
atol=dtype2prec_DONTUSE[dtype],
|
|
rtol=0,
|
|
)
|
|
|
|
@dtypes(torch.double, torch.cdouble)
|
|
@dtypesIfMPS(torch.float, torch.cfloat)
|
|
@expectedFailureMPS # https://github.com/pytorch/pytorch/issues/107214
|
|
def test_Conv2d_backward_depthwise(self, device, dtype):
|
|
x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True)
|
|
weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True)
|
|
|
|
def conv2d_depthwise(x, weight):
|
|
return torch.nn.functional.conv2d(
|
|
x, weight, bias=None, stride=(1, 10), groups=2
|
|
)
|
|
|
|
for cudnn_enabled in [False, True]:
|
|
with torch.backends.cudnn.flags(enabled=cudnn_enabled):
|
|
torch.autograd.gradcheck(conv2d_depthwise, (x, weight))
|
|
|
|
@onlyCPU
|
|
@dtypes(torch.float, torch.double)
|
|
def test_conv_thnn_nhwc(self, device, dtype):
|
|
def helper(
|
|
mod,
|
|
n,
|
|
c,
|
|
h,
|
|
w,
|
|
out_channels,
|
|
kernel_size,
|
|
dilation,
|
|
groups,
|
|
input_format,
|
|
weight_format,
|
|
):
|
|
input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
|
|
memory_format=input_format
|
|
)
|
|
input.requires_grad_()
|
|
conv = mod(
|
|
c, out_channels, kernel_size, dilation=dilation, groups=groups
|
|
).to(device="cpu", dtype=dtype, memory_format=weight_format)
|
|
for p in conv.parameters():
|
|
p.data = torch.randint_like(p, -3, 3)
|
|
|
|
ref_input = input.detach().clone().contiguous().requires_grad_()
|
|
ref_conv = mod(
|
|
c, out_channels, kernel_size, dilation=dilation, groups=groups
|
|
)
|
|
# load_state_dict will restore the stride & memory_layout on ref_conv.weight.
|
|
ref_conv.load_state_dict(conv.state_dict())
|
|
ref_conv = ref_conv.to(
|
|
device="cpu", dtype=dtype, memory_format=torch.contiguous_format
|
|
)
|
|
|
|
out = conv(input)
|
|
ref_out = ref_conv(ref_input)
|
|
|
|
grad = torch.randint_like(out, -3, 3)
|
|
ref_grad = grad.detach().clone().contiguous()
|
|
|
|
out.backward(grad)
|
|
ref_out.backward(ref_grad)
|
|
|
|
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
|
self.assertTrue(ref_out.is_contiguous())
|
|
self.assertEqual(out, ref_out, exact_dtype=False)
|
|
self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
|
|
self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
|
|
self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
|
|
|
|
with torch.backends.mkldnn.flags(enabled=False):
|
|
formats = [
|
|
[torch.channels_last, torch.channels_last],
|
|
[torch.channels_last, torch.contiguous_format],
|
|
[torch.contiguous_format, torch.channels_last],
|
|
]
|
|
for input_format, weight_format in formats:
|
|
# non-dilated conv: thnn_conv2d normal path (with im2col)
|
|
helper(
|
|
nn.Conv2d,
|
|
2,
|
|
8,
|
|
4,
|
|
4,
|
|
out_channels=4,
|
|
kernel_size=3,
|
|
dilation=1,
|
|
groups=1,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
helper(
|
|
nn.Conv2d,
|
|
2,
|
|
8,
|
|
4,
|
|
4,
|
|
out_channels=8,
|
|
kernel_size=3,
|
|
dilation=1,
|
|
groups=8,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
# test when input chanels is 1 and not converted to channels last
|
|
helper(
|
|
nn.Conv2d,
|
|
2,
|
|
1,
|
|
10,
|
|
10,
|
|
out_channels=8,
|
|
kernel_size=3,
|
|
dilation=1,
|
|
groups=1,
|
|
input_format=torch.contiguous_format,
|
|
weight_format=torch.channels_last,
|
|
)
|
|
# non-dilated conv: thnn_conv2d fast path (skip im2col)
|
|
helper(
|
|
nn.Conv2d,
|
|
1,
|
|
16,
|
|
56,
|
|
56,
|
|
out_channels=16,
|
|
kernel_size=1,
|
|
dilation=1,
|
|
groups=1,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
# ic == oc == 1 here, so need to stick input to CL to activate channels last
|
|
helper(
|
|
nn.Conv2d,
|
|
1,
|
|
16,
|
|
56,
|
|
56,
|
|
out_channels=16,
|
|
kernel_size=1,
|
|
dilation=1,
|
|
groups=16,
|
|
input_format=torch.channels_last,
|
|
weight_format=weight_format,
|
|
)
|
|
# dilated conv: slow_conv_dilated2d
|
|
helper(
|
|
nn.Conv2d,
|
|
2,
|
|
8,
|
|
11,
|
|
13,
|
|
out_channels=16,
|
|
kernel_size=3,
|
|
dilation=2,
|
|
groups=1,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
helper(
|
|
nn.Conv2d,
|
|
2,
|
|
16,
|
|
11,
|
|
13,
|
|
out_channels=32,
|
|
kernel_size=3,
|
|
dilation=2,
|
|
groups=16,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
# transposed-conv: slow_conv_transpose2d
|
|
helper(
|
|
nn.ConvTranspose2d,
|
|
2,
|
|
8,
|
|
4,
|
|
4,
|
|
out_channels=4,
|
|
kernel_size=3,
|
|
dilation=1,
|
|
groups=1,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
helper(
|
|
nn.ConvTranspose2d,
|
|
2,
|
|
8,
|
|
4,
|
|
4,
|
|
out_channels=8,
|
|
kernel_size=3,
|
|
dilation=1,
|
|
groups=8,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
helper(
|
|
nn.ConvTranspose2d,
|
|
1,
|
|
16,
|
|
56,
|
|
56,
|
|
out_channels=16,
|
|
kernel_size=1,
|
|
dilation=1,
|
|
groups=1,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
helper(
|
|
nn.ConvTranspose2d,
|
|
1,
|
|
16,
|
|
56,
|
|
56,
|
|
out_channels=32,
|
|
kernel_size=1,
|
|
dilation=1,
|
|
groups=16,
|
|
input_format=input_format,
|
|
weight_format=weight_format,
|
|
)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.half, torch.float, torch.cfloat)
|
|
def test_conv_cudnn_nhwc(self, device, dtype):
|
|
def helper(n, c, h, w, out_channels, kernel_size, groups):
|
|
# randint with dtype=torch.cfloat fails with
|
|
# RuntimeError: check_random_bounds handles only integral, floating-point and boolean types
|
|
# must create randint and randint_like using default int64, then cast to desired
|
|
input = torch.randint(
|
|
-3, 3, (n, c, h, w), dtype=torch.int64, device=device
|
|
).to(dtype, memory_format=torch.channels_last)
|
|
input.requires_grad_()
|
|
conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
|
|
device="cuda", dtype=dtype, memory_format=torch.channels_last
|
|
)
|
|
for p in conv.parameters():
|
|
p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype)
|
|
|
|
# use FP64 channels-first conv as reference
|
|
ref_input = input.detach().clone().contiguous().double().requires_grad_()
|
|
ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)
|
|
# load_state_dict will restore the stride & memory_layout on ref_conv.weight.
|
|
ref_conv.load_state_dict(conv.state_dict())
|
|
ref_conv = ref_conv.to(
|
|
device="cuda", dtype=torch.double, memory_format=torch.contiguous_format
|
|
)
|
|
|
|
out = conv(input)
|
|
ref_out = ref_conv(ref_input)
|
|
|
|
grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype)
|
|
ref_grad = grad.detach().clone().double().contiguous()
|
|
|
|
out.backward(grad)
|
|
ref_out.backward(ref_grad)
|
|
|
|
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
|
self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last))
|
|
self.assertTrue(
|
|
conv.weight.grad.is_contiguous(memory_format=torch.channels_last)
|
|
)
|
|
|
|
self.assertTrue(ref_out.is_contiguous())
|
|
self.assertTrue(ref_input.grad.is_contiguous())
|
|
self.assertTrue(ref_conv.weight.grad.is_contiguous())
|
|
|
|
self.assertEqual(out, ref_out, exact_dtype=False)
|
|
self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
|
|
self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
|
|
self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
|
|
|
|
helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1)
|
|
helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8)
|
|
helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1)
|
|
helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.half, torch.float)
|
|
def test_conv_cudnn_ndhwc(self, device, dtype):
|
|
def helper(n, c, d, h, w, out_channels, kernel_size, groups):
|
|
input = torch.randint(
|
|
-2, 2, (n, c, d, h, w), dtype=dtype, device=device
|
|
).to(memory_format=torch.channels_last_3d)
|
|
input.requires_grad_()
|
|
conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups).to(
|
|
device="cuda", dtype=dtype, memory_format=torch.channels_last_3d
|
|
)
|
|
for p in conv.parameters():
|
|
p.data = torch.randint_like(p, -2, 2)
|
|
|
|
# use FP64 channels-first conv as reference
|
|
ref_input = input.detach().clone().contiguous().double().requires_grad_()
|
|
ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)
|
|
# load_state_dict will restore the stride & memory_layout on ref_conv.weight.
|
|
ref_conv.load_state_dict(conv.state_dict())
|
|
ref_conv = ref_conv.to(
|
|
device="cuda", dtype=torch.double, memory_format=torch.contiguous_format
|
|
)
|
|
|
|
out = conv(input)
|
|
ref_out = ref_conv(ref_input)
|
|
|
|
grad = torch.randint_like(out, -2, 2)
|
|
ref_grad = grad.detach().clone().double().contiguous()
|
|
|
|
out.backward(grad)
|
|
ref_out.backward(ref_grad)
|
|
|
|
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d))
|
|
self.assertTrue(
|
|
input.grad.is_contiguous(memory_format=torch.channels_last_3d)
|
|
)
|
|
self.assertTrue(
|
|
conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d)
|
|
)
|
|
|
|
self.assertTrue(ref_out.is_contiguous())
|
|
self.assertTrue(ref_input.grad.is_contiguous())
|
|
self.assertTrue(ref_conv.weight.grad.is_contiguous())
|
|
|
|
self.assertEqual(out, ref_out, exact_dtype=False)
|
|
self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
|
|
self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
|
|
self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
|
|
|
|
helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1)
|
|
helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8)
|
|
helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1)
|
|
helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16)
|
|
|
|
def _run_conv(
|
|
self,
|
|
layer,
|
|
device,
|
|
inp,
|
|
grad,
|
|
ref_conv,
|
|
ref_input,
|
|
ref_out,
|
|
input_format,
|
|
weight_format,
|
|
grad_format,
|
|
output_format,
|
|
):
|
|
conv = (
|
|
layer(inp.size(1), grad.size(1), ref_conv.weight.size(2)).float().to(device)
|
|
)
|
|
# load_state_dict will restore the stride & memory_layout on ref_conv.weight.
|
|
conv.load_state_dict(ref_conv.state_dict())
|
|
weight_data = (
|
|
conv.weight.detach().clone().contiguous(memory_format=weight_format)
|
|
)
|
|
conv.weight.data = weight_data.resize_(
|
|
weight_data.size(), memory_format=weight_format
|
|
)
|
|
input = inp.clone().contiguous(memory_format=input_format)
|
|
input.resize_(input.size(), memory_format=input_format)
|
|
input = input.requires_grad_()
|
|
grad = grad.contiguous(memory_format=grad_format)
|
|
grad.resize_(grad.size(), memory_format=grad_format)
|
|
out = conv(input)
|
|
out.backward(grad)
|
|
self.assertTrue(out.is_contiguous(memory_format=output_format))
|
|
self.assertEqual(out, ref_out)
|
|
self.assertEqual(conv.weight.grad, ref_conv.weight.grad)
|
|
self.assertEqual(conv.bias.grad, ref_conv.bias.grad)
|
|
self.assertEqual(input.grad, ref_input.grad)
|
|
|
|
def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
|
|
data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device)
|
|
ref_input = data.clone().contiguous().requires_grad_(True)
|
|
ref_conv = layer(c, k, filter_size).float().to(device)
|
|
ref_out = ref_conv(ref_input)
|
|
grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device="cuda")
|
|
ref_out.backward(grad)
|
|
|
|
for w_f in [torch.contiguous_format, torch.channels_last]:
|
|
for g_f in [torch.contiguous_format, torch.channels_last]:
|
|
for input_format in [torch.contiguous_format, torch.channels_last]:
|
|
output_format = torch.contiguous_format
|
|
# Older versions of CudNN have Channels Last support disabled
|
|
if torch.backends.cudnn.version() >= 7603:
|
|
if input_format == torch.channels_last:
|
|
output_format = torch.channels_last
|
|
# This is because we have N111 weight that cannot handle
|
|
# the ambiguous memory_format
|
|
if w_f == torch.channels_last:
|
|
if layer == nn.Conv2d and filter_size * c != 1:
|
|
output_format = torch.channels_last
|
|
if layer == nn.ConvTranspose2d and filter_size * k != 1:
|
|
output_format = torch.channels_last
|
|
self._run_conv(
|
|
layer,
|
|
device,
|
|
data,
|
|
grad,
|
|
ref_conv,
|
|
ref_input,
|
|
ref_out,
|
|
input_format,
|
|
w_f,
|
|
g_f,
|
|
output_format,
|
|
)
|
|
|
|
@onlyCUDA
|
|
@tf32_on_and_off(0.05)
|
|
def test_conv_cudnn_mismatch_memory_format(self, device):
|
|
configs = [
|
|
[4, 2, 8, 8, 4, 2],
|
|
[4, 1, 8, 8, 4, 2],
|
|
[1, 1, 8, 8, 4, 2],
|
|
[4, 2, 2, 8, 4, 1],
|
|
[4, 2, 1, 8, 4, 1],
|
|
[4, 2, 8, 8, 4, 1],
|
|
[4, 1, 8, 8, 4, 1],
|
|
]
|
|
for n, c, h, w, k, filter_size in configs:
|
|
self._test_conv_cudnn_nhwc_nchw(
|
|
nn.Conv2d, n, c, h, w, k, filter_size, device
|
|
)
|
|
self._test_conv_cudnn_nhwc_nchw(
|
|
nn.ConvTranspose2d, n, c, h, w, k, filter_size, device
|
|
)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfNoCudnn
|
|
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
|
|
def test_conv_cudnn_nhwc_support(self, device, dtype):
|
|
input = torch.randn(
|
|
(1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True
|
|
)
|
|
weight = torch.randn(
|
|
(8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True
|
|
)
|
|
weight = weight.to(memory_format=torch.channels_last)
|
|
o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
|
|
self.assertTrue(o.is_contiguous(memory_format=torch.channels_last))
|
|
o.sum().backward()
|
|
|
|
# Test that faster algorithms used for inference produce the same results
|
|
# Validates depthwise3x3 bug reported in https://github.com/pytorch/pytorch/issues/60176
|
|
@onlyCPU
|
|
@dtypes(torch.float)
|
|
def test_conv2d_no_grad(self, device, dtype):
|
|
for batch in [1, 2, 3]:
|
|
for groups in [1, 2, 4]:
|
|
input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device)
|
|
m = nn.Conv2d(
|
|
groups,
|
|
8,
|
|
kernel_size=(3, 3),
|
|
groups=groups,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
with torch.no_grad():
|
|
output_ng = m(input)
|
|
output = m(input)
|
|
self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfNoCudnn
|
|
@dtypes(torch.float, torch.float16)
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@precisionOverride({torch.half: 0.002, torch.float: 1e-4})
|
|
def test_cudnn_convolution_relu(self, device, dtype):
|
|
for batch, groups, image_size, kernel_size, memory_format in product(
|
|
(1, 2, 3),
|
|
(1, 2, 4),
|
|
((1, 1), (8, 8)),
|
|
((1, 1), (3, 3)),
|
|
(torch.channels_last, torch.contiguous_format),
|
|
):
|
|
if image_size[0] < kernel_size[0]:
|
|
continue
|
|
inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
|
|
w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
|
|
inp = inp.to(memory_format=memory_format)
|
|
w = w.to(memory_format=memory_format)
|
|
conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
|
|
if torch.version.hip:
|
|
cudnn_out = torch.miopen_convolution_relu(
|
|
inp, w, None, (1, 1), (0, 0), (1, 1), 1
|
|
)
|
|
else:
|
|
cudnn_out = torch.cudnn_convolution_relu(
|
|
inp, w, None, (1, 1), (0, 0), (1, 1), 1
|
|
)
|
|
self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
|
|
if torch.cuda.is_tf32_supported() and dtype == torch.float:
|
|
self.assertEqual(conv2d_out.relu(), cudnn_out, atol=4e-3, rtol=0.006)
|
|
else:
|
|
self.assertEqual(conv2d_out.relu(), cudnn_out)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfNoCudnn
|
|
@dtypes(torch.float, torch.float16)
|
|
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
|
@precisionOverride({torch.half: 0.002, torch.float: 1e-4})
|
|
def test_cudnn_convolution_add_relu(self, device, dtype):
|
|
for batch, groups, image_size, kernel_size, memory_format in product(
|
|
(1, 2, 3),
|
|
(1, 2, 4),
|
|
((1, 1), (8, 8)),
|
|
((1, 1), (3, 3)),
|
|
(torch.channels_last, torch.contiguous_format),
|
|
):
|
|
if image_size[0] < kernel_size[0]:
|
|
continue
|
|
inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
|
|
w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
|
|
inp = inp.to(memory_format=memory_format)
|
|
w = w.to(memory_format=memory_format)
|
|
conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
|
|
alpha = 2.0
|
|
z = torch.randn_like(conv2d_out)
|
|
z = z.to(memory_format=memory_format)
|
|
if torch.version.hip:
|
|
cudnn_out = torch.miopen_convolution_add_relu(
|
|
inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1
|
|
)
|
|
else:
|
|
cudnn_out = torch.cudnn_convolution_add_relu(
|
|
inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1
|
|
)
|
|
|
|
self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
|
|
if torch.cuda.is_tf32_supported() and dtype == torch.float:
|
|
self.assertEqual(
|
|
F.relu(conv2d_out + alpha * z), cudnn_out, atol=2e-3, rtol=0.006
|
|
)
|
|
else:
|
|
self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)
|
|
|
|
@onlyCUDA
|
|
def test_convert_conv2d_weight_memory_format(self, device):
|
|
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
|
|
model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
|
|
for memory_format in [torch.channels_last, torch.contiguous_format]:
|
|
model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
|
|
out = model(input)
|
|
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
|
|
|
model = (
|
|
nn.Sequential(nn.ConvTranspose2d(8, 4, 3), nn.BatchNorm2d(4))
|
|
.to(device)
|
|
.float()
|
|
)
|
|
for memory_format in [torch.channels_last, torch.contiguous_format]:
|
|
model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
|
|
out = model(input)
|
|
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
|
|
|
@onlyCUDA
|
|
def test_convert_conv3d_weight_memory_format(self, device):
|
|
input = torch.randint(
|
|
1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device
|
|
)
|
|
model = (
|
|
nn.Sequential(nn.ConvTranspose3d(8, 4, 3), nn.BatchNorm3d(4))
|
|
.to(device)
|
|
.float()
|
|
)
|
|
for memory_format in [torch.channels_last_3d, torch.contiguous_format]:
|
|
model = nn.utils.convert_conv3d_weight_memory_format(model, memory_format)
|
|
out = model(input)
|
|
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
|
|
|
def test_conv_double_backward_strided_with_3D_input_and_weight(self, device):
|
|
# Test that _convolution_double_backward() outputs the correct grad shapes
|
|
# for 3D input / weight when stride > 1. This is an ad-hoc regression test for a
|
|
# specific case that was uncovered during the convolution consolidation effort.
|
|
# The test can be safely deleted if _convolution_double_backward() is removed.
|
|
|
|
input = torch.randn(2, 3, 6, device=device)
|
|
weight = torch.randn(3, 3, 3, device=device)
|
|
bias = torch.randn(3, device=device)
|
|
stride = (2,)
|
|
padding = (1,)
|
|
dilation = (1,)
|
|
transposed = False
|
|
output_padding = (0,)
|
|
groups = 1
|
|
output = torch.ops.aten.convolution(
|
|
input,
|
|
weight,
|
|
bias,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
)
|
|
|
|
ggI = torch.randn(input.shape, device=device)
|
|
ggW = torch.randn(weight.shape, device=device)
|
|
ggB = torch.randn(bias.shape, device=device)
|
|
gO = torch.randn(output.shape, device=device)
|
|
output_mask = [True, True, True]
|
|
(
|
|
grad_grad_output,
|
|
grad_input,
|
|
grad_weight,
|
|
) = torch.ops.aten._convolution_double_backward(
|
|
ggI,
|
|
ggW,
|
|
ggB,
|
|
gO,
|
|
weight,
|
|
input,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
output_mask,
|
|
)
|
|
|
|
# Make sure the correct shapes are computed.
|
|
self.assertEqual(grad_grad_output.shape, gO.shape)
|
|
self.assertEqual(grad_input.shape, input.shape)
|
|
self.assertEqual(grad_weight.shape, weight.shape)
|
|
|
|
@skipCUDAIfRocm
|
|
@onlyCUDA
|
|
@largeTensorTest("40GB")
|
|
@largeTensorTest("24GB", "cpu")
|
|
@tf32_on_and_off(0.005)
|
|
def test_conv3d_64bit_indexing(self, device):
|
|
x = torch.rand(1, 32, 512, 512, 256)
|
|
m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False)
|
|
yref = m(x)
|
|
y = m.to(device=device)(x.to(device=device))
|
|
self.assertEqual(yref, y)
|
|
|
|
@skipCUDAIfRocm
|
|
@onlyCUDA
|
|
@largeTensorTest("40GB", "cuda")
|
|
def test_conv3d_cudnn_broken(self, device):
|
|
for dtype in (torch.half, torch.bfloat16):
|
|
x = torch.rand(1, 16, 124, 1282, 722, dtype=dtype, device=device)
|
|
m = torch.nn.Conv3d(
|
|
16,
|
|
16,
|
|
kernel_size=(1, 3, 3),
|
|
padding=0,
|
|
stride=1,
|
|
bias=False,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
with torch.backends.cudnn.flags(enabled=False):
|
|
yref = m(x)
|
|
y = m(x)
|
|
self.assertEqual(yref, y)
|
|
|
|
@skipCUDAIfRocm
|
|
@onlyCUDA
|
|
@largeTensorTest("20GB")
|
|
@largeTensorTest("64GB", "cpu")
|
|
# TODO(eqy): Remove this once it is fixed in cuDNN and we can dispatch to it again
|
|
@xfailIf(
|
|
torch.backends.cudnn.version() is not None
|
|
and torch.backends.cudnn.version() > 91000
|
|
)
|
|
def test_depthwise_conv_64bit_indexing(self, device):
|
|
x = torch.randn(1, 2, 32800, 32800, dtype=torch.half).to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
c = nn.Conv2d(
|
|
2, 2, kernel_size=3, stride=1, padding=1, groups=2, dtype=torch.half
|
|
).to(memory_format=torch.channels_last)
|
|
yref = c(x)
|
|
y = c.to(device=device)(x.to(device=device))
|
|
self.assertEqual(yref, y, atol=1e-3, rtol=1e-4)
|
|
del y, yref
|
|
|
|
# try a batch-splittable case
|
|
x = x.reshape(100, 2, 3280, 3280)
|
|
x = x.contiguous(memory_format=torch.channels_last)
|
|
yref = c(x)
|
|
y = c.to(device=device)(x.to(device=device))
|
|
self.assertEqual(yref, y, atol=1e-3, rtol=1e-4)
|
|
|
|
|
|
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True)
|
|
instantiate_parametrized_tests(TestConvolutionNN)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|