mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
align signature of make_tensor with other creation ops (#72702)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72702 Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D34457729 Pulled By: mruberry fbshipit-source-id: 83d580c4201eef946dc9cf4b9e28a3d36be55609 (cherry picked from commit aa4cf20fbeb4b795595729b8ac2e6ba7707d8283)
This commit is contained in:
committed by
PyTorch MergeBot
parent
5f310c5e27
commit
0973c5a1cc
@ -134,13 +134,13 @@ class TestViewOps(TestCase):
|
||||
del dtypes[torch.bool]
|
||||
|
||||
def generate_inputs():
|
||||
yield make_tensor((4, 4, 64), device, dtype, low=-5, high=5)
|
||||
yield make_tensor((4, 4, 64), device, dtype, low=-5, high=5).permute(1, 0, 2)
|
||||
yield make_tensor((4, 64, 4), device, dtype, low=-5, high=5).permute(2, 0, 1)
|
||||
yield make_tensor((1, 5, 1), device, dtype, low=-5, high=5).expand(5, 5, 64)
|
||||
yield make_tensor((2, 5, 256), device, dtype, low=-5, high=5)[1::2, 1:, ::2]
|
||||
yield make_tensor((0, 5, 64), device, dtype, low=-5, high=5)
|
||||
yield make_tensor((), device, dtype, low=-5, high=5)
|
||||
yield make_tensor((4, 4, 64), dtype=dtype, device=device, low=-5, high=5)
|
||||
yield make_tensor((4, 4, 64), dtype=dtype, device=device, low=-5, high=5).permute(1, 0, 2)
|
||||
yield make_tensor((4, 64, 4), dtype=dtype, device=device, low=-5, high=5).permute(2, 0, 1)
|
||||
yield make_tensor((1, 5, 1), dtype=dtype, device=device, low=-5, high=5).expand(5, 5, 64)
|
||||
yield make_tensor((2, 5, 256), dtype=dtype, device=device, low=-5, high=5)[1::2, 1:, ::2]
|
||||
yield make_tensor((0, 5, 64), dtype=dtype, device=device, low=-5, high=5)
|
||||
yield make_tensor((), dtype=dtype, device=device, low=-5, high=5)
|
||||
|
||||
def calc_expected_size_and_stride(a, view_dtype):
|
||||
dtype_size = torch._utils._element_size(a.dtype)
|
||||
@ -211,7 +211,7 @@ class TestViewOps(TestCase):
|
||||
# TODO: Remove this when autograd support is added
|
||||
if dtype.is_floating_point or dtype.is_complex:
|
||||
for view_dtype in [*get_all_fp_dtypes(), *get_all_complex_dtypes()]:
|
||||
t = make_tensor((5, 5, 64), device, dtype, low=-5, high=5, requires_grad=True)
|
||||
t = make_tensor((5, 5, 64), dtype=dtype, device=device, low=-5, high=5, requires_grad=True)
|
||||
self.assertFalse(t.view(view_dtype).requires_grad)
|
||||
|
||||
# Test the extra error checks that happen when the view dtype
|
||||
@ -227,7 +227,7 @@ class TestViewOps(TestCase):
|
||||
continue
|
||||
|
||||
size_ratio = view_dtype_size // dtype_size
|
||||
a = make_tensor((4, 4, size_ratio + 1), device, dtype, low=-5, high=5)
|
||||
a = make_tensor((4, 4, size_ratio + 1), dtype=dtype, device=device, low=-5, high=5)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
rf"self.size\(-1\) must be divisible by {size_ratio}"):
|
||||
@ -238,7 +238,7 @@ class TestViewOps(TestCase):
|
||||
rf"self.storage_offset\(\) must be divisible by {size_ratio}"):
|
||||
a[:, :, 1:].view(view_dtype)
|
||||
|
||||
a = make_tensor((4, 4, size_ratio), device, dtype, low=-5, high=5)
|
||||
a = make_tensor((4, 4, size_ratio), dtype=dtype, device=device, low=-5, high=5)
|
||||
a = a.as_strided((4, 4, size_ratio), (size_ratio, 1, 1))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
@ -342,7 +342,7 @@ class TestViewOps(TestCase):
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_view_tensor_split(self, device, dtype):
|
||||
a = make_tensor((40, 30), device, dtype, low=-9, high=9)
|
||||
a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9)
|
||||
a_split_dim0 = a.tensor_split(7, 0)
|
||||
for a_split_dim0_tensor in a_split_dim0:
|
||||
self.assertTrue(self.is_view_of(a, a_split_dim0_tensor))
|
||||
@ -353,7 +353,7 @@ class TestViewOps(TestCase):
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_view_tensor_hsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9)
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_hsplit = torch.hsplit(t, 2)
|
||||
for t_hsplit_tensor in t_hsplit:
|
||||
self.assertTrue(self.is_view_of(t, t_hsplit_tensor))
|
||||
@ -363,7 +363,7 @@ class TestViewOps(TestCase):
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_view_tensor_vsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9)
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_vsplit = torch.vsplit(t, 2)
|
||||
for t_vsplit_tensor in t_vsplit:
|
||||
self.assertTrue(self.is_view_of(t, t_vsplit_tensor))
|
||||
@ -373,7 +373,7 @@ class TestViewOps(TestCase):
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_view_tensor_dsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9)
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_dsplit = torch.dsplit(t, 2)
|
||||
for t_dsplit_tensor in t_dsplit:
|
||||
self.assertTrue(self.is_view_of(t, t_dsplit_tensor))
|
||||
@ -1478,7 +1478,7 @@ class TestOldViewOps(TestCase):
|
||||
(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)
|
||||
)
|
||||
for s0, s1 in combinations(sizes, r=2):
|
||||
t = make_tensor(s0, device, dtype, low=-9, high=9)
|
||||
t = make_tensor(s0, dtype=dtype, device=device, low=-9, high=9)
|
||||
t_np = t.cpu().numpy()
|
||||
|
||||
if can_broadcast(s0, s1):
|
||||
@ -1568,7 +1568,7 @@ class TestOldViewOps(TestCase):
|
||||
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_reshape_view_semantics(self, device, dtype):
|
||||
tensor = make_tensor((15, 4), device, dtype)
|
||||
tensor = make_tensor((15, 4), dtype=dtype, device=device)
|
||||
target = (20, 3)
|
||||
|
||||
# Cases where the tensor can be returned as a view.
|
||||
@ -1604,7 +1604,7 @@ class TestOldViewOps(TestCase):
|
||||
(12, 3),
|
||||
]
|
||||
for input_size in input_sizes:
|
||||
a_base = make_tensor(input_size, device, dtype, low=-9, high=9)
|
||||
a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
|
||||
# Run tests on transposed input if it has at least 2 dims
|
||||
for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]:
|
||||
a_n = a.cpu().numpy()
|
||||
@ -1647,7 +1647,7 @@ class TestOldViewOps(TestCase):
|
||||
(1, 5, 2, 8),
|
||||
]
|
||||
for input_size in input_sizes:
|
||||
a_base = make_tensor(input_size, device, dtype, low=-9, high=9)
|
||||
a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
|
||||
# Run tests on transposed input if it has at least 2 dims
|
||||
for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]:
|
||||
a_n = a.cpu().numpy()
|
||||
|
Reference in New Issue
Block a user