mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Expanded Weights] add 'same' and 'valid' padding support (#83345)
Co-authored-by: Ashkan <yousefpour@fb.com> Adds "same" and "valid" padding support, as Opacus (well @ashkan-software) did https://github.com/pytorch/opacus/pull/451 Basics of it are this: - during forward pass, if there's "same" padding, we manually pad the input (NB: this will cause a small perf hit, haven't benchmarked yet) - during backward pass, the gradient wrt input needs to be cut down to the correct size if the original padding was same (conv_transpose doesn't accept string padding). Because conv_transpose will give us a gradient wrt the padded shape, we cut down the gradient to the correct size (we know how much padding we added to the left and right) - then, for the per sample gradients wrt weights, the input is already padded so neither the unfold nor group convolution have any padding Pull Request resolved: https://github.com/pytorch/pytorch/pull/83345 Approved by: https://github.com/zou3519
This commit is contained in:
@ -178,9 +178,6 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
|
||||
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
|
||||
|
||||
def reduction(x):
|
||||
return x.sum()
|
||||
|
||||
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum)
|
||||
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
@ -559,12 +556,9 @@ def filter_supported_tests(t):
|
||||
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding', 'LayerNorm', 'GroupNorm', 'InstanceNorm']
|
||||
if 'module_name' in t and t['module_name'] in supported_modules:
|
||||
return True
|
||||
if 'fullname' in t and any([module + "_" in t['fullname'] for module in supported_modules]):
|
||||
return not('Conv' in t['fullname'] and 'pad' in t['fullname'])
|
||||
|
||||
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
||||
# These currently use the legacy nn tests
|
||||
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding', 'LayerNorm', 'GroupNorm', 'InstanceNorm']
|
||||
supported_tests = [t for t in module_tests + new_module_tests if filter_supported_tests(t)]
|
||||
for test_param in supported_tests:
|
||||
if 'constructor' not in test_param:
|
||||
@ -628,8 +622,7 @@ def supported_inputs(op, sample_inputs, supported_inputs=True):
|
||||
is_supported_input = input.input.shape != normalized_shape # would cause inter-batch operations
|
||||
elif op.name in convolutions:
|
||||
# currently can't deal with padding computation on Python level
|
||||
is_supported_input = 'padding' not in input.kwargs or not isinstance(input.kwargs['padding'], str)
|
||||
is_supported_input = is_supported_input and input.input.dim() == batched_input_size[op.name]
|
||||
is_supported_input = input.input.dim() == batched_input_size[op.name]
|
||||
elif op.name == "nn.functional.embedding":
|
||||
idx = input.args[0]
|
||||
is_supported_input = len(idx.shape) > 1 # there's no batch size
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .conv_utils import conv_backward, conv_args_and_kwargs, conv_picker
|
||||
from .conv_utils import conv_backward, conv_args_and_kwargs, conv_picker, conv_input_for_string_padding
|
||||
from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads
|
||||
from .expanded_weights_utils import forward_helper
|
||||
|
||||
@ -11,10 +11,19 @@ from .expanded_weights_utils import forward_helper
|
||||
class ConvPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, kwarg_names, conv_fn, *expanded_args_and_kwargs):
|
||||
if any([isinstance(i, str) for i in expanded_args_and_kwargs]):
|
||||
raise RuntimeError("Expanded Weights does not support convolution padding as a string. "
|
||||
"Please file an issue to prioritize support")
|
||||
expanded_args, expanded_kwargs = conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs)
|
||||
orig_input = expanded_args[0]
|
||||
was_same_padding = expanded_kwargs['padding'] == "same"
|
||||
|
||||
if isinstance(expanded_kwargs['padding'], str):
|
||||
# if padding is a string, we'll do the necessary padding (slowly) using F.pad
|
||||
kernel_size = expanded_args[1].shape[2:]
|
||||
padding, dilation = expanded_kwargs['padding'], expanded_kwargs['dilation']
|
||||
input = conv_input_for_string_padding(conv_fn, padding, expanded_args[0], dilation, kernel_size)
|
||||
expanded_args = (input, expanded_args[1])
|
||||
# since we've already done the padding, don't need any more
|
||||
expanded_kwargs['padding'] = 0
|
||||
|
||||
output = forward_helper(conv_fn, expanded_args, expanded_kwargs)
|
||||
input, weight = expanded_args
|
||||
batched_dim_size = conv_picker(conv_fn, 3, 4, 5)
|
||||
@ -24,8 +33,10 @@ class ConvPerSampleGrad(torch.autograd.Function):
|
||||
|
||||
ctx.conv_fn = conv_fn
|
||||
|
||||
ctx.batch_size = input.shape[0]
|
||||
ctx.input_required_grad = input.requires_grad
|
||||
ctx.batch_size = orig_input.shape[0]
|
||||
ctx.input_required_grad = orig_input.requires_grad
|
||||
ctx.orig_input_shape = orig_input.shape
|
||||
ctx.was_same_padding = was_same_padding
|
||||
ctx.stride, ctx.padding = expanded_kwargs['stride'], expanded_kwargs['padding']
|
||||
ctx.dilation, ctx.groups = expanded_kwargs['dilation'], expanded_kwargs['groups']
|
||||
|
||||
|
@ -28,6 +28,38 @@ def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
|
||||
def conv_normalizer(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
return (input, weight), {'bias': bias, 'stride': stride, 'padding': padding, 'dilation': dilation, 'groups': groups}
|
||||
|
||||
|
||||
def conv_input_for_string_padding(func, padding_style, input, dilation, kernel_size):
|
||||
if padding_style == "valid":
|
||||
return input
|
||||
else:
|
||||
padding = int_padding_for_string_padding(func, padding_style, dilation, kernel_size)
|
||||
return F.pad(input, padding)
|
||||
|
||||
|
||||
def int_padding_for_string_padding(func, padding_style, dilation, kernel_size):
|
||||
def get_dilation(i):
|
||||
return dilation[i] if isinstance(dilation, tuple) else dilation
|
||||
|
||||
if padding_style == "same":
|
||||
padding: List[int] = []
|
||||
# F.pad needs the padding in reverse order from what conv expects
|
||||
for i in range(conv_picker(func, 0, 1, 2), -1, -1):
|
||||
padding += conv_padding_for_same(get_dilation(i), kernel_size[i])
|
||||
return padding
|
||||
elif padding_style == "valid":
|
||||
return conv_picker(func, 2, 4, 6) * (0,)
|
||||
else:
|
||||
raise RuntimeError(f"got padding type of {padding_style}, only accept 'same' or 'valid'")
|
||||
|
||||
|
||||
def conv_padding_for_same(dilation, kernel_size):
|
||||
total_pad = dilation * (kernel_size - 1)
|
||||
left_pad = total_pad // 2
|
||||
right_pad = total_pad - left_pad
|
||||
return left_pad, right_pad
|
||||
|
||||
|
||||
def conv_backward(func, ctx, grad_output):
|
||||
|
||||
def weight_grad_sample(weight):
|
||||
@ -43,6 +75,15 @@ def conv_backward(func, ctx, grad_output):
|
||||
else:
|
||||
return param
|
||||
|
||||
def calc_total_padding(func, was_same, padding, dilation, kernel_size):
|
||||
if was_same:
|
||||
all_padding = int_padding_for_string_padding(func, "same", dilation, kernel_size)
|
||||
# F.pad needs the padding in reverse order from what conv expects
|
||||
total_padding = tuple(all_padding[i] + all_padding[i - 1] for i in range(len(all_padding) - 1, -1, -2))
|
||||
return total_padding
|
||||
else:
|
||||
return tuple(2 * pad for pad in padding)
|
||||
|
||||
weight_shape = ctx.weight.shape
|
||||
stride, padding, dilation, groups = expand(ctx.stride), expand(ctx.padding), expand(ctx.dilation), ctx.groups
|
||||
|
||||
@ -55,15 +96,24 @@ def conv_backward(func, ctx, grad_output):
|
||||
results.append(None) # for kwarg names
|
||||
results.append(None) # for op reference
|
||||
|
||||
# "same" padding may give uneven padding on either side so we need to separate the "padding" attr and total padding
|
||||
total_padding = calc_total_padding(func, ctx.was_same_padding, padding, dilation, kernel_size)
|
||||
|
||||
if ctx.input_required_grad:
|
||||
output_padding = []
|
||||
input_dims = conv_picker(func, 1, 2, 3)
|
||||
for i in range(input_dims):
|
||||
input_dim = ctx.input.shape[2 + i]
|
||||
output_padding.append((2 * padding[i] + input_dim - (kernel_size[i] * dilation[i] - dilation[i] + 1)) % stride[i])
|
||||
input_dim = ctx.orig_input_shape[2 + i]
|
||||
output_padding.append((total_padding[i] + input_dim - (kernel_size[i] * dilation[i] - dilation[i] + 1)) % stride[i])
|
||||
weight_ = unpack_expanded_weight_or_tensor(ctx.weight)
|
||||
transpose_func = conv_picker(func, F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d)
|
||||
results.append(transpose_func(grad_output, weight_, None, stride, padding, tuple(output_padding), groups, dilation))
|
||||
out = transpose_func(grad_output, weight_, None, stride, padding, tuple(output_padding), groups, dilation)
|
||||
|
||||
if ctx.was_same_padding:
|
||||
for i in range(len(total_padding)):
|
||||
out = torch.narrow(out, 2 + i, total_padding[i] // 2, ctx.orig_input_shape[2 + i])
|
||||
|
||||
results.append(out)
|
||||
else:
|
||||
results.append(None)
|
||||
# weight and bias don't compute batched gradients; no other arguments are differentiable
|
||||
|
Reference in New Issue
Block a user