functionalization: fix x.is_contiguous(channels_last) (#94195)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94195
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2023-02-11 16:08:44 +00:00
committed by PyTorch MergeBot
parent aba4fb9a16
commit abfd293c39
2 changed files with 16 additions and 1 deletions

View File

@ -583,6 +583,21 @@ def forward(self, arg0_1):
return diagonal_scatter
""")
def test_channels_last_contiguous(self):
def f(x):
return x.contiguous(memory_format=torch.channels_last)
tmp = torch.ones(2)
y = x.diagonal()
y.add_(tmp)
return x
x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2)
self.assert_functionalization(f, x)
logs = self.get_logs(f, x).strip()
# There should be no clone in the graph
self.assertExpectedInline(logs, """\
def forward(self, arg0_1):
return arg0_1""")
def test_split(self):
def f(x):
# test: view ops that return multiple tensors (split)