mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Upgrades dlpack to v1.1 to include fp8/fp4 (#162195)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162195 Approved by: https://github.com/eqy, https://github.com/albanD, https://github.com/Skylion007, https://github.com/rgommers
This commit is contained in:
committed by
PyTorch MergeBot
parent
f2206b1ed8
commit
928ac57c2a
@ -65,14 +65,24 @@ DLDataType getDLDataType(const Tensor& t) {
|
||||
break;
|
||||
// TODO(#146647): use macro here instead of spelling out each shell dtype
|
||||
case ScalarType::Float8_e5m2:
|
||||
dtype.code = DLDataTypeCode::kDLFloat8_e5m2;
|
||||
break;
|
||||
case ScalarType::Float8_e5m2fnuz:
|
||||
dtype.code = DLDataTypeCode::kDLFloat8_e5m2fnuz;
|
||||
break;
|
||||
case ScalarType::Float8_e4m3fn:
|
||||
dtype.code = DLDataTypeCode::kDLFloat8_e4m3fn;
|
||||
break;
|
||||
case ScalarType::Float8_e4m3fnuz:
|
||||
dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz;
|
||||
break;
|
||||
case ScalarType::Float8_e8m0fnu:
|
||||
TORCH_CHECK_BUFFER(false, "float8 types are not supported by dlpack");
|
||||
dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu;
|
||||
break;
|
||||
case ScalarType::Float4_e2m1fn_x2:
|
||||
TORCH_CHECK_BUFFER(false, "float4 types are not supported by dlpack");
|
||||
dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn;
|
||||
dtype.lanes = 2;
|
||||
dtype.bits = 4;
|
||||
break;
|
||||
case ScalarType::QInt8:
|
||||
case ScalarType::QUInt8:
|
||||
@ -177,7 +187,11 @@ static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* dat
|
||||
|
||||
ScalarType toScalarType(const DLDataType& dtype) {
|
||||
ScalarType stype = ScalarType::Undefined;
|
||||
TORCH_CHECK_BUFFER(dtype.lanes == 1, "ATen does not support lanes != 1");
|
||||
if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) {
|
||||
TORCH_CHECK_BUFFER(
|
||||
dtype.lanes == 1,
|
||||
"ATen does not support lanes != 1 for dtype code", std::to_string(dtype.code));
|
||||
}
|
||||
switch (dtype.code) {
|
||||
case DLDataTypeCode::kDLUInt:
|
||||
switch (dtype.bits) {
|
||||
@ -269,6 +283,73 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat8_e5m2:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
stype = ScalarType::Float8_e5m2;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat8_e5m2 bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat8_e5m2fnuz:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
stype = ScalarType::Float8_e5m2fnuz;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat8_e5m2fnuz bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat8_e4m3fn:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
stype = ScalarType::Float8_e4m3fn;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat8_e4m3fn bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat8_e4m3fnuz:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
stype = ScalarType::Float8_e4m3fnuz;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat8_e4m3fnuz bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat8_e8m0fnu:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
stype = ScalarType::Float8_e8m0fnu;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat8_e8m0fnu bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat4_e2m1fn:
|
||||
switch (dtype.bits) {
|
||||
case 4:
|
||||
switch (dtype.lanes) {
|
||||
case 2:
|
||||
stype = ScalarType::Float4_e2m1fn_x2;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat4_e2m1fn lanes ", std::to_string(dtype.lanes));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat4_e2m1fn bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(false, "Unsupported code ", std::to_string(dtype.code));
|
||||
}
|
||||
@ -354,8 +435,8 @@ T* toDLPackImpl(const Tensor& src) {
|
||||
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
|
||||
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
|
||||
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
|
||||
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
|
||||
atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
|
||||
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
|
||||
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
|
||||
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
|
||||
fillVersion(&atDLMTensor->tensor);
|
||||
|
||||
|
@ -19,7 +19,7 @@
|
||||
#define DLPACK_MAJOR_VERSION 1
|
||||
|
||||
/*! \brief The current minor version of dlpack */
|
||||
#define DLPACK_MINOR_VERSION 0
|
||||
#define DLPACK_MINOR_VERSION 1
|
||||
|
||||
/*! \brief DLPACK_DLL prefix for windows */
|
||||
#ifdef _WIN32
|
||||
@ -32,9 +32,7 @@
|
||||
#define DLPACK_DLL
|
||||
#endif
|
||||
|
||||
// NOLINTNEXTLINE(modernize-deprecated-headers)
|
||||
#include <stdint.h>
|
||||
// NOLINTNEXTLINE(modernize-deprecated-headers)
|
||||
#include <stddef.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
@ -159,6 +157,26 @@ typedef enum {
|
||||
kDLComplex = 5U,
|
||||
/*! \brief boolean */
|
||||
kDLBool = 6U,
|
||||
/*! \brief FP8 data types */
|
||||
kDLFloat8_e3m4 = 7U,
|
||||
kDLFloat8_e4m3 = 8U,
|
||||
kDLFloat8_e4m3b11fnuz = 9U,
|
||||
kDLFloat8_e4m3fn = 10U,
|
||||
kDLFloat8_e4m3fnuz = 11U,
|
||||
kDLFloat8_e5m2 = 12U,
|
||||
kDLFloat8_e5m2fnuz = 13U,
|
||||
kDLFloat8_e8m0fnu = 14U,
|
||||
/*! \brief FP6 data types
|
||||
* Setting bits != 6 is currently unspecified, and the producer must ensure it is set
|
||||
* while the consumer must stop importing if the value is unexpected.
|
||||
*/
|
||||
kDLFloat6_e2m3fn = 15U,
|
||||
kDLFloat6_e3m2fn = 16U,
|
||||
/*! \brief FP4 data types
|
||||
* Setting bits != 4 is currently unspecified, and the producer must ensure it is set
|
||||
* while the consumer must stop importing if the value is unexpected.
|
||||
*/
|
||||
kDLFloat4_e2m1fn = 17U,
|
||||
} DLDataTypeCode;
|
||||
|
||||
/*!
|
||||
@ -172,6 +190,12 @@ typedef enum {
|
||||
* - int8: type_code = 0, bits = 8, lanes = 1
|
||||
* - std::complex<float>: type_code = 5, bits = 64, lanes = 1
|
||||
* - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits)
|
||||
* - float8_e4m3: type_code = 8, bits = 8, lanes = 1 (packed in memory)
|
||||
* - float6_e3m2fn: type_code = 16, bits = 6, lanes = 1 (packed in memory)
|
||||
* - float4_e2m1fn: type_code = 17, bits = 4, lanes = 1 (packed in memory)
|
||||
*
|
||||
* When a sub-byte type is packed, DLPack requires the data to be in little bit-endian, i.e.,
|
||||
* for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element.
|
||||
*/
|
||||
typedef struct {
|
||||
/*!
|
||||
@ -229,12 +253,12 @@ typedef struct {
|
||||
/*! \brief The data type of the pointer*/
|
||||
DLDataType dtype;
|
||||
/*! \brief The shape of the tensor */
|
||||
const int64_t* shape;
|
||||
int64_t* shape;
|
||||
/*!
|
||||
* \brief strides of the tensor (in number of elements, not bytes)
|
||||
* can be NULL, indicating tensor is compact and row-majored.
|
||||
*/
|
||||
const int64_t* strides;
|
||||
int64_t* strides;
|
||||
/*! \brief The offset in bytes to the beginning pointer to data */
|
||||
uint64_t byte_offset;
|
||||
} DLTensor;
|
||||
@ -282,6 +306,14 @@ typedef struct DLManagedTensor {
|
||||
*/
|
||||
#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL)
|
||||
|
||||
/*
|
||||
* \brief bit mask to indicate that whether a sub-byte type is packed or padded.
|
||||
*
|
||||
* The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can
|
||||
* be set by the producer to signal that a tensor of sub-byte type is padded.
|
||||
*/
|
||||
#define DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (1UL << 2UL)
|
||||
|
||||
/*!
|
||||
* \brief A versioned and managed C Tensor object, manage memory of DLTensor.
|
||||
*
|
||||
|
@ -93,15 +93,7 @@ class TestTorchDlPack(TestCase):
|
||||
z[0] = z[0] + 20.0
|
||||
self.assertEqual(z, x)
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
def test_dlpack_conversion_with_streams(self, device, dtype):
|
||||
# Create a stream where the tensor will reside
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# Do an operation in the actual stream
|
||||
x = make_tensor((5,), dtype=dtype, device=device) + 1
|
||||
def _dlpack_conversion_with_streams(self, stream, x):
|
||||
# DLPack protocol helps establish a correct stream order
|
||||
# (hence data dependency) at the exchange boundary.
|
||||
# DLPack manages this synchronization for us, so we don't need to
|
||||
@ -114,8 +106,38 @@ class TestTorchDlPack(TestCase):
|
||||
with torch.cuda.stream(stream):
|
||||
z = from_dlpack(x)
|
||||
stream.synchronize()
|
||||
return z
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
def test_dlpack_conversion_with_streams(self, device, dtype):
|
||||
# Create a stream where the tensor will reside
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# Do an operation in the actual stream
|
||||
x = make_tensor((5,), dtype=dtype, device=device) + 1
|
||||
z = self._dlpack_conversion_with_streams(stream, x)
|
||||
self.assertEqual(z, x)
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@dtypes(
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e8m0fnu,
|
||||
torch.float4_e2m1fn_x2,
|
||||
)
|
||||
def test_dlpack_conversion_with_streams_narrow_precision(self, device, dtype):
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
x = make_tensor((5,), dtype=torch.uint8, device=device) + 1
|
||||
x = x.view(dtype)
|
||||
z = self._dlpack_conversion_with_streams(stream, x)
|
||||
self.assertEqual(z.view(torch.uint8), x.view(torch.uint8))
|
||||
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(
|
||||
@ -187,6 +209,27 @@ class TestTorchDlPack(TestCase):
|
||||
stream_b.synchronize()
|
||||
self.assertEqual(z, x)
|
||||
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
@dtypes(
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e8m0fnu,
|
||||
torch.float4_e2m1fn_x2,
|
||||
)
|
||||
def test_dlpack_conversion_with_diff_streams_narrow_precision(self, device, dtype):
|
||||
stream_a = torch.cuda.Stream()
|
||||
stream_b = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream_a):
|
||||
x = make_tensor((5,), dtype=torch.uint8, device=device) + 1
|
||||
x = x.view(dtype)
|
||||
z = torch.from_dlpack(x.__dlpack__(stream=stream_b.cuda_stream))
|
||||
stream_a.synchronize()
|
||||
stream_b.synchronize()
|
||||
self.assertEqual(z.view(torch.uint8), x.view(torch.uint8))
|
||||
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(
|
||||
@ -484,9 +527,7 @@ class TestTorchDlPack(TestCase):
|
||||
@skipMeta
|
||||
@onlyCPU
|
||||
def test_dlpack_unsupported_dtype_error(self, device):
|
||||
inp = make_tensor((5,), dtype=torch.float32, device=device).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
inp = torch.quantize_per_tensor(torch.randn(()), 0.1, 10, torch.qint8)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
BufferError, ".* types are not supported by dlpack"
|
||||
|
Reference in New Issue
Block a user