Add "ndim" property to tensor (#20565)

Summary:
For compatibility with numpy.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20565

Differential Revision: D15374390

Pulled By: umanwizard

fbshipit-source-id: 4ab209a5fb27d8ba27ee7eb6b67b858ce2480594
This commit is contained in:
Brennan Vincent
2019-05-20 15:58:51 -07:00
committed by Facebook Github Bot
parent 6ae99aa5bc
commit 987f1ccf49
4 changed files with 22 additions and 0 deletions

View File

@ -142,6 +142,7 @@ view of a storage and defines numeric operations on it.
.. autoattribute:: is_cuda
.. autoattribute:: device
.. autoattribute:: grad
.. autoattribute:: ndim
.. automethod:: abs
.. automethod:: abs_

View File

@ -11529,6 +11529,14 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
f = Foo2()
self.assertEqual(f.foo(), 5)
def test_ndim(self):
a = torch.randn(1, 2, 3)
self.assertEqual(3, a.ndim)
b = torch.randn(())
self.assertEqual(0, b.ndim)
c = torch.randn(1, 0)
self.assertEqual(2, c.ndim)
# Functions to test negative dimension wrapping
METHOD = 1
INPLACE_METHOD = 2

View File

@ -3191,3 +3191,8 @@ add_docstr_all('device',
r"""
Is the :class:`torch.device` where this Tensor is.
""")
add_docstr_all('ndim',
r"""
Alias for :meth:`~Tensor.dim()`
""")

View File

@ -304,6 +304,13 @@ PyObject *THPVariable_get_requires_grad(THPVariable *self)
END_HANDLE_TH_ERRORS
}
PyObject *THPVariable_get_ndim(THPVariable *self)
{
HANDLE_TH_ERRORS
return PyInt_FromLong(self->cdata.dim());
END_HANDLE_TH_ERRORS
}
int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
{
HANDLE_TH_ERRORS
@ -443,6 +450,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"dtype", (getter)THPVariable_dtype, nullptr, nullptr, nullptr},
{"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr},
{"device", (getter)THPVariable_device, nullptr, nullptr, nullptr},
{"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr},
{nullptr}
};