mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torch.tensor: add tests for list of numpy arrays case
References: https://github.com/pytorch/pytorch/issues/13918 Add more test cases for list of numpy array inputs Pull Request resolved: https://github.com/pytorch/pytorch/pull/72249 Approved by: https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
631f035131
commit
0509022450
@ -234,6 +234,28 @@ class TestNumPyInterop(TestCase):
|
||||
with self.assertWarnsOnceRegex(UserWarning, warning_msg):
|
||||
torch.tensor([np.array([0]), np.array([1])], device=device)
|
||||
|
||||
def test_ctor_with_invalid_numpy_array_sequence(self, device):
|
||||
# Invalid list of numpy array
|
||||
with self.assertRaisesRegex(ValueError, "expected sequence of length"):
|
||||
torch.tensor([np.random.random(size=(3, 3)), np.random.random(size=(3, 0))], device=device)
|
||||
|
||||
# Invalid list of list of numpy array
|
||||
with self.assertRaisesRegex(ValueError, "expected sequence of length"):
|
||||
torch.tensor([[np.random.random(size=(3, 3)), np.random.random(size=(3, 2))]], device=device)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "expected sequence of length"):
|
||||
torch.tensor([[np.random.random(size=(3, 3)), np.random.random(size=(3, 3))],
|
||||
[np.random.random(size=(3, 3)), np.random.random(size=(3, 2))]], device=device)
|
||||
|
||||
# expected shape is `[1, 2, 3]`, hence we try to iterate over 0-D array
|
||||
# leading to type error : not a sequence.
|
||||
with self.assertRaisesRegex(TypeError, "not a sequence"):
|
||||
torch.tensor([[np.random.random(size=(3)), np.random.random()]], device=device)
|
||||
|
||||
# list of list or numpy array.
|
||||
with self.assertRaisesRegex(ValueError, "expected sequence of length"):
|
||||
torch.tensor([[1, 2, 3], np.random.random(size=(2,)), ], device=device)
|
||||
|
||||
@onlyCPU
|
||||
def test_ctor_with_numpy_scalar_ctor(self, device) -> None:
|
||||
dtypes = [
|
||||
|
Reference in New Issue
Block a user