Added is_xla (#103100)

This change creates `is_xla` which is congruent with `is_cuda` and `is_cpu`. Useful in situations like: https://github.com/pytorch/pytorch/pull/102858

```
>>> x = torch.tensor([1], device=xm.xla_device())
>>> x.is_xla
True
>>> x.is_cpu
False
>>> x = torch.tensor([1])
>>> x.is_cpu
True
>>> x.is_xla
False
```

Attn: @albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103100
Approved by: https://github.com/albanD
This commit is contained in:
Muralidhar Andoorveedu
2023-06-22 23:31:00 +00:00
committed by PyTorch MergeBot
parent 49dc26435f
commit 4e204ff87b
5 changed files with 28 additions and 0 deletions

View File

@ -6603,6 +6603,13 @@ Is ``True`` if the Tensor is stored on the CPU, ``False`` otherwise.
""",
)
add_docstr_all(
"is_xla",
r"""
Is ``True`` if the Tensor is stored on an XLA device, ``False`` otherwise.
""",
)
add_docstr_all(
"is_ipu",
r"""

View File

@ -1204,6 +1204,16 @@ PyObject* THPVariable_is_cuda(THPVariable* self, void* unused) {
END_HANDLE_TH_ERRORS
}
PyObject* THPVariable_is_xla(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "is_xla");
}
auto& self_ = THPVariable_Unpack(self);
return torch::autograd::utils::wrap(self_.is_xla());
END_HANDLE_TH_ERRORS
}
PyObject* THPVariable_is_ipu(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
@ -1470,6 +1480,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
{"is_cpu", (getter)THPVariable_is_cpu, nullptr, nullptr, nullptr},
{"is_xla", (getter)THPVariable_is_xla, nullptr, nullptr, nullptr},
{"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
{"is_ipu", (getter)THPVariable_is_ipu, nullptr, nullptr, nullptr},
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},

View File

@ -114,6 +114,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
{"shape", "prim"},
{"is_cuda", "prim"},
{"is_cpu", "prim"},
{"is_xla", "prim"},
{"is_xpu", "prim"},
{"is_sparse", "prim"},
{"is_sparse_csr", "prim"},

View File

@ -1149,6 +1149,14 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, a.is_cpu());
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::is_xla(Tensor a) -> bool"),
[](Stack& stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.is_xla());
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::is_xpu(Tensor a) -> bool"),
[](Stack& stack) {

View File

@ -1215,6 +1215,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.dtype.__get__: lambda self: -1,
Tensor.is_cuda.__get__: lambda self: -1,
Tensor.is_cpu.__get__: lambda self: -1,
Tensor.is_xla.__get__: lambda self: -1,
Tensor.is_xpu.__get__: lambda self: -1,
Tensor.is_ipu.__get__: lambda self: -1,
Tensor.is_leaf.__get__: lambda self: -1,