mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Numpy zerotensor handling (#164487)
Fixes #89034 Updated tensor_to_numpy() function in tensor_numpy.cpp to handle ZeroTensors by throwing an error if force=False and returning an array full of zeros if force=True. @ngimel, I just saw that you mentioned PyTorch is not too concerned with this issue but I had already worked on it so I figured I would push it anyways and see what you thought. Feel free to close the PR if you think it is not worth merging. @albanD Pull Request resolved: https://github.com/pytorch/pytorch/pull/164487 Approved by: https://github.com/ngimel, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
f46bb04dcc
commit
f7ad6dbad6
@ -164,6 +164,28 @@ class TestNumPyInterop(TestCase):
|
||||
self.assertEqual(y.dtype, np.bool_)
|
||||
self.assertEqual(x[0], y[0])
|
||||
|
||||
@skipIfTorchDynamo(
|
||||
"can't check if value is ZeroTensor since _is_zerotensor returns a bool and not a TensorVariable"
|
||||
)
|
||||
def test_to_numpy_zero_tensor(self, device) -> None:
|
||||
dtypes = [
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.short,
|
||||
torch.int,
|
||||
torch.half,
|
||||
torch.float,
|
||||
torch.double,
|
||||
torch.long,
|
||||
torch.bool,
|
||||
]
|
||||
for dtype in dtypes:
|
||||
x = torch._efficientzerotensor((10), dtype=dtype)
|
||||
self.assertRaises(RuntimeError, lambda: x.numpy())
|
||||
y = x.numpy(force=True)
|
||||
for i in range(10):
|
||||
self.assertEqual(y[i], 0)
|
||||
|
||||
@skipIfTorchDynamo("conj bit not implemented in TensorVariable yet")
|
||||
def test_to_numpy_force_argument(self, device) -> None:
|
||||
for force in [False, True]:
|
||||
|
@ -145,6 +145,10 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force /*=false*/) {
|
||||
" device type tensor to numpy. Use Tensor.cpu() to ",
|
||||
"copy the tensor to host memory first.");
|
||||
|
||||
TORCH_CHECK(
|
||||
!at::_is_zerotensor(tensor),
|
||||
" Cannot convert a ZeroTensor to numpy. Set force=True if you need the zero array.");
|
||||
|
||||
TORCH_CHECK(
|
||||
!(at::GradMode::is_enabled() && tensor.requires_grad()),
|
||||
"Can't call numpy() on Tensor that requires grad. "
|
||||
@ -186,6 +190,9 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force /*=false*/) {
|
||||
if (!array)
|
||||
return nullptr;
|
||||
|
||||
if (at::_is_zerotensor(tensor))
|
||||
PyArray_FILLWBYTE(reinterpret_cast<PyArrayObject*>(array.get()), 0);
|
||||
|
||||
// TODO: This attempts to keep the underlying memory alive by setting the base
|
||||
// object of the ndarray to the tensor and disabling resizes on the storage.
|
||||
// This is not sufficient. For example, the tensor's storage may be changed
|
||||
|
Reference in New Issue
Block a user