diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 385b26d4d6c3..2460c4552ada 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -142,6 +142,7 @@ view of a storage and defines numeric operations on it. .. autoattribute:: is_cuda .. autoattribute:: device .. autoattribute:: grad + .. autoattribute:: ndim .. automethod:: abs .. automethod:: abs_ diff --git a/test/test_torch.py b/test/test_torch.py index 4a5f98dd54eb..20ae72c8ff1c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -11529,6 +11529,14 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], f = Foo2() self.assertEqual(f.foo(), 5) + def test_ndim(self): + a = torch.randn(1, 2, 3) + self.assertEqual(3, a.ndim) + b = torch.randn(()) + self.assertEqual(0, b.ndim) + c = torch.randn(1, 0) + self.assertEqual(2, c.ndim) + # Functions to test negative dimension wrapping METHOD = 1 INPLACE_METHOD = 2 diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index b904f09009fa..44298e710874 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -3191,3 +3191,8 @@ add_docstr_all('device', r""" Is the :class:`torch.device` where this Tensor is. """) + +add_docstr_all('ndim', + r""" +Alias for :meth:`~Tensor.dim()` +""") diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 0169aa0d44d8..9b3ede12c38a 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -304,6 +304,13 @@ PyObject *THPVariable_get_requires_grad(THPVariable *self) END_HANDLE_TH_ERRORS } +PyObject *THPVariable_get_ndim(THPVariable *self) +{ + HANDLE_TH_ERRORS + return PyInt_FromLong(self->cdata.dim()); + END_HANDLE_TH_ERRORS +} + int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj) { HANDLE_TH_ERRORS @@ -443,6 +450,7 @@ static struct PyGetSetDef THPVariable_properties[] = { {"dtype", (getter)THPVariable_dtype, nullptr, nullptr, nullptr}, {"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr}, {"device", (getter)THPVariable_device, nullptr, nullptr, nullptr}, + {"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr}, {nullptr} };