mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Revert D26733731: [pytorch][PR] Skip dispatch for `is_floatin… (#53242)
Summary:
…g_point`"
This reverts commit fbf2883d350f62d17292b71a58f404b5e3e58b7b.
Fixes #{issue number}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53242
Reviewed By: mrshenli
Differential Revision: D26896105
Pulled By: iramazanli
fbshipit-source-id: 279a6f6d4fbb7949a7ed65df848db71a9b8d44e2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7484c56fa3
commit
d7b5a6faaa
@ -19,7 +19,7 @@ bool is_complex(const Tensor& self) {
|
||||
}
|
||||
|
||||
bool is_floating_point(const Tensor& self) {
|
||||
return at::isFloatingType(self.scalar_type());
|
||||
return self.is_floating_point();
|
||||
}
|
||||
|
||||
bool is_signed(const Tensor &self) {
|
||||
|
||||
@ -2071,6 +2071,7 @@
|
||||
- func: is_floating_point(Tensor self) -> bool
|
||||
variants: function, method
|
||||
device_guard: False
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: is_complex(Tensor self) -> bool
|
||||
variants: function, method
|
||||
|
||||
@ -146,4 +146,8 @@ inline bool is_complex(const Tensor& tensor) {
|
||||
return tensor.is_complex();
|
||||
}
|
||||
|
||||
inline bool is_floating_point(const Tensor& tensor) {
|
||||
return tensor.is_floating_point();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -148,6 +148,10 @@ class TORCH_API Tensor {
|
||||
return at::isComplexType(this->scalar_type());
|
||||
}
|
||||
|
||||
bool is_floating_point() const {
|
||||
return at::isFloatingType(this->scalar_type());
|
||||
}
|
||||
|
||||
int64_t size(int64_t dim) const {
|
||||
// false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
|
||||
dim = c10::maybe_wrap_dim(dim, this->dim(), false);
|
||||
|
||||
@ -152,6 +152,7 @@ class TestVmapAPI(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
vmap(out_op)(tensor, tensor)
|
||||
|
||||
tensor = torch.randn(2)
|
||||
# The fallback doesn't support TensorList
|
||||
with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
|
||||
vmap(lambda t: torch.atleast_1d([t]))(tensor)
|
||||
@ -1583,6 +1584,19 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||
self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
|
||||
self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))
|
||||
|
||||
def test_is_floating_point(self):
|
||||
float_tensor = torch.tensor([1., 2., 3.])
|
||||
long_tensor = torch.tensor([1, 2, 3])
|
||||
|
||||
def foo(x):
|
||||
if x.is_floating_point():
|
||||
return torch.tensor(1)
|
||||
else:
|
||||
return torch.tensor(0)
|
||||
|
||||
self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1]))
|
||||
self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0]))
|
||||
|
||||
def test_is_contiguous(self):
|
||||
def foo(x):
|
||||
if x.is_contiguous():
|
||||
|
||||
Reference in New Issue
Block a user