Add new keys for Graphcore IPU (DispatchKey / Backend / DeviceType)

We need a key to register our out of tree backend: https://github.com/graphcore/poptorch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74763
Approved by: https://github.com/bdhirsh
This commit is contained in:
Anthony Barbier
2022-04-07 17:18:45 +00:00
committed by PyTorch MergeBot
parent c7ae23b50e
commit ce9e27a0fc
28 changed files with 165 additions and 5 deletions

View File

@ -80,6 +80,9 @@ class TORCH_API Context {
static bool hasHIP() {
return detail::getHIPHooks().hasHIP();
}
static bool hasIPU() {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::IPU);
}
static bool hasXLA() {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA);
}
@ -295,6 +298,10 @@ static inline bool hasHIP() {
return globalContext().hasHIP();
}
static inline bool hasIPU() {
return globalContext().hasIPU();
}
static inline bool hasXLA() {
return globalContext().hasXLA();
}

View File

@ -1493,6 +1493,7 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) {
// Nothing beyond this point is important for meta functions, so it's fine to exit early here.
// Extend the condition to ORT tesnors as ORT tensors also don't have storage.
if (common_device_.type() == DeviceType::XLA ||
common_device_.type() == DeviceType::IPU ||
common_device_.type() == DeviceType::Lazy ||
common_device_.type() == DeviceType::ORT ||
common_device_.type() == DeviceType::HPU) return;

View File

@ -370,6 +370,12 @@ class TORCH_API TensorBase {
return impl_->is_cuda();
}
/// Returns if a `Tensor` has IPU backend.
bool is_ipu() const {
// NB: this is not a native function to avoid dispatching overhead.
return impl_->is_ipu();
}
/// Returns if a `Tensor` has XPU backend.
bool is_xpu() const {
// NB: this is not a native function to avoid dispatching overhead.

View File

@ -1179,7 +1179,7 @@ Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
//
// We need to do the checks here instead of in `native_functions.yaml`
// to preserve backwards compatibility.
if (!self.is_xla() && !self.is_lazy()) {
if (!self.is_xla() && !self.is_lazy() && !self.is_ipu()) {
return self._reshape_alias(shape, stride.value());
} else {
return self.view(shape);

View File

@ -32,6 +32,7 @@ enum class Backend {
HIP,
VE,
FPGA,
IPU,
XPU,
SparseCPU,
SparseCUDA,
@ -96,6 +97,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::QuantizedCPU;
} else if (t == DispatchKey::QuantizedCUDA) {
return Backend::QuantizedCUDA;
} else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) {
return Backend::IPU;
} else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) {
return Backend::XPU;
} else if (t == DispatchKey::SparseXPU) {
@ -129,6 +132,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
return DispatchKey::XLA;
case Backend::Lazy:
return DispatchKey::Lazy;
case Backend::IPU:
return DispatchKey::IPU;
case Backend::XPU:
return DispatchKey::XPU;
case Backend::SparseXPU:
@ -196,6 +201,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::CPU;
case Backend::SparseCsrCUDA:
return DeviceType::CUDA;
case Backend::IPU:
return DeviceType::IPU;
case Backend::XPU:
case Backend::SparseXPU:
case Backend::QuantizedXPU:
@ -235,6 +242,8 @@ static inline const char* toString(Backend b) {
return "FPGA";
case Backend::XPU:
return "XPU";
case Backend::IPU:
return "IPU";
case Backend::ORT:
return "ORT";
case Backend::XLA:

View File

@ -20,6 +20,7 @@ DeviceType parse_type(const std::string& device_string) {
types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"ipu", DeviceType::IPU},
{"xpu", DeviceType::XPU},
{"mkldnn", DeviceType::MKLDNN},
{"opengl", DeviceType::OPENGL},
@ -47,7 +48,7 @@ DeviceType parse_type(const std::string& device_string) {
}
TORCH_CHECK(
false,
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
"Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
device_string);
}
enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };

View File

@ -96,6 +96,11 @@ struct C10_API Device final {
return type_ == DeviceType::XPU;
}
/// Return true if the device is of IPU type.
bool is_ipu() const noexcept {
return type_ == DeviceType::IPU;
}
/// Return true if the device is of HPU type.
bool is_hpu() const noexcept {
return type_ == DeviceType::HPU;

View File

@ -43,6 +43,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
return lower_case ? "meta" : "META";
case DeviceType::HPU:
return lower_case ? "hpu" : "HPU";
case DeviceType::IPU:
return lower_case ? "ipu" : "IPU";
default:
TORCH_CHECK(
false,
@ -84,6 +86,7 @@ bool isValidDeviceType(DeviceType d) {
case DeviceType::XPU:
case DeviceType::Meta:
case DeviceType::HPU:
case DeviceType::IPU:
return true;
default:
return false;

View File

@ -31,11 +31,12 @@ enum class DeviceType : int8_t {
HPU = 15, // HPU / HABANA
VE = 16, // SX-Aurora / NEC
Lazy = 17, // Lazy Tensors
IPU = 18, // Graphcore IPU
// NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType
// in DeviceType.cpp
// - Change the number below
COMPILE_TIME_MAX_DEVICE_TYPES = 18,
COMPILE_TIME_MAX_DEVICE_TYPES = 19,
};
constexpr DeviceType kCPU = DeviceType::CPU;
@ -52,18 +53,19 @@ constexpr DeviceType kXPU = DeviceType::XPU;
constexpr DeviceType kHPU = DeviceType::HPU;
constexpr DeviceType kVE = DeviceType::VE;
constexpr DeviceType kLazy = DeviceType::Lazy;
constexpr DeviceType kIPU = DeviceType::IPU;
// define explicit int constant
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
static_assert(
COMPILE_TIME_MAX_DEVICE_TYPES <= 18,
COMPILE_TIME_MAX_DEVICE_TYPES <= 19,
"Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
"for this constant to reflect the actual number of DeviceTypes we support "
"in PyTorch; it's important that this number is not too large as we "
"use this to allocate stack arrays in some places in our code. If you "
"are indeed just adding the 18th device type, feel free to change "
"are indeed just adding the 19th device type, feel free to change "
"the check to 32; but if you are adding some sort of extensible device "
"types registration, please be aware that you are affecting code that "
"this number is small. Try auditing uses of this constant.");

View File

@ -19,6 +19,8 @@ const char* toString(BackendComponent t) {
return "LazyBit";
case BackendComponent::XPUBit:
return "XPUBit";
case BackendComponent::IPUBit:
return "IPUBit";
case BackendComponent::MLCBit:
return "MLCBit";
case BackendComponent::HPUBit:
@ -54,6 +56,8 @@ const char* toString(DispatchKey t) {
return "FPGA";
case DispatchKey::XPU:
return "XPU";
case DispatchKey::IPU:
return "IPU";
case DispatchKey::ORT:
return "ORT";
case DispatchKey::XLA:
@ -124,6 +128,8 @@ const char* toString(DispatchKey t) {
return "Autograd";
case DispatchKey::AutogradCPU:
return "AutogradCPU";
case DispatchKey::AutogradIPU:
return "AutogradIPU";
case DispatchKey::AutogradXPU:
return "AutogradXPU";
case DispatchKey::AutogradCUDA:
@ -284,6 +290,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"XLA", c10::DispatchKey::XLA},
{"MLC", c10::DispatchKey::MLC},
{"XPU", c10::DispatchKey::XPU},
{"IPU", c10::DispatchKey::IPU},
{"HPU", c10::DispatchKey::HPU},
{"Lazy", c10::DispatchKey::Lazy},
{"NestedTensor", c10::DispatchKey::NestedTensor},
@ -305,6 +312,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
{"AutogradXLA", c10::DispatchKey::AutogradXLA},
{"AutogradLazy", c10::DispatchKey::AutogradLazy},
{"AutogradIPU", c10::DispatchKey::AutogradIPU},
{"AutogradXPU", c10::DispatchKey::AutogradXPU},
{"AutogradMLC", c10::DispatchKey::AutogradMLC},
{"AutogradHPU", c10::DispatchKey::AutogradHPU},

View File

@ -52,6 +52,7 @@ enum class BackendComponent : uint8_t {
HIPBit,
XLABit,
MLCBit,
IPUBit,
XPUBit,
HPUBit,
VEBit,
@ -393,6 +394,7 @@ enum class DispatchKey : uint16_t {
// CUDA]
XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
IPU, // lives out of tree at https://github.com/graphcore/poptorch
XPU, // For out of tree Intel's heterogeneous computing plug-in
HPU, // For out of tree & closed source integration of HPU / Habana
VE, // For out of tree & closed source integration of SX-Aurora / NEC
@ -416,6 +418,7 @@ enum class DispatchKey : uint16_t {
_QuantizedHIP,
_QuantizedXLA,
_QuantizedMLC,
_QuantizedIPU,
QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in
_QuantizedHPU,
_QuantizedVE,
@ -437,6 +440,7 @@ enum class DispatchKey : uint16_t {
// [Masquerading as CUDA]
_SparseXLA,
_SparseMLC,
_SparseIPU,
SparseXPU, // For out of tree Intel's heterogeneous computing plug-in
_SparseHPU,
SparseVE, // For out of tree & closed source integration of SX-Aurora / NEC
@ -457,6 +461,7 @@ enum class DispatchKey : uint16_t {
_AutogradHIP,
AutogradXLA,
AutogradMLC,
AutogradIPU,
AutogradXPU,
AutogradHPU,
_AutogradVE,

View File

@ -79,6 +79,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
return DispatchKeySet(DispatchKey::MLC);
case DispatchKey::AutogradHPU:
return DispatchKeySet(DispatchKey::HPU);
case DispatchKey::AutogradIPU:
return DispatchKeySet(DispatchKey::IPU);
case DispatchKey::AutogradXPU:
return DispatchKeySet(DispatchKey::XPU);
case DispatchKey::AutogradPrivateUse1:

View File

@ -717,6 +717,7 @@ constexpr DispatchKeySet backend_bitset_mask =
constexpr auto inplace_or_view_ks =
DispatchKeySet(DispatchKey::ADInplaceOrView);
constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU);
constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
@ -777,6 +778,8 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
switch (t) {
case BackendComponent::CPUBit:
return inplace_or_view_ks | autograd_cpu_ks;
case BackendComponent::IPUBit:
return inplace_or_view_ks | autograd_ipu_ks;
case BackendComponent::XPUBit:
return inplace_or_view_ks | autograd_xpu_ks;
case BackendComponent::CUDABit:

View File

@ -737,6 +737,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return key_set_.has_all(xpu_ks);
}
bool is_ipu() const {
constexpr auto ipu_ks = DispatchKeySet(BackendComponent::IPUBit);
return key_set_.has_all(ipu_ks);
}
bool is_xla() const {
constexpr auto xla_ks = DispatchKeySet(BackendComponent::XLABit);
return key_set_.has_all(xla_ks);

View File

@ -643,6 +643,9 @@ inline DispatchKey computeDispatchKey(
}
return DispatchKey::CUDA;
}
case DeviceType::IPU: {
return DispatchKey::IPU;
}
case DeviceType::XPU: {
if (isQIntType(dtype_)) {
return DispatchKey::QuantizedXPU;
@ -780,6 +783,9 @@ inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
return DeviceType::Meta;
// stuff that people are actively developing
case DispatchKey::IPU:
case DispatchKey::AutogradIPU:
return DeviceType::IPU;
case DispatchKey::XPU:
case DispatchKey::SparseXPU:
case DispatchKey::QuantizedXPU:

View File

@ -541,6 +541,28 @@ static PyObject * THPVariable_xpu(PyObject* self, PyObject* args, PyObject* kwar
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_ipu(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"ipu(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)",
"ipu(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated"
});
auto& self_ = THPVariable_Unpack(self);
ParsedArgs<3> parsed_args;
auto r = parser.parse(self, args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
}
auto device = r.isNone(0) ? at::Device(at::DeviceType::IPU) : r.device(0);
auto opt_memory_format = r.memoryformatOptional(2);
TORCH_CHECK(device.is_ipu(), "Invalid device, must be ipu device");
return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_to_type(PyObject* self, ScalarType scalarType, c10::optional<c10::MemoryFormat> optional_memory_format) {
HANDLE_TH_ERRORS
auto& self_ = THPVariable_Unpack(self);
@ -1205,6 +1227,7 @@ PyMethodDef variable_methods[] = {
{"cpu", castPyCFunctionWithKeywords(THPVariable_cpu), METH_VARARGS | METH_KEYWORDS, NULL},
{"cuda", castPyCFunctionWithKeywords(THPVariable_cuda), METH_VARARGS | METH_KEYWORDS, NULL},
{"xpu", castPyCFunctionWithKeywords(THPVariable_xpu), METH_VARARGS | METH_KEYWORDS, NULL},
{"ipu", castPyCFunctionWithKeywords(THPVariable_ipu), METH_VARARGS | METH_KEYWORDS, NULL},
{"data_ptr", THPVariable_data_ptr, METH_NOARGS, NULL},
{"dim", THPVariable_dim, METH_NOARGS, NULL},
{"has_names", THPVariable_has_names, METH_NOARGS, NULL},

View File

@ -84,6 +84,7 @@ class DispatchKey(Enum):
HIP = auto()
XLA = auto()
Lazy = auto()
IPU = auto()
XPU = auto()
NestedTensor = auto()
PrivateUse1 = auto()
@ -103,6 +104,7 @@ class DispatchKey(Enum):
AutogradCUDA = auto()
AutogradXLA = auto()
AutogradLazy = auto()
AutogradIPU = auto()
AutogradXPU = auto()
AutogradPrivateUse1 = auto()
AutogradPrivateUse2 = auto()

View File

@ -478,6 +478,7 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
'is_ort': ['is_ort: _bool'],
'is_mkldnn': ['is_mkldnn: _bool'],
'is_vulkan': ['is_vulkan: _bool'],
'is_ipu': ['is_ipu: _bool'],
'storage_offset': ['def storage_offset(self) -> _int: ...'],
'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '

View File

@ -1060,6 +1060,24 @@ Args:
{memory_format}
""".format(**common_args))
add_docstr_all('ipu',
r"""
ipu(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor
Returns a copy of this object in IPU memory.
If this object is already in IPU memory and on the correct device,
then no copy is performed and the original object is returned.
Args:
device (:class:`torch.device`): The destination IPU device.
Defaults to the current IPU device.
non_blocking (bool): If ``True`` and the source is in pinned memory,
the copy will be asynchronous with respect to the host.
Otherwise, the argument has no effect. Default: ``False``.
{memory_format}
""".format(**common_args))
add_docstr_all('xpu',
r"""
xpu(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor
@ -4946,6 +4964,11 @@ add_docstr_all('is_cuda',
Is ``True`` if the Tensor is stored on the GPU, ``False`` otherwise.
""")
add_docstr_all('is_ipu',
r"""
Is ``True`` if the Tensor is stored on the IPU, ``False`` otherwise.
""")
add_docstr_all('is_xpu',
r"""
Is ``True`` if the Tensor is stored on the XPU, ``False`` otherwise.

View File

@ -899,6 +899,16 @@ PyObject *THPVariable_is_cuda(THPVariable *self, void *unused)
END_HANDLE_TH_ERRORS
}
PyObject* THPVariable_is_ipu(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "is_ipu");
}
auto& self_ = THPVariable_Unpack(self);
return torch::autograd::utils::wrap(self_.is_ipu());
END_HANDLE_TH_ERRORS
}
PyObject* THPVariable_is_xpu(THPVariable* self, void* unused) {
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject*)self)) {
@ -1128,6 +1138,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
{"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
{"is_ipu", (getter)THPVariable_is_ipu, nullptr, nullptr, nullptr},
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
{"is_sparse_csr", (getter)THPVariable_is_sparse_csr, nullptr, nullptr, nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},

View File

@ -121,6 +121,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
{"is_mlc", "prim"},
{"is_quantized", "prim"},
{"is_vulkan", "prim"},
{"is_ipu", "prim"},
{"is_meta", "prim"},
{"is_leaf", "aten"},
{"is_nested", "prim"},

View File

@ -2260,6 +2260,14 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs1{
push(stack, a.is_vulkan());
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::is_ipu(Tensor a) -> bool"),
[](Stack& stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.is_ipu());
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::is_quantized(Tensor a) -> bool"),
[](Stack& stack) {

View File

@ -357,6 +357,7 @@ void check_base_legacy_new(c10::DispatchKey dispatch_key, at::Layout expected_la
c10::DispatchKey::HIP,
c10::DispatchKey::XLA,
c10::DispatchKey::Lazy,
c10::DispatchKey::IPU,
c10::DispatchKey::XPU,
c10::DispatchKey::HPU,
});

View File

@ -21,6 +21,7 @@ static const char* backend_to_string(const at::Backend& backend) {
case at::Backend::CPU: return "torch";
case at::Backend::CUDA: return "torch.cuda";
case at::Backend::XPU: return "torch.xpu";
case at::Backend::IPU: return "torch.ipu";
case at::Backend::SparseCPU: return "torch.sparse";
case at::Backend::SparseCUDA: return "torch.cuda.sparse";
case at::Backend::SparseXPU: return "torch.xpu.sparse";

View File

@ -323,6 +323,9 @@ class _RemoteModule(nn.Module):
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
_raise_not_supported(self.cuda.__name__)
def ipu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
_raise_not_supported(self.ipu.__name__)
def xpu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
_raise_not_supported(self.xpu.__name__)

View File

@ -355,6 +355,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
return c10::DispatchKey::CPU;
case c10::DeviceType::CUDA:
return c10::DispatchKey::CUDA;
case c10::DeviceType::IPU:
return c10::DispatchKey::IPU;
case c10::DeviceType::XLA:
return c10::DispatchKey::XLA;
case c10::DeviceType::Lazy:

View File

@ -687,6 +687,25 @@ class Module:
"""
return self._apply(lambda t: t.cuda(device))
def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:
r"""Moves all model parameters and buffers to the IPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on IPU while being optimized.
.. note::
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
copied to that device
Returns:
Module: self
"""
return self._apply(lambda t: t.ipu(device))
def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:
r"""Moves all model parameters and buffers to the XPU.

View File

@ -1071,6 +1071,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.dtype.__get__: lambda self: -1,
Tensor.is_cuda.__get__: lambda self: -1,
Tensor.is_xpu.__get__: lambda self: -1,
Tensor.is_ipu.__get__: lambda self: -1,
Tensor.is_leaf.__get__: lambda self: -1,
Tensor.retains_grad.__get__: lambda self: -1,
Tensor.is_meta.__get__: lambda self: -1,
@ -1123,6 +1124,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1,
Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1,
Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1,
Tensor.ipu: lambda self, memory_format=torch.preserve_format: -1,
Tensor.data_ptr: lambda self: -1,
Tensor.dense_dim: lambda self: -1,
Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1,