mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
__cuda_array_interface__
: Use "<V2" for bfloat16. (#143042)
Rationale: While Numpy doesn't support `bfloat16` and therefore there's no official typestr for `bfloat16` in `__array_interface__` (https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.interface.html#__array_interface__), JAX/ml_dtypes uses "<V2": ``` >>> from jax import numpy as jnp >>> jnp.bfloat16.dtype.str '<V2' ``` Using the same in PyTorch has the upside of making the typestrs returned by `__cuda_array_interface__` identify the torch dtype uniquely. ### Misc notes (1) JAX itself just refuses to do `__cuda_array_interface__` for `bfloat16`: ``` >>> from jax import numpy as jnp >>> jnp.arange(10, dtype=jnp.bfloat16).__cuda_array_interface__ Traceback (most recent call last): File "<stdin>", line 1, in <module> jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: __cuda_array_interface__ is not supported for bfloat16 buffers. ``` (2) The "official" description of `__cuda_array_interface__` doesn't mention bfloat16, it just references `__array_interface__`: https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html (3) Ongoing issue for numpy to support bfloat16: https://github.com/numpy/numpy/issues/19808 (4) Tweet that triggered this: https://x.com/HeinrichKuttler/status/1866761979349844211, with @ezyang responding. (5) "<V2" is kinda weird, as it's a "little-endian void" type. When given to Numpy, it gets turned into endian-agnostic: ``` >>> import numpy as np >>> import ml_dtypes >>> np.dtype("bfloat16").str '<V2' >>> np.dtype("<V2").str '|V2' ``` Still, it makes sense to have a unique string for `bfloat16` and since Google chose "<V2" we might as well use that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143042 Approved by: https://github.com/ezyang
This commit is contained in:
@ -70,6 +70,28 @@ def _rebuild_from_type_v2(func, new_type, args, state):
|
||||
return ret
|
||||
|
||||
|
||||
def _dtype_to_typestr(dtype):
|
||||
# CUDA devices are little-endian and tensors are stored in native byte
|
||||
# order. 1-byte entries are endian-agnostic.
|
||||
return {
|
||||
torch.complex64: "<c8",
|
||||
torch.complex128: "<c16",
|
||||
torch.bfloat16: "<V2", # Same as ml_dtypes.bfloat16.dtype.str.
|
||||
torch.float16: "<f2",
|
||||
torch.float32: "<f4",
|
||||
torch.float64: "<f8",
|
||||
torch.uint8: "|u1",
|
||||
torch.int8: "|i1",
|
||||
torch.uint16: "<u2",
|
||||
torch.int16: "<i2",
|
||||
torch.uint32: "<u4",
|
||||
torch.int32: "<i4",
|
||||
torch.uint64: "<u8",
|
||||
torch.int64: "<i8",
|
||||
torch.bool: "|b1",
|
||||
}[dtype]
|
||||
|
||||
|
||||
# NB: If you subclass Tensor, and want to share the subclassed class
|
||||
# across processes, you must also update torch/multiprocessing/reductions.py
|
||||
# to define a ForkingPickler serialization mode for the class.
|
||||
@ -1262,28 +1284,8 @@ class Tensor(torch._C.TensorBase):
|
||||
"If gradients aren't required, use var.detach() to get Variable that doesn't require grad."
|
||||
)
|
||||
|
||||
# CUDA devices are little-endian and tensors are stored in native byte
|
||||
# order. 1-byte entries are endian-agnostic.
|
||||
typestr = {
|
||||
torch.complex64: "<c8",
|
||||
torch.complex128: "<c16",
|
||||
torch.bfloat16: "<f2",
|
||||
torch.float16: "<f2",
|
||||
torch.float32: "<f4",
|
||||
torch.float64: "<f8",
|
||||
torch.uint8: "|u1",
|
||||
torch.int8: "|i1",
|
||||
torch.uint16: "<u2",
|
||||
torch.int16: "<i2",
|
||||
torch.uint32: "<u4",
|
||||
torch.int32: "<i4",
|
||||
torch.uint64: "<u8",
|
||||
torch.int64: "<i8",
|
||||
torch.bool: "|b1",
|
||||
}[self.dtype]
|
||||
|
||||
typestr = _dtype_to_typestr(self.dtype)
|
||||
itemsize = self.element_size()
|
||||
|
||||
shape = tuple(self.shape)
|
||||
if self.is_contiguous():
|
||||
# __cuda_array_interface__ v2 requires the strides to be omitted
|
||||
|
Reference in New Issue
Block a user