mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
WIP Add compatibility with channels_last_3d for conv3d (#114790)
Part of a multi-PR work to fix #59168 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114790 Approved by: https://github.com/albanD
This commit is contained in:
@ -396,6 +396,7 @@ Utility functions to convert Module parameter memory formats.
|
||||
:nosignatures:
|
||||
|
||||
convert_conv2d_weight_memory_format
|
||||
convert_conv3d_weight_memory_format
|
||||
|
||||
Utility functions to apply and remove weight normalization from Module parameters.
|
||||
|
||||
|
@ -37,6 +37,7 @@ if TEST_SCIPY:
|
||||
import scipy.signal
|
||||
import scipy.ndimage
|
||||
|
||||
|
||||
class TestConvolutionNN(NNTestCase):
|
||||
_do_cuda_memory_leak_check = True
|
||||
_do_cuda_non_default_stream = True
|
||||
@ -377,7 +378,6 @@ class TestConvolutionNN(NNTestCase):
|
||||
self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True',
|
||||
lambda: o1.sum().backward())
|
||||
|
||||
|
||||
def test_conv_modules_raise_error_on_incorrect_input_size(self):
|
||||
for dtype in [torch.half, torch.bfloat16, torch.double, torch.float]:
|
||||
modules = [nn.Conv1d(3, 8, 3).to(dtype), nn.ConvTranspose1d(3, 8, 3).to(dtype),
|
||||
@ -666,7 +666,6 @@ class TestConvolutionNN(NNTestCase):
|
||||
out = conv(input)
|
||||
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_cudnn_noncontiguous_weight(self):
|
||||
# Noncontiguous weights must be contiguous() before being
|
||||
@ -677,7 +676,6 @@ class TestConvolutionNN(NNTestCase):
|
||||
self.assertEqual(F.conv1d(input, weights1, bias=None, stride=2, dilation=2),
|
||||
F.conv1d(input, weights2, bias=None, stride=2, dilation=2))
|
||||
|
||||
|
||||
def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient='input'):
|
||||
for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
|
||||
for batch, stride, padding, chan_in, chan_out, dilation in \
|
||||
@ -744,7 +742,8 @@ class TestConvolutionNN(NNTestCase):
|
||||
if has_bias:
|
||||
bias = torch.randn([chan_out], requires_grad=True, dtype=torch.float)
|
||||
output = torch._nnpack_spatial_convolution(input, weight, stride=stride, padding=padding, bias=bias)
|
||||
output_expected = torch.nn.functional.conv2d(input, weight, stride=stride, padding=padding, bias=bias)
|
||||
output_expected = torch.nn.functional.conv2d(
|
||||
input, weight, stride=stride, padding=padding, bias=bias)
|
||||
self.assertEqual(output, output_expected, atol=3e-4, rtol=0)
|
||||
|
||||
gradient_o = torch.randn(output.shape, dtype=torch.float)
|
||||
@ -764,7 +763,6 @@ class TestConvolutionNN(NNTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "Only \"zeros\" "):
|
||||
nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect")
|
||||
|
||||
|
||||
def test_functional_grad_conv(self):
|
||||
# Conv 1D
|
||||
input = torch.randn(1, 1, 5, requires_grad=True)
|
||||
@ -819,7 +817,8 @@ class TestConvolutionNN(NNTestCase):
|
||||
|
||||
input = torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL).uniform_(-8.0, 8.0).requires_grad_(True)
|
||||
|
||||
weight = torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size).uniform_(-4.0, 4.0).requires_grad_(True)
|
||||
weight = torch.empty(OUT_CH, IN_CH // groups, kernel_size,
|
||||
kernel_size).uniform_(-4.0, 4.0).requires_grad_(True)
|
||||
|
||||
output = F.conv2d(input, weight,
|
||||
stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
@ -909,7 +908,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
self.assertEqual(conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0)
|
||||
self.assertEqual(conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
|
||||
def test_Conv2d_large_workspace(self, device, dtype):
|
||||
@ -932,7 +930,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
run_test(benchmark=False)
|
||||
run_test(benchmark=True)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.half, torch.float)
|
||||
def test_ConvTranspose2d_large_output_padding(self, device, dtype):
|
||||
@ -949,7 +946,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
x.backward(torch.randn_like(x))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float, torch.double, torch.half)
|
||||
# Very similar to test_Conv2d_naive_groups but with special care to handle
|
||||
@ -1038,7 +1034,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
m2.weight.grad.data], 0),
|
||||
atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
|
||||
def test_noncontig_conv_grad(self, device, dtype):
|
||||
@ -1081,7 +1076,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
"\ninp_size: " + str(inp_size) +
|
||||
"\ndilation: " + str(dilation))
|
||||
|
||||
|
||||
def test_conv_double_backward_no_bias(self):
|
||||
kern = 3
|
||||
stride = 2
|
||||
@ -1107,7 +1101,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
"\ninp_size: " + str(inp_size) +
|
||||
"\ndilation: " + str(dilation))
|
||||
|
||||
|
||||
def test_conv_double_backward_groups(self):
|
||||
kern = 3
|
||||
stride = 1
|
||||
@ -1134,7 +1127,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
"\ndilation: " + str(dilation) +
|
||||
"\ngroups: " + str(groups))
|
||||
|
||||
|
||||
def test_conv_double_backward_stride(self):
|
||||
batch_size = 2
|
||||
|
||||
@ -1750,7 +1742,8 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
# Forward AD and forward-over-reverse AD smoke test in float32
|
||||
# TODO: remove this if we introduce per-op gradient tests for float32
|
||||
with fwAD.dual_level():
|
||||
dual_inputs = [(fwAD.make_dual(i, torch.rand_like(i)) if isinstance(i, torch.Tensor) else i) for i in inputs]
|
||||
dual_inputs = [(fwAD.make_dual(i, torch.rand_like(i)) if isinstance(i, torch.Tensor) else i)
|
||||
for i in inputs]
|
||||
# Forward AD
|
||||
output = convolution(*dual_inputs)
|
||||
# Forward over reverse AD
|
||||
@ -1780,7 +1773,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
bias.requires_grad_(False)
|
||||
self.assertTrue(gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol))
|
||||
|
||||
|
||||
@onlyCPU
|
||||
def test_conv_contiguous_for_oneDNN(self):
|
||||
# See https://github.com/pytorch/pytorch/issues/80837.
|
||||
@ -1883,7 +1875,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device)
|
||||
conv1(input_large)
|
||||
conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype)
|
||||
input_large = torch.randn(1, 1, 2048, 1024 , dtype=dtype, device=device)
|
||||
input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device)
|
||||
conv2(input_large)
|
||||
|
||||
def test_conv_noncontig_weights(self, device):
|
||||
@ -2052,7 +2044,8 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
|
||||
self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
|
||||
self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False)
|
||||
self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False)
|
||||
self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data,
|
||||
atol=1e-5, rtol=0, exact_device=False)
|
||||
|
||||
@dtypesIfCUDA(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
|
||||
@dtypes(torch.float)
|
||||
@ -2445,6 +2438,20 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
out = model(input)
|
||||
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@skipCUDAIfCudnnVersionLessThan(7603)
|
||||
def test_convert_conv3d_weight_memory_format(self, device):
|
||||
|
||||
input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device)
|
||||
model = nn.Sequential(
|
||||
nn.ConvTranspose3d(8, 4, 3),
|
||||
nn.BatchNorm3d(4)).to(device).float()
|
||||
for memory_format in [torch.channels_last_3d, torch.contiguous_format]:
|
||||
model = nn.utils.convert_conv3d_weight_memory_format(model, memory_format)
|
||||
out = model(input)
|
||||
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
||||
|
||||
def test_conv_double_backward_strided_with_3D_input_and_weight(self, device):
|
||||
# Test that _convolution_double_backward() outputs the correct grad shapes
|
||||
# for 3D input / weight when stride > 1. This is an ad-hoc regression test for a
|
||||
@ -2487,6 +2494,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
y = m.to(device=device)(x.to(device=device))
|
||||
self.assertEqual(yref, y)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals())
|
||||
instantiate_parametrized_tests(TestConvolutionNN)
|
||||
|
||||
|
@ -4,7 +4,7 @@ from .weight_norm import weight_norm, remove_weight_norm
|
||||
from .convert_parameters import parameters_to_vector, vector_to_parameters
|
||||
from .spectral_norm import spectral_norm, remove_spectral_norm
|
||||
from .fusion import fuse_conv_bn_eval, fuse_conv_bn_weights, fuse_linear_bn_eval, fuse_linear_bn_weights
|
||||
from .memory_format import convert_conv2d_weight_memory_format
|
||||
from .memory_format import convert_conv2d_weight_memory_format, convert_conv3d_weight_memory_format
|
||||
from . import parametrizations
|
||||
from .init import skip_init
|
||||
from . import stateless
|
||||
@ -14,6 +14,7 @@ __all__ = [
|
||||
"clip_grad_norm_",
|
||||
"clip_grad_value_",
|
||||
"convert_conv2d_weight_memory_format",
|
||||
"convert_conv3d_weight_memory_format",
|
||||
"fuse_conv_bn_eval",
|
||||
"fuse_conv_bn_weights",
|
||||
"fuse_linear_bn_eval",
|
||||
|
@ -70,3 +70,74 @@ def convert_conv2d_weight_memory_format(module, memory_format):
|
||||
for child in module.children():
|
||||
convert_conv2d_weight_memory_format(child, memory_format)
|
||||
return module
|
||||
|
||||
|
||||
def convert_conv3d_weight_memory_format(module, memory_format):
|
||||
r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format``
|
||||
The conversion recursively applies to nested ``nn.Module``, including ``module``.
|
||||
Note that it only changes the memory_format, but not the semantics of each dimensions.
|
||||
This function is used to facilitate the computation to adopt NHWC kernels, which
|
||||
provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0
|
||||
|
||||
.. note::
|
||||
Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive
|
||||
than the utility function ``convert_conv3d_weight_memory_format``. Any
|
||||
layer with 4d weight will be affected by ``model.to``, which does not
|
||||
necessarily benefit from conversion to specified ``memory_format``.
|
||||
One place we are confident in is that NHWC(channels_last) conversion for
|
||||
convolution in cuDNN, As it is beneficial to run convolution in NHWC,
|
||||
even in cases where we have to apply permutation to input tensors.
|
||||
|
||||
Hence our strategy here is to convert only the weight of convolution to
|
||||
channels_last. This ensures that;
|
||||
1. Fast convolution kernels will be used, the benefit of which could
|
||||
outweigh overhead of permutation (if input is not in the same format)
|
||||
2. No unnecessary permutations are applied on layers that do not benefit
|
||||
from memory_format conversion.
|
||||
|
||||
The optimal case is that, layers between convolution layers are channels
|
||||
last compatible. Input tensor would be permuted to channels last when it
|
||||
encounters the first convolution layer and stay in that memory format.
|
||||
Hence following convolutions will not need to permute its input tensor.
|
||||
|
||||
In case where a channels last incompatible layer is between convolution
|
||||
layers, we need to permute the input tensor back to contiguous format
|
||||
for that layer. The input tensor will go through the remaining layers in
|
||||
contiguous format and be permuted to channels last when it encounters
|
||||
another convolution layer. There's no point in propagating that
|
||||
permutation to an earlier layer, as most layers are quite agnostic to
|
||||
``memory_format``.
|
||||
|
||||
This claim might change when PyTorch supports fusion of permutation, as
|
||||
there might have been a better spot to fuse the permutation other than
|
||||
immediately before a convolution.
|
||||
|
||||
Args:
|
||||
module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container
|
||||
``nn.Module``
|
||||
memory_format: user specified ``memory_format``,
|
||||
e.g. ``torch.channels_last`` or ``torch.contiguous_format``
|
||||
|
||||
Returns:
|
||||
The original module with updated ``nn.Conv3d``
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
|
||||
>>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda")
|
||||
>>> model = nn.Sequential(
|
||||
>>> nn.Conv3d(8, 4, 3)).cuda().half()
|
||||
>>> # This is identical to:
|
||||
>>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last)
|
||||
>>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last)
|
||||
>>> out = model(input)
|
||||
"""
|
||||
|
||||
# TODO: expand this to `_ConvNd` when channels_last support is extended
|
||||
# beyond only 4d tensors.
|
||||
if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
|
||||
weight_data = module.weight.detach().clone().contiguous(memory_format=memory_format)
|
||||
module.weight.data = weight_data.resize_(weight_data.size(), memory_format=memory_format)
|
||||
for child in module.children():
|
||||
convert_conv3d_weight_memory_format(child, memory_format)
|
||||
return module
|
||||
|
Reference in New Issue
Block a user