align signature of make_tensor with other creation ops (#72702)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72702

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D34457729

Pulled By: mruberry

fbshipit-source-id: 83d580c4201eef946dc9cf4b9e28a3d36be55609
(cherry picked from commit aa4cf20fbeb4b795595729b8ac2e6ba7707d8283)
This commit is contained in:
Philip Meier
2022-02-24 21:47:38 -08:00
committed by PyTorch MergeBot
parent 5f310c5e27
commit 0973c5a1cc
17 changed files with 525 additions and 511 deletions

View File

@ -134,13 +134,13 @@ class TestViewOps(TestCase):
del dtypes[torch.bool]
def generate_inputs():
yield make_tensor((4, 4, 64), device, dtype, low=-5, high=5)
yield make_tensor((4, 4, 64), device, dtype, low=-5, high=5).permute(1, 0, 2)
yield make_tensor((4, 64, 4), device, dtype, low=-5, high=5).permute(2, 0, 1)
yield make_tensor((1, 5, 1), device, dtype, low=-5, high=5).expand(5, 5, 64)
yield make_tensor((2, 5, 256), device, dtype, low=-5, high=5)[1::2, 1:, ::2]
yield make_tensor((0, 5, 64), device, dtype, low=-5, high=5)
yield make_tensor((), device, dtype, low=-5, high=5)
yield make_tensor((4, 4, 64), dtype=dtype, device=device, low=-5, high=5)
yield make_tensor((4, 4, 64), dtype=dtype, device=device, low=-5, high=5).permute(1, 0, 2)
yield make_tensor((4, 64, 4), dtype=dtype, device=device, low=-5, high=5).permute(2, 0, 1)
yield make_tensor((1, 5, 1), dtype=dtype, device=device, low=-5, high=5).expand(5, 5, 64)
yield make_tensor((2, 5, 256), dtype=dtype, device=device, low=-5, high=5)[1::2, 1:, ::2]
yield make_tensor((0, 5, 64), dtype=dtype, device=device, low=-5, high=5)
yield make_tensor((), dtype=dtype, device=device, low=-5, high=5)
def calc_expected_size_and_stride(a, view_dtype):
dtype_size = torch._utils._element_size(a.dtype)
@ -211,7 +211,7 @@ class TestViewOps(TestCase):
# TODO: Remove this when autograd support is added
if dtype.is_floating_point or dtype.is_complex:
for view_dtype in [*get_all_fp_dtypes(), *get_all_complex_dtypes()]:
t = make_tensor((5, 5, 64), device, dtype, low=-5, high=5, requires_grad=True)
t = make_tensor((5, 5, 64), dtype=dtype, device=device, low=-5, high=5, requires_grad=True)
self.assertFalse(t.view(view_dtype).requires_grad)
# Test the extra error checks that happen when the view dtype
@ -227,7 +227,7 @@ class TestViewOps(TestCase):
continue
size_ratio = view_dtype_size // dtype_size
a = make_tensor((4, 4, size_ratio + 1), device, dtype, low=-5, high=5)
a = make_tensor((4, 4, size_ratio + 1), dtype=dtype, device=device, low=-5, high=5)
with self.assertRaisesRegex(
RuntimeError,
rf"self.size\(-1\) must be divisible by {size_ratio}"):
@ -238,7 +238,7 @@ class TestViewOps(TestCase):
rf"self.storage_offset\(\) must be divisible by {size_ratio}"):
a[:, :, 1:].view(view_dtype)
a = make_tensor((4, 4, size_ratio), device, dtype, low=-5, high=5)
a = make_tensor((4, 4, size_ratio), dtype=dtype, device=device, low=-5, high=5)
a = a.as_strided((4, 4, size_ratio), (size_ratio, 1, 1))
with self.assertRaisesRegex(
RuntimeError,
@ -342,7 +342,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*get_all_dtypes())
def test_view_tensor_split(self, device, dtype):
a = make_tensor((40, 30), device, dtype, low=-9, high=9)
a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9)
a_split_dim0 = a.tensor_split(7, 0)
for a_split_dim0_tensor in a_split_dim0:
self.assertTrue(self.is_view_of(a, a_split_dim0_tensor))
@ -353,7 +353,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*get_all_dtypes())
def test_view_tensor_hsplit(self, device, dtype):
t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9)
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_hsplit = torch.hsplit(t, 2)
for t_hsplit_tensor in t_hsplit:
self.assertTrue(self.is_view_of(t, t_hsplit_tensor))
@ -363,7 +363,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*get_all_dtypes())
def test_view_tensor_vsplit(self, device, dtype):
t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9)
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_vsplit = torch.vsplit(t, 2)
for t_vsplit_tensor in t_vsplit:
self.assertTrue(self.is_view_of(t, t_vsplit_tensor))
@ -373,7 +373,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*get_all_dtypes())
def test_view_tensor_dsplit(self, device, dtype):
t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9)
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_dsplit = torch.dsplit(t, 2)
for t_dsplit_tensor in t_dsplit:
self.assertTrue(self.is_view_of(t, t_dsplit_tensor))
@ -1478,7 +1478,7 @@ class TestOldViewOps(TestCase):
(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)
)
for s0, s1 in combinations(sizes, r=2):
t = make_tensor(s0, device, dtype, low=-9, high=9)
t = make_tensor(s0, dtype=dtype, device=device, low=-9, high=9)
t_np = t.cpu().numpy()
if can_broadcast(s0, s1):
@ -1568,7 +1568,7 @@ class TestOldViewOps(TestCase):
@dtypes(*get_all_dtypes())
def test_reshape_view_semantics(self, device, dtype):
tensor = make_tensor((15, 4), device, dtype)
tensor = make_tensor((15, 4), dtype=dtype, device=device)
target = (20, 3)
# Cases where the tensor can be returned as a view.
@ -1604,7 +1604,7 @@ class TestOldViewOps(TestCase):
(12, 3),
]
for input_size in input_sizes:
a_base = make_tensor(input_size, device, dtype, low=-9, high=9)
a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
# Run tests on transposed input if it has at least 2 dims
for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]:
a_n = a.cpu().numpy()
@ -1647,7 +1647,7 @@ class TestOldViewOps(TestCase):
(1, 5, 2, 8),
]
for input_size in input_sizes:
a_base = make_tensor(input_size, device, dtype, low=-9, high=9)
a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
# Run tests on transposed input if it has at least 2 dims
for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]:
a_n = a.cpu().numpy()