mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[numpy] Add torch.moveaxis
(#48581)
Summary: Reference: https://github.com/pytorch/pytorch/issues/38349 #36048 https://github.com/pytorch/pytorch/pull/41480#issuecomment-734398262 Pull Request resolved: https://github.com/pytorch/pytorch/pull/48581 Reviewed By: bdhirsh Differential Revision: D25276307 Pulled By: mruberry fbshipit-source-id: 3e3e4df1343c5ce5b71457badc43f08c419ec5c3
This commit is contained in:
committed by
Facebook GitHub Bot
parent
befab0d9d4
commit
5c9cef9a6c
@ -86,95 +86,97 @@ class TestShapeOps(TestCase):
|
||||
shape = self._rand_shape(4, min_size=5, max_size=10)
|
||||
x = _generate_input(shape, dtype, device, False)
|
||||
|
||||
# Invalid `source` and `destination` dimension
|
||||
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
|
||||
torch.movedim(x, 5, 0)
|
||||
for fn in [torch.movedim, torch.moveaxis]:
|
||||
# Invalid `source` and `destination` dimension
|
||||
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
|
||||
fn(x, 5, 0)
|
||||
|
||||
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
|
||||
torch.movedim(x, 0, 5)
|
||||
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
|
||||
fn(x, 0, 5)
|
||||
|
||||
# Mismatch in size of `source` and `destination`
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"):
|
||||
torch.movedim(x, (1, 0), (0, ))
|
||||
# Mismatch in size of `source` and `destination`
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"):
|
||||
fn(x, (1, 0), (0, ))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
|
||||
torch.movedim(x, (0, 0), (0, 1))
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
|
||||
fn(x, (0, 0), (0, 1))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
|
||||
torch.movedim(x, (0, 1, 0), (0, 1, 2))
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
|
||||
fn(x, (0, 1, 0), (0, 1, 2))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
|
||||
torch.movedim(x, (0, 1), (1, 1))
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
|
||||
fn(x, (0, 1), (1, 1))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
|
||||
torch.movedim(x, (0, 1, 2), (1, 0, 1))
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
|
||||
fn(x, (0, 1, 2), (1, 0, 1))
|
||||
|
||||
@dtypes(torch.int64, torch.float, torch.complex128)
|
||||
def test_movedim(self, device, dtype):
|
||||
for nd in range(5):
|
||||
shape = self._rand_shape(nd, min_size=5, max_size=10)
|
||||
x = _generate_input(shape, dtype, device, with_extremal=False)
|
||||
for random_negative in [True, False]:
|
||||
for src_dim, dst_dim in permutations(range(nd), r=2):
|
||||
random_prob = random.random()
|
||||
for fn in [torch.moveaxis, torch.movedim]:
|
||||
for nd in range(5):
|
||||
shape = self._rand_shape(nd, min_size=5, max_size=10)
|
||||
x = _generate_input(shape, dtype, device, with_extremal=False)
|
||||
for random_negative in [True, False]:
|
||||
for src_dim, dst_dim in permutations(range(nd), r=2):
|
||||
random_prob = random.random()
|
||||
|
||||
if random_negative and random_prob > 0.66:
|
||||
src_dim = src_dim - nd
|
||||
elif random_negative and random_prob > 0.33:
|
||||
dst_dim = dst_dim - nd
|
||||
elif random_negative:
|
||||
src_dim = src_dim - nd
|
||||
dst_dim = dst_dim - nd
|
||||
if random_negative and random_prob > 0.66:
|
||||
src_dim = src_dim - nd
|
||||
elif random_negative and random_prob > 0.33:
|
||||
dst_dim = dst_dim - nd
|
||||
elif random_negative:
|
||||
src_dim = src_dim - nd
|
||||
dst_dim = dst_dim - nd
|
||||
|
||||
# Integer `source` and `destination`
|
||||
torch_fn = partial(torch.movedim, source=src_dim, destination=dst_dim)
|
||||
np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim)
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
# Integer `source` and `destination`
|
||||
torch_fn = partial(fn, source=src_dim, destination=dst_dim)
|
||||
np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim)
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
|
||||
if nd == 0:
|
||||
continue
|
||||
if nd == 0:
|
||||
continue
|
||||
|
||||
def make_index_negative(sequence, idx):
|
||||
sequence = list(sequence)
|
||||
sequence[random_idx] = sequence[random_idx] - nd
|
||||
return tuple(src_sequence)
|
||||
def make_index_negative(sequence, idx):
|
||||
sequence = list(sequence)
|
||||
sequence[random_idx] = sequence[random_idx] - nd
|
||||
return tuple(src_sequence)
|
||||
|
||||
for src_sequence in permutations(range(nd), r=random.randint(1, nd)):
|
||||
# Sequence `source` and `destination`
|
||||
dst_sequence = tuple(random.sample(range(nd), len(src_sequence)))
|
||||
for src_sequence in permutations(range(nd), r=random.randint(1, nd)):
|
||||
# Sequence `source` and `destination`
|
||||
dst_sequence = tuple(random.sample(range(nd), len(src_sequence)))
|
||||
|
||||
# Randomly change a dim to a negative dim representation of itself.
|
||||
random_prob = random.random()
|
||||
if random_negative and random_prob > 0.66:
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
src_sequence = make_index_negative(src_sequence, random_idx)
|
||||
elif random_negative and random_prob > 0.33:
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
dst_sequence = make_index_negative(dst_sequence, random_idx)
|
||||
elif random_negative:
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
dst_sequence = make_index_negative(dst_sequence, random_idx)
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
src_sequence = make_index_negative(src_sequence, random_idx)
|
||||
# Randomly change a dim to a negative dim representation of itself.
|
||||
random_prob = random.random()
|
||||
if random_negative and random_prob > 0.66:
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
src_sequence = make_index_negative(src_sequence, random_idx)
|
||||
elif random_negative and random_prob > 0.33:
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
dst_sequence = make_index_negative(dst_sequence, random_idx)
|
||||
elif random_negative:
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
dst_sequence = make_index_negative(dst_sequence, random_idx)
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
src_sequence = make_index_negative(src_sequence, random_idx)
|
||||
|
||||
torch_fn = partial(torch.movedim, source=src_sequence, destination=dst_sequence)
|
||||
np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence)
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
torch_fn = partial(fn, source=src_sequence, destination=dst_sequence)
|
||||
np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence)
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
|
||||
# Move dim to same position
|
||||
x = torch.randn(2, 3, 5, 7, 11)
|
||||
torch_fn = partial(torch.movedim, source=(0, 1), destination=(0, 1))
|
||||
np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1))
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
# Move dim to same position
|
||||
x = torch.randn(2, 3, 5, 7, 11)
|
||||
torch_fn = partial(fn, source=(0, 1), destination=(0, 1))
|
||||
np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1))
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
|
||||
torch_fn = partial(torch.movedim, source=1, destination=1)
|
||||
np_fn = partial(np.moveaxis, source=1, destination=1)
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
torch_fn = partial(fn, source=1, destination=1)
|
||||
np_fn = partial(np.moveaxis, source=1, destination=1)
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
|
||||
# Empty Sequence
|
||||
torch_fn = partial(torch.movedim, source=(), destination=())
|
||||
np_fn = partial(np.moveaxis, source=(), destination=())
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
# Empty Sequence
|
||||
torch_fn = partial(fn, source=(), destination=())
|
||||
np_fn = partial(np.moveaxis, source=(), destination=())
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
|
||||
@dtypes(torch.float, torch.bool)
|
||||
def test_diag(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user