Adding SSA support for convolution_backward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77283

Approved by: https://github.com/Krovatkin
This commit is contained in:
John Clow
2022-05-20 10:33:42 -07:00
committed by PyTorch MergeBot
parent 417373337f
commit dbee7e5499
3 changed files with 58 additions and 3 deletions

View File

@ -7,9 +7,7 @@ from textwrap import dedent
import torch
from torch import nn
from torch.testing import FileCheck
from torch.testing._internal.common_methods_invocations import (
sample_inputs_cat_concat,
)
from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
from torch.testing._internal.common_utils import make_tensor
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
from typing import List, Any
@ -328,6 +326,43 @@ class TestSymbolicShapeAnalysis(JitTestCase):
inps[2].setType(inps[2].type().with_sizes(args[1].size()))
self.checkShapeAnalysis(out_size, mod.graph, assert_propagation=True)
def assert_shape_equal_scripted(self, script_fn, given_ins):
expected_res = script_fn(*given_ins)
g = script_fn.graph
graph_ins = list(g.inputs())
self.assertEqual(len(given_ins), len(graph_ins))
for inp, graph_in in zip(given_ins, graph_ins):
graph_in.setType(graph_in.type().with_sizes(inp.size()))
out_sizes = [out.size() for out in expected_res]
self.checkShapeAnalysis(out_sizes, g, assert_propagation=True)
def test_convolution_backward(self):
# No opinfos for ops that are not part of the Python API
# Also, as the return shapes are the input, weight, and bias shape, there is no point
# in a really complicated test
input = torch.randn((16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True)
weight = torch.randn((8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True)
out_grad = torch.randn((16, 8, 8, 8), dtype=torch.float32, device="cpu")
@torch.jit.script
def conv_bwd(input, weight, grad):
bias_sizes = [8, ]
args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args)
self.assert_shape_equal_scripted(conv_bwd, (input, weight, out_grad))
@torch.jit.script
def conv_bwd_2(input, weight, grad):
bias_sizes = None
args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args)
self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad))
def test_returning_input_symbolic_shapes(self):
mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
inps = list(mm.graph.inputs())

View File

@ -1713,6 +1713,21 @@ def transpose(self: List[int],
_19 = torch.append(output_size, torch.add(_18, 1))
return output_size
)=====")
+ std::string(R"=====(def conv_backwards(grad_output: List[int],
input: List[int],
weight: List[int],
biases: Optional[List[int]]) -> Tuple[List[int], List[int], List[int]]:
out = annotate(List[int], [])
for _0 in range(torch.len(input)):
elem = input[_0]
_1 = torch.append(out, elem)
out0 = annotate(List[int], [])
for _2 in range(torch.len(weight)):
elem0 = weight[_2]
_3 = torch.append(out0, elem0)
return (out, out0, [grad_output[1]])
)=====")
+ std::string(R"=====(def flatten(input: List[int],
start_dim: int,
@ -2726,6 +2741,7 @@ const OperatorMap<std::string>& GetShapeFunctionMappings() {
{"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", "conv2d"},
{"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "batch_norm"},
{"aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", "conv3d"},
{"aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "conv_backwards"},
{"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "flatten"},
{"aten::cat(Tensor[] tensors, int dim=0) -> Tensor", "cat"},
{"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "permute"},

View File

@ -724,6 +724,9 @@ def conv2d(
assert len(input) == 4
return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
def conv_backwards(grad_output: List[int], input:List[int], weight:List[int], biases:Optional[List[int]]):
# Bias gradient is always generated regardess of if biases is supplied
return _copy(input), _copy(weight), [grad_output[1]]
def batch_norm(
input: List[int],
@ -993,6 +996,7 @@ add_shape_compute_mapping("aten::conv1d(Tensor input, Tensor weight, Tensor? bia
add_shape_compute_mapping("aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", conv2d)
add_shape_compute_mapping("aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", batch_norm)
add_shape_compute_mapping("aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", conv3d)
add_shape_compute_mapping("aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", conv_backwards)
add_shape_compute_mapping("aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", flatten)
add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
add_shape_compute_mapping("aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute)