mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "remove torch.equal usages (#89527)"
This reverts commit 4095ef8b809f922f2e0e09011afd00037d20a771.
Reverted https://github.com/pytorch/pytorch/pull/89527 on behalf of https://github.com/clee2000 due to broke periodic multigpu tests 4095ef8b80
https://github.com/pytorch/pytorch/actions/runs/3592806602/jobs/6049368502
This commit is contained in:
@ -1082,31 +1082,31 @@ class TestNamedTensor(TestCase):
|
||||
|
||||
def test_unflatten(self):
|
||||
# test args: tensor, int, namedshape
|
||||
self.assertTrue(
|
||||
(torch.ones(4, names=('A',)).unflatten('A', (('A', 2), ('B', 2))) ==
|
||||
torch.ones(2, 2, names=('A', 'B'))).all())
|
||||
self.assertTrue(
|
||||
(torch.ones(4, names=('A',)).unflatten('A', [('A', 2), ('B', 2)]) ==
|
||||
torch.ones(2, 2, names=('A', 'B'))).all())
|
||||
self.assertTrue(
|
||||
(torch.ones(4, names=('A',)).unflatten('A', (['A', 2], ['B', 2])) ==
|
||||
torch.ones(2, 2, names=('A', 'B'))).all())
|
||||
self.assertTrue(
|
||||
(torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)) ==
|
||||
torch.ones(2, 10, names=('A', 'B1'))).all())
|
||||
self.assertTrue(
|
||||
(torch.ones(2, 3 * 4 * 5 * 6, names=('A', 'B'))
|
||||
.unflatten('B', (['B1', 3], ['B2', 4], ['B3', -1], ['B4', 6])) ==
|
||||
torch.ones(2, 3, 4, 5, 6, names=('A', 'B1', 'B2', 'B3', 'B4'))).all())
|
||||
self.assertTrue(
|
||||
(torch.ones(2, 0, names=('A', 'B'))
|
||||
.unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])) ==
|
||||
torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3'))).all())
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(4, names=('A',)).unflatten('A', (('A', 2), ('B', 2))),
|
||||
torch.ones(2, 2, names=('A', 'B'))))
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(4, names=('A',)).unflatten('A', [('A', 2), ('B', 2)]),
|
||||
torch.ones(2, 2, names=('A', 'B'))))
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(4, names=('A',)).unflatten('A', (['A', 2], ['B', 2])),
|
||||
torch.ones(2, 2, names=('A', 'B'))))
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)),
|
||||
torch.ones(2, 10, names=('A', 'B1'))))
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(2, 3 * 4 * 5 * 6, names=('A', 'B'))
|
||||
.unflatten('B', (['B1', 3], ['B2', 4], ['B3', -1], ['B4', 6])),
|
||||
torch.ones(2, 3, 4, 5, 6, names=('A', 'B1', 'B2', 'B3', 'B4'))))
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(2, 0, names=('A', 'B'))
|
||||
.unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])),
|
||||
torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3'))))
|
||||
|
||||
# test args: namedtensor, str, namedshape
|
||||
self.assertTrue(
|
||||
(torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))) ==
|
||||
torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))).all())
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))),
|
||||
torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))))
|
||||
|
||||
# test invalid args: namedtensor, str, sizes
|
||||
with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
|
||||
|
Reference in New Issue
Block a user