Upgrades dlpack to v1.1 to include fp8/fp4 (#162195)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162195
Approved by: https://github.com/eqy, https://github.com/albanD, https://github.com/Skylion007, https://github.com/rgommers
This commit is contained in:
Syed Tousif Ahmed
2025-09-17 00:20:48 -07:00
committed by PyTorch MergeBot
parent f2206b1ed8
commit 928ac57c2a
3 changed files with 176 additions and 22 deletions

View File

@ -65,14 +65,24 @@ DLDataType getDLDataType(const Tensor& t) {
break;
// TODO(#146647): use macro here instead of spelling out each shell dtype
case ScalarType::Float8_e5m2:
dtype.code = DLDataTypeCode::kDLFloat8_e5m2;
break;
case ScalarType::Float8_e5m2fnuz:
dtype.code = DLDataTypeCode::kDLFloat8_e5m2fnuz;
break;
case ScalarType::Float8_e4m3fn:
dtype.code = DLDataTypeCode::kDLFloat8_e4m3fn;
break;
case ScalarType::Float8_e4m3fnuz:
dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz;
break;
case ScalarType::Float8_e8m0fnu:
TORCH_CHECK_BUFFER(false, "float8 types are not supported by dlpack");
dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu;
break;
case ScalarType::Float4_e2m1fn_x2:
TORCH_CHECK_BUFFER(false, "float4 types are not supported by dlpack");
dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn;
dtype.lanes = 2;
dtype.bits = 4;
break;
case ScalarType::QInt8:
case ScalarType::QUInt8:
@ -177,7 +187,11 @@ static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* dat
ScalarType toScalarType(const DLDataType& dtype) {
ScalarType stype = ScalarType::Undefined;
TORCH_CHECK_BUFFER(dtype.lanes == 1, "ATen does not support lanes != 1");
if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) {
TORCH_CHECK_BUFFER(
dtype.lanes == 1,
"ATen does not support lanes != 1 for dtype code", std::to_string(dtype.code));
}
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
@ -269,6 +283,73 @@ ScalarType toScalarType(const DLDataType& dtype) {
false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat8_e5m2:
switch (dtype.bits) {
case 8:
stype = ScalarType::Float8_e5m2;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat8_e5m2 bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat8_e5m2fnuz:
switch (dtype.bits) {
case 8:
stype = ScalarType::Float8_e5m2fnuz;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat8_e5m2fnuz bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat8_e4m3fn:
switch (dtype.bits) {
case 8:
stype = ScalarType::Float8_e4m3fn;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat8_e4m3fn bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat8_e4m3fnuz:
switch (dtype.bits) {
case 8:
stype = ScalarType::Float8_e4m3fnuz;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat8_e4m3fnuz bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat8_e8m0fnu:
switch (dtype.bits) {
case 8:
stype = ScalarType::Float8_e8m0fnu;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat8_e8m0fnu bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat4_e2m1fn:
switch (dtype.bits) {
case 4:
switch (dtype.lanes) {
case 2:
stype = ScalarType::Float4_e2m1fn_x2;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat4_e2m1fn lanes ", std::to_string(dtype.lanes));
}
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat4_e2m1fn bits ", std::to_string(dtype.bits));
}
break;
default:
TORCH_CHECK_BUFFER(false, "Unsupported code ", std::to_string(dtype.code));
}
@ -354,8 +435,8 @@ T* toDLPackImpl(const Tensor& src) {
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
fillVersion(&atDLMTensor->tensor);

View File

@ -19,7 +19,7 @@
#define DLPACK_MAJOR_VERSION 1
/*! \brief The current minor version of dlpack */
#define DLPACK_MINOR_VERSION 0
#define DLPACK_MINOR_VERSION 1
/*! \brief DLPACK_DLL prefix for windows */
#ifdef _WIN32
@ -32,9 +32,7 @@
#define DLPACK_DLL
#endif
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <stdint.h>
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <stddef.h>
#ifdef __cplusplus
@ -159,6 +157,26 @@ typedef enum {
kDLComplex = 5U,
/*! \brief boolean */
kDLBool = 6U,
/*! \brief FP8 data types */
kDLFloat8_e3m4 = 7U,
kDLFloat8_e4m3 = 8U,
kDLFloat8_e4m3b11fnuz = 9U,
kDLFloat8_e4m3fn = 10U,
kDLFloat8_e4m3fnuz = 11U,
kDLFloat8_e5m2 = 12U,
kDLFloat8_e5m2fnuz = 13U,
kDLFloat8_e8m0fnu = 14U,
/*! \brief FP6 data types
* Setting bits != 6 is currently unspecified, and the producer must ensure it is set
* while the consumer must stop importing if the value is unexpected.
*/
kDLFloat6_e2m3fn = 15U,
kDLFloat6_e3m2fn = 16U,
/*! \brief FP4 data types
* Setting bits != 4 is currently unspecified, and the producer must ensure it is set
* while the consumer must stop importing if the value is unexpected.
*/
kDLFloat4_e2m1fn = 17U,
} DLDataTypeCode;
/*!
@ -172,6 +190,12 @@ typedef enum {
* - int8: type_code = 0, bits = 8, lanes = 1
* - std::complex<float>: type_code = 5, bits = 64, lanes = 1
* - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits)
* - float8_e4m3: type_code = 8, bits = 8, lanes = 1 (packed in memory)
* - float6_e3m2fn: type_code = 16, bits = 6, lanes = 1 (packed in memory)
* - float4_e2m1fn: type_code = 17, bits = 4, lanes = 1 (packed in memory)
*
* When a sub-byte type is packed, DLPack requires the data to be in little bit-endian, i.e.,
* for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element.
*/
typedef struct {
/*!
@ -229,12 +253,12 @@ typedef struct {
/*! \brief The data type of the pointer*/
DLDataType dtype;
/*! \brief The shape of the tensor */
const int64_t* shape;
int64_t* shape;
/*!
* \brief strides of the tensor (in number of elements, not bytes)
* can be NULL, indicating tensor is compact and row-majored.
*/
const int64_t* strides;
int64_t* strides;
/*! \brief The offset in bytes to the beginning pointer to data */
uint64_t byte_offset;
} DLTensor;
@ -282,6 +306,14 @@ typedef struct DLManagedTensor {
*/
#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL)
/*
* \brief bit mask to indicate that whether a sub-byte type is packed or padded.
*
* The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can
* be set by the producer to signal that a tensor of sub-byte type is padded.
*/
#define DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (1UL << 2UL)
/*!
* \brief A versioned and managed C Tensor object, manage memory of DLTensor.
*

View File

@ -93,15 +93,7 @@ class TestTorchDlPack(TestCase):
z[0] = z[0] + 20.0
self.assertEqual(z, x)
@skipMeta
@onlyCUDA
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
def test_dlpack_conversion_with_streams(self, device, dtype):
# Create a stream where the tensor will reside
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# Do an operation in the actual stream
x = make_tensor((5,), dtype=dtype, device=device) + 1
def _dlpack_conversion_with_streams(self, stream, x):
# DLPack protocol helps establish a correct stream order
# (hence data dependency) at the exchange boundary.
# DLPack manages this synchronization for us, so we don't need to
@ -114,8 +106,38 @@ class TestTorchDlPack(TestCase):
with torch.cuda.stream(stream):
z = from_dlpack(x)
stream.synchronize()
return z
@skipMeta
@onlyCUDA
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
def test_dlpack_conversion_with_streams(self, device, dtype):
# Create a stream where the tensor will reside
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# Do an operation in the actual stream
x = make_tensor((5,), dtype=dtype, device=device) + 1
z = self._dlpack_conversion_with_streams(stream, x)
self.assertEqual(z, x)
@skipMeta
@onlyCUDA
@dtypes(
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e8m0fnu,
torch.float4_e2m1fn_x2,
)
def test_dlpack_conversion_with_streams_narrow_precision(self, device, dtype):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
x = make_tensor((5,), dtype=torch.uint8, device=device) + 1
x = x.view(dtype)
z = self._dlpack_conversion_with_streams(stream, x)
self.assertEqual(z.view(torch.uint8), x.view(torch.uint8))
@skipMeta
@onlyNativeDeviceTypes
@dtypes(
@ -187,6 +209,27 @@ class TestTorchDlPack(TestCase):
stream_b.synchronize()
self.assertEqual(z, x)
@skipMeta
@onlyCUDA
@dtypes(
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e8m0fnu,
torch.float4_e2m1fn_x2,
)
def test_dlpack_conversion_with_diff_streams_narrow_precision(self, device, dtype):
stream_a = torch.cuda.Stream()
stream_b = torch.cuda.Stream()
with torch.cuda.stream(stream_a):
x = make_tensor((5,), dtype=torch.uint8, device=device) + 1
x = x.view(dtype)
z = torch.from_dlpack(x.__dlpack__(stream=stream_b.cuda_stream))
stream_a.synchronize()
stream_b.synchronize()
self.assertEqual(z.view(torch.uint8), x.view(torch.uint8))
@skipMeta
@onlyNativeDeviceTypes
@dtypes(
@ -484,9 +527,7 @@ class TestTorchDlPack(TestCase):
@skipMeta
@onlyCPU
def test_dlpack_unsupported_dtype_error(self, device):
inp = make_tensor((5,), dtype=torch.float32, device=device).to(
torch.float8_e4m3fn
)
inp = torch.quantize_per_tensor(torch.randn(()), 0.1, 10, torch.qint8)
with self.assertRaisesRegex(
BufferError, ".* types are not supported by dlpack"