Revert "Revert D26733731: [pytorch][PR] Skip dispatch for `is_floatin… (#53242)

Summary:
…g_point`"

This reverts commit fbf2883d350f62d17292b71a58f404b5e3e58b7b.

Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53242

Reviewed By: mrshenli

Differential Revision: D26896105

Pulled By: iramazanli

fbshipit-source-id: 279a6f6d4fbb7949a7ed65df848db71a9b8d44e2
This commit is contained in:
iramazanli
2021-03-11 09:38:14 -08:00
committed by Facebook GitHub Bot
parent 7484c56fa3
commit d7b5a6faaa
5 changed files with 24 additions and 1 deletions

View File

@ -19,7 +19,7 @@ bool is_complex(const Tensor& self) {
}
bool is_floating_point(const Tensor& self) {
return at::isFloatingType(self.scalar_type());
return self.is_floating_point();
}
bool is_signed(const Tensor &self) {

View File

@ -2071,6 +2071,7 @@
- func: is_floating_point(Tensor self) -> bool
variants: function, method
device_guard: False
manual_cpp_binding: True
- func: is_complex(Tensor self) -> bool
variants: function, method

View File

@ -146,4 +146,8 @@ inline bool is_complex(const Tensor& tensor) {
return tensor.is_complex();
}
inline bool is_floating_point(const Tensor& tensor) {
return tensor.is_floating_point();
}
}

View File

@ -148,6 +148,10 @@ class TORCH_API Tensor {
return at::isComplexType(this->scalar_type());
}
bool is_floating_point() const {
return at::isFloatingType(this->scalar_type());
}
int64_t size(int64_t dim) const {
// false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
dim = c10::maybe_wrap_dim(dim, this->dim(), false);

View File

@ -152,6 +152,7 @@ class TestVmapAPI(TestCase):
with self.assertRaisesRegex(RuntimeError, msg):
vmap(out_op)(tensor, tensor)
tensor = torch.randn(2)
# The fallback doesn't support TensorList
with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
vmap(lambda t: torch.atleast_1d([t]))(tensor)
@ -1583,6 +1584,19 @@ class TestVmapOperators(Namespace.TestVmapBase):
self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))
def test_is_floating_point(self):
float_tensor = torch.tensor([1., 2., 3.])
long_tensor = torch.tensor([1, 2, 3])
def foo(x):
if x.is_floating_point():
return torch.tensor(1)
else:
return torch.tensor(0)
self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1]))
self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0]))
def test_is_contiguous(self):
def foo(x):
if x.is_contiguous():