mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added is_xla (#103100)
This change creates `is_xla` which is congruent with `is_cuda` and `is_cpu`. Useful in situations like: https://github.com/pytorch/pytorch/pull/102858 ``` >>> x = torch.tensor([1], device=xm.xla_device()) >>> x.is_xla True >>> x.is_cpu False >>> x = torch.tensor([1]) >>> x.is_cpu True >>> x.is_xla False ``` Attn: @albanD Pull Request resolved: https://github.com/pytorch/pytorch/pull/103100 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
49dc26435f
commit
4e204ff87b
@ -6603,6 +6603,13 @@ Is ``True`` if the Tensor is stored on the CPU, ``False`` otherwise.
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"is_xla",
|
||||
r"""
|
||||
Is ``True`` if the Tensor is stored on an XLA device, ``False`` otherwise.
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"is_ipu",
|
||||
r"""
|
||||
|
@ -1204,6 +1204,16 @@ PyObject* THPVariable_is_cuda(THPVariable* self, void* unused) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPVariable_is_xla(THPVariable* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_getter(self, "is_xla");
|
||||
}
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
return torch::autograd::utils::wrap(self_.is_xla());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPVariable_is_ipu(THPVariable* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
@ -1470,6 +1480,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
|
||||
{"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
|
||||
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
|
||||
{"is_cpu", (getter)THPVariable_is_cpu, nullptr, nullptr, nullptr},
|
||||
{"is_xla", (getter)THPVariable_is_xla, nullptr, nullptr, nullptr},
|
||||
{"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
|
||||
{"is_ipu", (getter)THPVariable_is_ipu, nullptr, nullptr, nullptr},
|
||||
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
|
||||
|
@ -114,6 +114,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
||||
{"shape", "prim"},
|
||||
{"is_cuda", "prim"},
|
||||
{"is_cpu", "prim"},
|
||||
{"is_xla", "prim"},
|
||||
{"is_xpu", "prim"},
|
||||
{"is_sparse", "prim"},
|
||||
{"is_sparse_csr", "prim"},
|
||||
|
@ -1149,6 +1149,14 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
push(stack, a.is_cpu());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("prim::is_xla(Tensor a) -> bool"),
|
||||
[](Stack& stack) {
|
||||
at::Tensor a;
|
||||
pop(stack, a);
|
||||
push(stack, a.is_xla());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("prim::is_xpu(Tensor a) -> bool"),
|
||||
[](Stack& stack) {
|
||||
|
@ -1215,6 +1215,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
Tensor.dtype.__get__: lambda self: -1,
|
||||
Tensor.is_cuda.__get__: lambda self: -1,
|
||||
Tensor.is_cpu.__get__: lambda self: -1,
|
||||
Tensor.is_xla.__get__: lambda self: -1,
|
||||
Tensor.is_xpu.__get__: lambda self: -1,
|
||||
Tensor.is_ipu.__get__: lambda self: -1,
|
||||
Tensor.is_leaf.__get__: lambda self: -1,
|
||||
|
Reference in New Issue
Block a user