Fix shape function for transpose convolution (#102139)

Fixes #98129.
Fixes the shape function for jit conv_transpose, as defined by the documentation https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d, includes output_padding.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102139
Approved by: https://github.com/mingfeima, https://github.com/davidberard98
This commit is contained in:
ecao
2023-06-21 17:50:56 +00:00
committed by PyTorch MergeBot
parent 678ce61cdb
commit 223f232928
3 changed files with 130 additions and 21 deletions

View File

@ -781,33 +781,41 @@ def conv_transpose2d_input(input: List[int], weight: List[int], bias: Optional[L
input_batch_size_dim = 0
weight_output_channels_dim = 1
output_size.append(input[input_batch_size_dim])
output_size.append(weight[weight_output_channels_dim])
output_size.append(weight[weight_output_channels_dim] * groups)
for d in range(2, dim):
dilation_ = dilation[d - 2] if has_dilation else 1
kernel = dilation_ * (weight[d] - 1)
output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + 1)
output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + output_padding[d - 2] + 1)
return output_size
def conv_forwards(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]:
has_dilation = len(dilation) > 0
has_output_padding = len(output_padding) > 0
dim = len(input)
output_size: List[int] = []
input_batch_size_dim = 0
weight_output_channels_dim = 1 if transposed else 0
output_size.append(input[input_batch_size_dim])
output_size.append(weight[weight_output_channels_dim])
if transposed:
output_size.append(weight[weight_output_channels_dim] * groups)
else:
output_size.append(weight[weight_output_channels_dim])
for d in range(2, dim):
dilation_ = dilation[d - 2] if has_dilation else 1
output_padding_ = output_padding[d - 2] if has_output_padding else 0
if transposed:
kernel = dilation_ * (weight[d] - 1)
output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + 1)
output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + output_padding_ + 1)
else:
kernel = dilation_ * (weight[d] - 1) + 1
output_size.append((input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1)
return output_size
def _conv_forwards(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]:
return conv_forwards(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
def batch_norm(
input: List[int],
weight: Optional[List[int]],
@ -1124,6 +1132,7 @@ add_shape_compute_mapping("aten::batch_norm(Tensor input, Tensor? weight, Tensor
add_shape_compute_mapping("aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", conv3d)
add_shape_compute_mapping("aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", conv_backwards)
add_shape_compute_mapping("aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", conv_forwards)
add_shape_compute_mapping("aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", _conv_forwards)
add_shape_compute_mapping("aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", conv_transpose2d_input)
add_shape_compute_mapping("aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", flatten)
add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)