WIP Add compatibility with channels_last_3d for conv3d (#114790)

Part of a multi-PR work to fix #59168

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114790
Approved by: https://github.com/albanD
This commit is contained in:
Damien
2023-12-20 19:28:21 +00:00
committed by PyTorch MergeBot
parent 8bff59e41d
commit 2d2016fdf8
4 changed files with 99 additions and 18 deletions

View File

@ -396,6 +396,7 @@ Utility functions to convert Module parameter memory formats.
:nosignatures:
convert_conv2d_weight_memory_format
convert_conv3d_weight_memory_format
Utility functions to apply and remove weight normalization from Module parameters.

View File

@ -37,6 +37,7 @@ if TEST_SCIPY:
import scipy.signal
import scipy.ndimage
class TestConvolutionNN(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
@ -377,7 +378,6 @@ class TestConvolutionNN(NNTestCase):
self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True',
lambda: o1.sum().backward())
def test_conv_modules_raise_error_on_incorrect_input_size(self):
for dtype in [torch.half, torch.bfloat16, torch.double, torch.float]:
modules = [nn.Conv1d(3, 8, 3).to(dtype), nn.ConvTranspose1d(3, 8, 3).to(dtype),
@ -666,7 +666,6 @@ class TestConvolutionNN(NNTestCase):
out = conv(input)
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_cudnn_noncontiguous_weight(self):
# Noncontiguous weights must be contiguous() before being
@ -677,7 +676,6 @@ class TestConvolutionNN(NNTestCase):
self.assertEqual(F.conv1d(input, weights1, bias=None, stride=2, dilation=2),
F.conv1d(input, weights2, bias=None, stride=2, dilation=2))
def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient='input'):
for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
for batch, stride, padding, chan_in, chan_out, dilation in \
@ -744,7 +742,8 @@ class TestConvolutionNN(NNTestCase):
if has_bias:
bias = torch.randn([chan_out], requires_grad=True, dtype=torch.float)
output = torch._nnpack_spatial_convolution(input, weight, stride=stride, padding=padding, bias=bias)
output_expected = torch.nn.functional.conv2d(input, weight, stride=stride, padding=padding, bias=bias)
output_expected = torch.nn.functional.conv2d(
input, weight, stride=stride, padding=padding, bias=bias)
self.assertEqual(output, output_expected, atol=3e-4, rtol=0)
gradient_o = torch.randn(output.shape, dtype=torch.float)
@ -764,7 +763,6 @@ class TestConvolutionNN(NNTestCase):
with self.assertRaisesRegex(ValueError, "Only \"zeros\" "):
nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect")
def test_functional_grad_conv(self):
# Conv 1D
input = torch.randn(1, 1, 5, requires_grad=True)
@ -819,7 +817,8 @@ class TestConvolutionNN(NNTestCase):
input = torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL).uniform_(-8.0, 8.0).requires_grad_(True)
weight = torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size).uniform_(-4.0, 4.0).requires_grad_(True)
weight = torch.empty(OUT_CH, IN_CH // groups, kernel_size,
kernel_size).uniform_(-4.0, 4.0).requires_grad_(True)
output = F.conv2d(input, weight,
stride=stride, padding=padding, dilation=dilation, groups=groups)
@ -909,7 +908,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
self.assertEqual(conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0)
self.assertEqual(conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0)
@onlyCUDA
@dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
def test_Conv2d_large_workspace(self, device, dtype):
@ -932,7 +930,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
run_test(benchmark=False)
run_test(benchmark=True)
@onlyCUDA
@dtypes(torch.half, torch.float)
def test_ConvTranspose2d_large_output_padding(self, device, dtype):
@ -949,7 +946,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
x.backward(torch.randn_like(x))
torch.cuda.synchronize()
@onlyCUDA
@dtypes(torch.float, torch.double, torch.half)
# Very similar to test_Conv2d_naive_groups but with special care to handle
@ -1038,7 +1034,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
m2.weight.grad.data], 0),
atol=atol, rtol=rtol)
@onlyCUDA
@dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
def test_noncontig_conv_grad(self, device, dtype):
@ -1081,7 +1076,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
"\ninp_size: " + str(inp_size) +
"\ndilation: " + str(dilation))
def test_conv_double_backward_no_bias(self):
kern = 3
stride = 2
@ -1107,7 +1101,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
"\ninp_size: " + str(inp_size) +
"\ndilation: " + str(dilation))
def test_conv_double_backward_groups(self):
kern = 3
stride = 1
@ -1134,7 +1127,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
"\ndilation: " + str(dilation) +
"\ngroups: " + str(groups))
def test_conv_double_backward_stride(self):
batch_size = 2
@ -1750,7 +1742,8 @@ class TestConvolutionNNDeviceType(NNTestCase):
# Forward AD and forward-over-reverse AD smoke test in float32
# TODO: remove this if we introduce per-op gradient tests for float32
with fwAD.dual_level():
dual_inputs = [(fwAD.make_dual(i, torch.rand_like(i)) if isinstance(i, torch.Tensor) else i) for i in inputs]
dual_inputs = [(fwAD.make_dual(i, torch.rand_like(i)) if isinstance(i, torch.Tensor) else i)
for i in inputs]
# Forward AD
output = convolution(*dual_inputs)
# Forward over reverse AD
@ -1780,7 +1773,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
bias.requires_grad_(False)
self.assertTrue(gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol))
@onlyCPU
def test_conv_contiguous_for_oneDNN(self):
# See https://github.com/pytorch/pytorch/issues/80837.
@ -1883,7 +1875,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device)
conv1(input_large)
conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype)
input_large = torch.randn(1, 1, 2048, 1024 , dtype=dtype, device=device)
input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device)
conv2(input_large)
def test_conv_noncontig_weights(self, device):
@ -2052,7 +2044,8 @@ class TestConvolutionNNDeviceType(NNTestCase):
self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False)
self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False)
self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data,
atol=1e-5, rtol=0, exact_device=False)
@dtypesIfCUDA(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
@dtypes(torch.float)
@ -2445,6 +2438,20 @@ class TestConvolutionNNDeviceType(NNTestCase):
out = model(input)
self.assertTrue(out.is_contiguous(memory_format=memory_format))
@onlyCUDA
@skipCUDAIfRocm
@skipCUDAIfCudnnVersionLessThan(7603)
def test_convert_conv3d_weight_memory_format(self, device):
input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device)
model = nn.Sequential(
nn.ConvTranspose3d(8, 4, 3),
nn.BatchNorm3d(4)).to(device).float()
for memory_format in [torch.channels_last_3d, torch.contiguous_format]:
model = nn.utils.convert_conv3d_weight_memory_format(model, memory_format)
out = model(input)
self.assertTrue(out.is_contiguous(memory_format=memory_format))
def test_conv_double_backward_strided_with_3D_input_and_weight(self, device):
# Test that _convolution_double_backward() outputs the correct grad shapes
# for 3D input / weight when stride > 1. This is an ad-hoc regression test for a
@ -2487,6 +2494,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
y = m.to(device=device)(x.to(device=device))
self.assertEqual(yref, y)
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals())
instantiate_parametrized_tests(TestConvolutionNN)

View File

@ -4,7 +4,7 @@ from .weight_norm import weight_norm, remove_weight_norm
from .convert_parameters import parameters_to_vector, vector_to_parameters
from .spectral_norm import spectral_norm, remove_spectral_norm
from .fusion import fuse_conv_bn_eval, fuse_conv_bn_weights, fuse_linear_bn_eval, fuse_linear_bn_weights
from .memory_format import convert_conv2d_weight_memory_format
from .memory_format import convert_conv2d_weight_memory_format, convert_conv3d_weight_memory_format
from . import parametrizations
from .init import skip_init
from . import stateless
@ -14,6 +14,7 @@ __all__ = [
"clip_grad_norm_",
"clip_grad_value_",
"convert_conv2d_weight_memory_format",
"convert_conv3d_weight_memory_format",
"fuse_conv_bn_eval",
"fuse_conv_bn_weights",
"fuse_linear_bn_eval",

View File

@ -70,3 +70,74 @@ def convert_conv2d_weight_memory_format(module, memory_format):
for child in module.children():
convert_conv2d_weight_memory_format(child, memory_format)
return module
def convert_conv3d_weight_memory_format(module, memory_format):
r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format``
The conversion recursively applies to nested ``nn.Module``, including ``module``.
Note that it only changes the memory_format, but not the semantics of each dimensions.
This function is used to facilitate the computation to adopt NHWC kernels, which
provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0
.. note::
Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive
than the utility function ``convert_conv3d_weight_memory_format``. Any
layer with 4d weight will be affected by ``model.to``, which does not
necessarily benefit from conversion to specified ``memory_format``.
One place we are confident in is that NHWC(channels_last) conversion for
convolution in cuDNN, As it is beneficial to run convolution in NHWC,
even in cases where we have to apply permutation to input tensors.
Hence our strategy here is to convert only the weight of convolution to
channels_last. This ensures that;
1. Fast convolution kernels will be used, the benefit of which could
outweigh overhead of permutation (if input is not in the same format)
2. No unnecessary permutations are applied on layers that do not benefit
from memory_format conversion.
The optimal case is that, layers between convolution layers are channels
last compatible. Input tensor would be permuted to channels last when it
encounters the first convolution layer and stay in that memory format.
Hence following convolutions will not need to permute its input tensor.
In case where a channels last incompatible layer is between convolution
layers, we need to permute the input tensor back to contiguous format
for that layer. The input tensor will go through the remaining layers in
contiguous format and be permuted to channels last when it encounters
another convolution layer. There's no point in propagating that
permutation to an earlier layer, as most layers are quite agnostic to
``memory_format``.
This claim might change when PyTorch supports fusion of permutation, as
there might have been a better spot to fuse the permutation other than
immediately before a convolution.
Args:
module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container
``nn.Module``
memory_format: user specified ``memory_format``,
e.g. ``torch.channels_last`` or ``torch.contiguous_format``
Returns:
The original module with updated ``nn.Conv3d``
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
>>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda")
>>> model = nn.Sequential(
>>> nn.Conv3d(8, 4, 3)).cuda().half()
>>> # This is identical to:
>>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last)
>>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last)
>>> out = model(input)
"""
# TODO: expand this to `_ConvNd` when channels_last support is extended
# beyond only 4d tensors.
if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
weight_data = module.weight.detach().clone().contiguous(memory_format=memory_format)
module.weight.data = weight_data.resize_(weight_data.size(), memory_format=memory_format)
for child in module.children():
convert_conv3d_weight_memory_format(child, memory_format)
return module