mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adding SSA support for convolution_backward
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77283 Approved by: https://github.com/Krovatkin
This commit is contained in:
committed by
PyTorch MergeBot
parent
417373337f
commit
dbee7e5499
@ -7,9 +7,7 @@ from textwrap import dedent
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
sample_inputs_cat_concat,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
|
||||
from torch.testing._internal.common_utils import make_tensor
|
||||
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
|
||||
from typing import List, Any
|
||||
@ -328,6 +326,43 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
inps[2].setType(inps[2].type().with_sizes(args[1].size()))
|
||||
self.checkShapeAnalysis(out_size, mod.graph, assert_propagation=True)
|
||||
|
||||
def assert_shape_equal_scripted(self, script_fn, given_ins):
|
||||
expected_res = script_fn(*given_ins)
|
||||
g = script_fn.graph
|
||||
graph_ins = list(g.inputs())
|
||||
self.assertEqual(len(given_ins), len(graph_ins))
|
||||
for inp, graph_in in zip(given_ins, graph_ins):
|
||||
graph_in.setType(graph_in.type().with_sizes(inp.size()))
|
||||
|
||||
out_sizes = [out.size() for out in expected_res]
|
||||
self.checkShapeAnalysis(out_sizes, g, assert_propagation=True)
|
||||
|
||||
def test_convolution_backward(self):
|
||||
# No opinfos for ops that are not part of the Python API
|
||||
# Also, as the return shapes are the input, weight, and bias shape, there is no point
|
||||
# in a really complicated test
|
||||
|
||||
input = torch.randn((16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True)
|
||||
weight = torch.randn((8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True)
|
||||
out_grad = torch.randn((16, 8, 8, 8), dtype=torch.float32, device="cpu")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def conv_bwd(input, weight, grad):
|
||||
bias_sizes = [8, ]
|
||||
args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
|
||||
return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args)
|
||||
|
||||
self.assert_shape_equal_scripted(conv_bwd, (input, weight, out_grad))
|
||||
|
||||
@torch.jit.script
|
||||
def conv_bwd_2(input, weight, grad):
|
||||
bias_sizes = None
|
||||
args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
|
||||
return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args)
|
||||
self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad))
|
||||
|
||||
|
||||
def test_returning_input_symbolic_shapes(self):
|
||||
mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
|
||||
inps = list(mm.graph.inputs())
|
||||
|
@ -1713,6 +1713,21 @@ def transpose(self: List[int],
|
||||
_19 = torch.append(output_size, torch.add(_18, 1))
|
||||
return output_size
|
||||
|
||||
)=====")
|
||||
+ std::string(R"=====(def conv_backwards(grad_output: List[int],
|
||||
input: List[int],
|
||||
weight: List[int],
|
||||
biases: Optional[List[int]]) -> Tuple[List[int], List[int], List[int]]:
|
||||
out = annotate(List[int], [])
|
||||
for _0 in range(torch.len(input)):
|
||||
elem = input[_0]
|
||||
_1 = torch.append(out, elem)
|
||||
out0 = annotate(List[int], [])
|
||||
for _2 in range(torch.len(weight)):
|
||||
elem0 = weight[_2]
|
||||
_3 = torch.append(out0, elem0)
|
||||
return (out, out0, [grad_output[1]])
|
||||
|
||||
)=====")
|
||||
+ std::string(R"=====(def flatten(input: List[int],
|
||||
start_dim: int,
|
||||
@ -2726,6 +2741,7 @@ const OperatorMap<std::string>& GetShapeFunctionMappings() {
|
||||
{"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", "conv2d"},
|
||||
{"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "batch_norm"},
|
||||
{"aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", "conv3d"},
|
||||
{"aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "conv_backwards"},
|
||||
{"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "flatten"},
|
||||
{"aten::cat(Tensor[] tensors, int dim=0) -> Tensor", "cat"},
|
||||
{"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "permute"},
|
||||
|
@ -724,6 +724,9 @@ def conv2d(
|
||||
assert len(input) == 4
|
||||
return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
|
||||
|
||||
def conv_backwards(grad_output: List[int], input:List[int], weight:List[int], biases:Optional[List[int]]):
|
||||
# Bias gradient is always generated regardess of if biases is supplied
|
||||
return _copy(input), _copy(weight), [grad_output[1]]
|
||||
|
||||
def batch_norm(
|
||||
input: List[int],
|
||||
@ -993,6 +996,7 @@ add_shape_compute_mapping("aten::conv1d(Tensor input, Tensor weight, Tensor? bia
|
||||
add_shape_compute_mapping("aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", conv2d)
|
||||
add_shape_compute_mapping("aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", batch_norm)
|
||||
add_shape_compute_mapping("aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", conv3d)
|
||||
add_shape_compute_mapping("aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", conv_backwards)
|
||||
add_shape_compute_mapping("aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", flatten)
|
||||
add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
|
||||
add_shape_compute_mapping("aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute)
|
||||
|
Reference in New Issue
Block a user