mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
functionalization: make view_copy outputs always contiguous (#85747)
This fixes an issue with mobile: The output of view_copy ops should always be contiguous. Later, we can consider adding optional arguments to the `view_copy()` functions to let you explicitly say what the contiguity of the output can be (e.g. channels_last) Pull Request resolved: https://github.com/pytorch/pytorch/pull/85747 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
294bfb8e80
commit
9ad1659b17
@ -659,23 +659,31 @@ def forward(self, a_1):
|
||||
getitem_1 = split_copy[1]; split_copy = None
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
|
||||
select_copy = torch.ops.aten.select_copy.int(_reshape_alias_copy, 0, 0); _reshape_alias_copy = None
|
||||
clone = torch.ops.aten.clone.default(add_1, memory_format = torch.contiguous_format)
|
||||
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
|
||||
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(add_1, [4], [1])
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(add, [8]); add = None
|
||||
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None
|
||||
transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_1, 1, 0); _reshape_alias_copy_1 = None
|
||||
_reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None
|
||||
transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_2, 1, 0); _reshape_alias_copy_2 = None
|
||||
unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None
|
||||
squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None
|
||||
slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = None
|
||||
unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None
|
||||
squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None
|
||||
transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None
|
||||
_reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None
|
||||
view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_2, [4, 2]); _reshape_alias_copy_2 = None
|
||||
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None
|
||||
_reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None
|
||||
select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_3, 0, 0); _reshape_alias_copy_3 = None
|
||||
add_2 = torch.ops.aten.add.Tensor(select_copy_1, _unsafe_view); select_copy_1 = _unsafe_view = None
|
||||
_reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None
|
||||
view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_3, [4, 2]); _reshape_alias_copy_3 = None
|
||||
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8])
|
||||
_reshape_alias_copy_4 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None
|
||||
select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_4, 0, 0); _reshape_alias_copy_4 = None
|
||||
view_copy_4 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None
|
||||
_reshape_alias_copy_5 = torch.ops.aten._reshape_alias_copy.default(view_copy_4, [2, 4], [4, 1]); view_copy_4 = None
|
||||
transpose_copy_3 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_5, 1, 0); _reshape_alias_copy_5 = None
|
||||
unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None
|
||||
squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None
|
||||
split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None
|
||||
getitem_2 = split_copy_1[0]
|
||||
getitem_3 = split_copy_1[1]; split_copy_1 = None
|
||||
_reshape_alias_copy_6 = torch.ops.aten._reshape_alias_copy.default(getitem_2, [4], [1]); getitem_2 = None
|
||||
add_2 = torch.ops.aten.add.Tensor(select_copy_1, _reshape_alias_copy_6); select_copy_1 = _reshape_alias_copy_6 = None
|
||||
return add_1
|
||||
""") # noqa: B950
|
||||
|
||||
|
@ -926,6 +926,12 @@ class TestViewOps(TestCase):
|
||||
self.assertEqual(a_view_copy, a_view)
|
||||
self.assertEqual(a.grad, a_ref.grad)
|
||||
|
||||
# Testing that the output of a view_copy kernel (by default) is contiguous.
|
||||
def test_view_copy_output_contiguous(self, device):
|
||||
a = torch.randn(4, 4, 4, 4, device=device).to(memory_format=torch.channels_last)
|
||||
b = torch.ops.aten.slice_copy(a, 0, 0, 2)
|
||||
self.assertTrue(b.is_contiguous())
|
||||
|
||||
def test_view_copy_out(self, device):
|
||||
a = torch.randn(2, 2, device=device)
|
||||
out = torch.empty(2, device=device)
|
||||
|
@ -91,7 +91,7 @@ at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
|
||||
return self.reshape_symint(size);
|
||||
} else {
|
||||
auto output = at::_ops::view::call(self, size);
|
||||
return output.clone();
|
||||
return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
|
||||
}
|
||||
}
|
||||
"""
|
||||
@ -117,13 +117,13 @@ at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
|
||||
|
||||
if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
|
||||
return_cloned_output = """\
|
||||
return output.clone();"""
|
||||
return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);"""
|
||||
else:
|
||||
# If the return type is a list, we need to clone each tensor in the list.
|
||||
return_cloned_output = f"""\
|
||||
{view_copy_sig.returns_type().cpp_type()} out_clone;
|
||||
for (const auto i : c10::irange(output.size())) {{
|
||||
out_clone.push_back(output[i].clone());
|
||||
out_clone.push_back(output[i].clone(/*memory_format=*/at::MemoryFormat::Contiguous));
|
||||
}}
|
||||
return out_clone;"""
|
||||
|
||||
|
Reference in New Issue
Block a user