mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix shape function for transpose convolution (#102139)
Fixes #98129. Fixes the shape function for jit conv_transpose, as defined by the documentation https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d, includes output_padding. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102139 Approved by: https://github.com/mingfeima, https://github.com/davidberard98
This commit is contained in:
@ -781,33 +781,41 @@ def conv_transpose2d_input(input: List[int], weight: List[int], bias: Optional[L
|
||||
input_batch_size_dim = 0
|
||||
weight_output_channels_dim = 1
|
||||
output_size.append(input[input_batch_size_dim])
|
||||
output_size.append(weight[weight_output_channels_dim])
|
||||
output_size.append(weight[weight_output_channels_dim] * groups)
|
||||
|
||||
for d in range(2, dim):
|
||||
dilation_ = dilation[d - 2] if has_dilation else 1
|
||||
kernel = dilation_ * (weight[d] - 1)
|
||||
output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + 1)
|
||||
output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + output_padding[d - 2] + 1)
|
||||
return output_size
|
||||
|
||||
def conv_forwards(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]:
|
||||
has_dilation = len(dilation) > 0
|
||||
has_output_padding = len(output_padding) > 0
|
||||
dim = len(input)
|
||||
output_size: List[int] = []
|
||||
input_batch_size_dim = 0
|
||||
weight_output_channels_dim = 1 if transposed else 0
|
||||
output_size.append(input[input_batch_size_dim])
|
||||
output_size.append(weight[weight_output_channels_dim])
|
||||
if transposed:
|
||||
output_size.append(weight[weight_output_channels_dim] * groups)
|
||||
else:
|
||||
output_size.append(weight[weight_output_channels_dim])
|
||||
|
||||
for d in range(2, dim):
|
||||
dilation_ = dilation[d - 2] if has_dilation else 1
|
||||
output_padding_ = output_padding[d - 2] if has_output_padding else 0
|
||||
if transposed:
|
||||
kernel = dilation_ * (weight[d] - 1)
|
||||
output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + 1)
|
||||
output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + output_padding_ + 1)
|
||||
else:
|
||||
kernel = dilation_ * (weight[d] - 1) + 1
|
||||
output_size.append((input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1)
|
||||
return output_size
|
||||
|
||||
def _conv_forwards(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]:
|
||||
return conv_forwards(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
|
||||
|
||||
def batch_norm(
|
||||
input: List[int],
|
||||
weight: Optional[List[int]],
|
||||
@ -1124,6 +1132,7 @@ add_shape_compute_mapping("aten::batch_norm(Tensor input, Tensor? weight, Tensor
|
||||
add_shape_compute_mapping("aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", conv3d)
|
||||
add_shape_compute_mapping("aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", conv_backwards)
|
||||
add_shape_compute_mapping("aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", conv_forwards)
|
||||
add_shape_compute_mapping("aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", _conv_forwards)
|
||||
add_shape_compute_mapping("aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", conv_transpose2d_input)
|
||||
add_shape_compute_mapping("aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", flatten)
|
||||
add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
|
||||
|
||||
Reference in New Issue
Block a user