mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
functionalization: fix x.is_contiguous(channels_last) (#94195)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94195 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
aba4fb9a16
commit
abfd293c39
@ -583,6 +583,21 @@ def forward(self, arg0_1):
|
||||
return diagonal_scatter
|
||||
""")
|
||||
|
||||
def test_channels_last_contiguous(self):
|
||||
def f(x):
|
||||
return x.contiguous(memory_format=torch.channels_last)
|
||||
tmp = torch.ones(2)
|
||||
y = x.diagonal()
|
||||
y.add_(tmp)
|
||||
return x
|
||||
x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2)
|
||||
self.assert_functionalization(f, x)
|
||||
logs = self.get_logs(f, x).strip()
|
||||
# There should be no clone in the graph
|
||||
self.assertExpectedInline(logs, """\
|
||||
def forward(self, arg0_1):
|
||||
return arg0_1""")
|
||||
|
||||
def test_split(self):
|
||||
def f(x):
|
||||
# test: view ops that return multiple tensors (split)
|
||||
|
Reference in New Issue
Block a user