__cuda_array_interface__: Use "<V2" for bfloat16. (#143042)

Rationale: While Numpy doesn't support `bfloat16` and therefore there's no official typestr for `bfloat16` in `__array_interface__` (https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.interface.html#__array_interface__), JAX/ml_dtypes uses "<V2":

```
>>> from jax import numpy as jnp
>>> jnp.bfloat16.dtype.str
'<V2'
```

Using the same in PyTorch has the upside of making the typestrs returned by `__cuda_array_interface__` identify the torch dtype uniquely.

### Misc notes

(1) JAX itself just refuses to do `__cuda_array_interface__` for `bfloat16`:

```
>>> from jax import numpy as jnp
>>> jnp.arange(10, dtype=jnp.bfloat16).__cuda_array_interface__
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: __cuda_array_interface__ is not supported for bfloat16 buffers.
```

(2) The "official" description of `__cuda_array_interface__` doesn't mention bfloat16, it just references `__array_interface__`: https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html

(3) Ongoing issue for numpy to support bfloat16: https://github.com/numpy/numpy/issues/19808

(4) Tweet that triggered this: https://x.com/HeinrichKuttler/status/1866761979349844211, with @ezyang responding.

(5) "<V2" is kinda weird, as it's a "little-endian void" type. When given to Numpy, it gets turned into endian-agnostic:

```
>>> import numpy as np
>>> import ml_dtypes
>>> np.dtype("bfloat16").str
'<V2'
>>> np.dtype("<V2").str
'|V2'
```

Still, it makes sense to have a unique string for `bfloat16` and since Google chose "<V2" we might as well use that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143042
Approved by: https://github.com/ezyang
This commit is contained in:
Heiner
2024-12-14 06:27:50 +00:00
committed by PyTorch MergeBot
parent c0a39ad35a
commit 3cc617b6a7

View File

@ -70,6 +70,28 @@ def _rebuild_from_type_v2(func, new_type, args, state):
return ret
def _dtype_to_typestr(dtype):
# CUDA devices are little-endian and tensors are stored in native byte
# order. 1-byte entries are endian-agnostic.
return {
torch.complex64: "<c8",
torch.complex128: "<c16",
torch.bfloat16: "<V2", # Same as ml_dtypes.bfloat16.dtype.str.
torch.float16: "<f2",
torch.float32: "<f4",
torch.float64: "<f8",
torch.uint8: "|u1",
torch.int8: "|i1",
torch.uint16: "<u2",
torch.int16: "<i2",
torch.uint32: "<u4",
torch.int32: "<i4",
torch.uint64: "<u8",
torch.int64: "<i8",
torch.bool: "|b1",
}[dtype]
# NB: If you subclass Tensor, and want to share the subclassed class
# across processes, you must also update torch/multiprocessing/reductions.py
# to define a ForkingPickler serialization mode for the class.
@ -1262,28 +1284,8 @@ class Tensor(torch._C.TensorBase):
"If gradients aren't required, use var.detach() to get Variable that doesn't require grad."
)
# CUDA devices are little-endian and tensors are stored in native byte
# order. 1-byte entries are endian-agnostic.
typestr = {
torch.complex64: "<c8",
torch.complex128: "<c16",
torch.bfloat16: "<f2",
torch.float16: "<f2",
torch.float32: "<f4",
torch.float64: "<f8",
torch.uint8: "|u1",
torch.int8: "|i1",
torch.uint16: "<u2",
torch.int16: "<i2",
torch.uint32: "<u4",
torch.int32: "<i4",
torch.uint64: "<u8",
torch.int64: "<i8",
torch.bool: "|b1",
}[self.dtype]
typestr = _dtype_to_typestr(self.dtype)
itemsize = self.element_size()
shape = tuple(self.shape)
if self.is_contiguous():
# __cuda_array_interface__ v2 requires the strides to be omitted