functionalization: make view_copy outputs always contiguous (#85747)

This fixes an issue with mobile: The output of view_copy ops should always be contiguous.

Later, we can consider adding optional arguments to the `view_copy()` functions to let you explicitly say what the contiguity of the output can be (e.g. channels_last)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85747
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2022-10-21 08:29:10 -07:00
committed by PyTorch MergeBot
parent 294bfb8e80
commit 9ad1659b17
3 changed files with 27 additions and 13 deletions

View File

@ -659,23 +659,31 @@ def forward(self, a_1):
getitem_1 = split_copy[1]; split_copy = None
add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
select_copy = torch.ops.aten.select_copy.int(_reshape_alias_copy, 0, 0); _reshape_alias_copy = None
clone = torch.ops.aten.clone.default(add_1, memory_format = torch.contiguous_format)
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(add_1, [4], [1])
view_copy_1 = torch.ops.aten.view_copy.default(add, [8]); add = None
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_1, 1, 0); _reshape_alias_copy_1 = None
_reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_2, 1, 0); _reshape_alias_copy_2 = None
unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None
squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None
slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = None
unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None
squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None
transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None
_reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None
view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_2, [4, 2]); _reshape_alias_copy_2 = None
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None
_reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None
select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_3, 0, 0); _reshape_alias_copy_3 = None
add_2 = torch.ops.aten.add.Tensor(select_copy_1, _unsafe_view); select_copy_1 = _unsafe_view = None
_reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None
view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_3, [4, 2]); _reshape_alias_copy_3 = None
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8])
_reshape_alias_copy_4 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None
select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_4, 0, 0); _reshape_alias_copy_4 = None
view_copy_4 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None
_reshape_alias_copy_5 = torch.ops.aten._reshape_alias_copy.default(view_copy_4, [2, 4], [4, 1]); view_copy_4 = None
transpose_copy_3 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_5, 1, 0); _reshape_alias_copy_5 = None
unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None
squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None
split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None
getitem_2 = split_copy_1[0]
getitem_3 = split_copy_1[1]; split_copy_1 = None
_reshape_alias_copy_6 = torch.ops.aten._reshape_alias_copy.default(getitem_2, [4], [1]); getitem_2 = None
add_2 = torch.ops.aten.add.Tensor(select_copy_1, _reshape_alias_copy_6); select_copy_1 = _reshape_alias_copy_6 = None
return add_1
""") # noqa: B950

View File

@ -926,6 +926,12 @@ class TestViewOps(TestCase):
self.assertEqual(a_view_copy, a_view)
self.assertEqual(a.grad, a_ref.grad)
# Testing that the output of a view_copy kernel (by default) is contiguous.
def test_view_copy_output_contiguous(self, device):
a = torch.randn(4, 4, 4, 4, device=device).to(memory_format=torch.channels_last)
b = torch.ops.aten.slice_copy(a, 0, 0, 2)
self.assertTrue(b.is_contiguous())
def test_view_copy_out(self, device):
a = torch.randn(2, 2, device=device)
out = torch.empty(2, device=device)

View File

@ -91,7 +91,7 @@ at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
return self.reshape_symint(size);
} else {
auto output = at::_ops::view::call(self, size);
return output.clone();
return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
}
}
"""
@ -117,13 +117,13 @@ at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
return_cloned_output = """\
return output.clone();"""
return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);"""
else:
# If the return type is a list, we need to clone each tensor in the list.
return_cloned_output = f"""\
{view_copy_sig.returns_type().cpp_type()} out_clone;
for (const auto i : c10::irange(output.size())) {{
out_clone.push_back(output[i].clone());
out_clone.push_back(output[i].clone(/*memory_format=*/at::MemoryFormat::Contiguous));
}}
return out_clone;"""