add the torch.float8_e8m0fnu dtype to PyTorch (#147466)

Summary:

Continuing the work from https://github.com/pytorch/pytorch/pull/146427

Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in
https://github.com/pytorch/pytorch/issues/146414 . Please see the issue for a detailed definition of the format.  Example of basic functionality:

```python
import torch

# round trip
x0 = torch.randn(4, 4, dtype=torch.float32)
x1 = x0.to(torch.float8_e8m0fnu)  # RNE rounding
x2 = x1.to(torch.float32)  # 2 ** exponent

# creation with empty
x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu)

# printing
print(x0)
```

Done in this PR:
* numerical correctness
* op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32
* printing a tensor works

For future PRs:
* performance optimizations for casting
* torch._scaled_mm
* PT2
* various cleanups (detailed in comments with issue numbers)

Test Plan:

```
pytest test/quantization/core/experimental/test_float8.py -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147466
Approved by: https://github.com/drisspg
This commit is contained in:
vasiliy
2025-02-20 13:55:42 +00:00
committed by PyTorch MergeBot
parent 574371d828
commit 382fbcc1e4
25 changed files with 535 additions and 44 deletions

View File

@ -63,10 +63,12 @@ DLDataType getDLDataType(const Tensor& t) {
case ScalarType::BFloat16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
// TODO(#146647): use macro here instead of spelling out each shell dtype
case ScalarType::Float8_e5m2:
case ScalarType::Float8_e5m2fnuz:
case ScalarType::Float8_e4m3fn:
case ScalarType::Float8_e4m3fnuz:
case ScalarType::Float8_e8m0fnu:
TORCH_CHECK(false, "float8 types are not supported by dlpack");
break;
case ScalarType::QInt8:

View File

@ -87,7 +87,7 @@
#define AT_FLOAT8_TYPES \
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
c10::kFloat8_e4m3fnuz
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
#define AT_INTEGRAL_TYPES \
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort

View File

@ -59,8 +59,8 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_V2( \
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, \
kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, \
AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#else
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

View File

@ -460,7 +460,8 @@ Tensor isinf(const Tensor& self) {
Tensor isfinite(const Tensor& self) {
// Note: Integral tensor values are always finite
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true) ||
self.scalar_type() == kFloat8_e8m0fnu) {
return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
}

View File

@ -204,12 +204,12 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
kComplexHalf, kHalf, kBool, \
kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#else
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

View File

@ -51,6 +51,9 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
fill_non_native_type<at::Float8_e4m3fnuz>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) {
fill_non_native_type<at::Float8_e5m2fnuz>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e8m0fnu) {
// TODO(#146647): use macro here instead of spelling out each float8 dtype
fill_non_native_type<at::Float8_e8m0fnu>(iter, value_scalar);
} else {
AT_DISPATCH_V2(
iter.dtype(), "fill_cpu", AT_WRAP([&]() {

View File

@ -184,7 +184,13 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
}
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES),
// AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf,
kHalf,
kBool,

View File

@ -144,6 +144,28 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; });
break;
}
} else if (dtype == kFloat8_e8m0fnu) {
// TODO(#146647): clean this up, too much copy-pasta
switch (other_dtype) {
case kFloat:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
return Float8_e8m0fnu(value);
});
break;
case kHalf:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
return Float8_e8m0fnu(value);
});
break;
case kBFloat16:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
return Float8_e8m0fnu(value);
});
break;
default:
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e8m0fnu x) { return x; });
break;
}
} else {
TORCH_CHECK(false, "This supposed ot be called only for Float8 types");
}
@ -157,7 +179,7 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) {
} else if (isFloat8Type(dtype)) {
float8_copy_kernel_cuda(iter);
} else if (iter.dtype(1) == kFloat && (dtype == kBFloat16 || dtype == kHalf)) {
if (dtype == kBFloat16) {

View File

@ -582,7 +582,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
C10_CUDA_KERNEL_LAUNCH_CHECK();
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES),
// AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf,
kHalf,
kBool,
@ -606,7 +612,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
C10_CUDA_KERNEL_LAUNCH_CHECK();
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES),
// AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf,
kHalf,
kBool,
@ -630,7 +642,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
C10_CUDA_KERNEL_LAUNCH_CHECK();
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES),
// AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf,
kHalf,
kBool,
@ -652,7 +670,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
C10_CUDA_KERNEL_LAUNCH_CHECK();
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES),
// AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf,
kHalf,
kBool,
@ -677,7 +701,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
C10_CUDA_KERNEL_LAUNCH_CHECK();
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES),
// AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf,
kHalf,
kBool,

View File

@ -228,6 +228,10 @@ template <> inline std::string typeName<at::Float8_e5m2fnuz>() {
template <> inline std::string typeName<at::Float8_e4m3fnuz>() {
return "at::Float8_e4m3fnuz";
}
template <> inline std::string typeName<at::Float8_e8m0fnu>() {
// TODO(#146647): Can the code here be made generic for any scalartype?
return "at::Float8_e8m0fnu";
}
#define TYPE_NAME_CASE(ctype, scalartype) \
case ScalarType::scalartype: return typeName<ctype>();

View File

@ -49,16 +49,9 @@ class C10_API Scalar {
#define DEFINE_IMPLICIT_CTOR(type, name) \
Scalar(type vv) : Scalar(vv, true) {}
AT_FORALL_SCALAR_TYPES_AND7(
Half,
BFloat16,
Float8_e5m2,
Float8_e4m3fn,
Float8_e5m2fnuz,
Float8_e4m3fnuz,
ComplexHalf,
DEFINE_IMPLICIT_CTOR)
AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
AT_FORALL_FLOAT8_TYPES(DEFINE_IMPLICIT_CTOR)
// Helper constructors to allow Scalar creation from long and long long types
// As std::is_same_v<long, long long> is false(except Android), one needs to

View File

@ -222,6 +222,9 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
return std::make_pair("float8_e5m2fnuz", "");
case c10::ScalarType::Float8_e4m3fnuz:
return std::make_pair("float8_e4m3fnuz", "");
case c10::ScalarType::Float8_e8m0fnu:
// TODO(#146647): macroify all of this
return std::make_pair("float8_e8m0fnu", "");
default:
throw std::runtime_error("Unimplemented scalar type");
}

View File

@ -7,6 +7,7 @@
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Float8_e8m0fnu.h>
#include <c10/util/Half.h>
#include <c10/util/bits.h>
#include <c10/util/complex.h>
@ -102,7 +103,8 @@ struct dummy_int1_7_t {};
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
@ -146,7 +148,8 @@ struct dummy_int1_7_t {};
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz)
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
enum class ScalarType : int8_t {
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
@ -317,6 +320,13 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
_(c10::quint4x2, QUInt4x2) \
_(c10::quint2x4, QUInt2x4)
#define AT_FORALL_FLOAT8_TYPES(_) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
#define AT_FORALL_COMPLEX_TYPES(_) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble)
@ -372,7 +382,8 @@ inline bool isIntegralType(ScalarType t) {
inline bool isFloat8Type(ScalarType t) {
return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz ||
t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz;
t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz ||
t == ScalarType::Float8_e8m0fnu;
}
inline bool isReducedFloatingType(ScalarType t) {
@ -446,6 +457,10 @@ inline bool isSignedType(ScalarType t) {
return std::numeric_limits< \
::c10::impl::ScalarTypeToCPPTypeT<ScalarType::name>>::is_signed;
// TODO(#146647): If we expect to have numeric_limits for everything,
// let's just have a big macro for the whole thing.
// If we're hardcoding it, let's just use the macro and a "true"/"false"
// below?
switch (t) {
case ScalarType::QInt8:
case ScalarType::QUInt8:
@ -467,6 +482,7 @@ inline bool isSignedType(ScalarType t) {
CASE_ISSIGNED(Float8_e5m2fnuz);
CASE_ISSIGNED(Float8_e4m3fn);
CASE_ISSIGNED(Float8_e4m3fnuz);
CASE_ISSIGNED(Float8_e8m0fnu);
CASE_ISSIGNED(Byte);
CASE_ISSIGNED(Char);
CASE_ISSIGNED(Short);

View File

@ -0,0 +1,112 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#include <cstring>
#include <limits>
// TODO(#146647): Can we remove the below warning?
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value)
: x(detail::fp8e8m0fnu_from_fp32_value(value)) {}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e8m0fnu::operator float() const {
// TODO(#146647): maybe rewrite without control flow
// if exponent is zero, need to special case to return 2^-127 instead of zero
if (x == 0) {
return c10::detail::fp32_from_bits(0x00400000);
}
// if exponent is NaN, need to special case to return properly encoded NaN
if (isnan()) {
return c10::detail::fp32_from_bits(0x7f800001);
}
// leave sign at 0, set the exponent bits, leave stored mantissa at 0
uint32_t res = x << 23;
return c10::detail::fp32_from_bits(res);
}
/// Special values helper
inline C10_HOST_DEVICE bool Float8_e8m0fnu::isnan() const {
return x == 0b11111111;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e8m0fnu to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e8m0fnu> {
public:
static constexpr bool is_specialized = true;
static constexpr bool is_signed = false;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = false;
static constexpr auto has_denorm_loss = false;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 1;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 1; // just a 2!
static constexpr int radix = 2;
static constexpr int min_exponent = -126;
static constexpr int min_exponent10 = -38;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before = false;
static constexpr c10::Float8_e8m0fnu min() {
// 2^-127
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu lowest() {
// 2^-127
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu max() {
// 254 biased, which is 127 unbiased, so 2^127
return c10::Float8_e8m0fnu(0b11111110, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu epsilon() {
// according to https://en.cppreference.com/w/cpp/types/numeric_limits, this
// is "the difference between 1.0 and the next representable value of the
// given floating-point type". The next representable value is 2.0, so the
// difference is 1.0 which is 2^0. 0 unbiased is 127 biased.
return c10::Float8_e8m0fnu(0b01111111, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu round_error() {
// 0.5 in float, which is 2^-1, and -1 + 127 = 126
return c10::Float8_e8m0fnu(0b01111110, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu quiet_NaN() {
return c10::Float8_e8m0fnu(0b11111111, c10::Float8_e8m0fnu::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,12 @@
#include <c10/macros/Macros.h>
#include <c10/util/Float8_e8m0fnu.h>
namespace c10 {
// TODO(#146647): Can we have these in a single shared cpp file
// built with macro to remove the need for a new cpp file?
static_assert(
std::is_standard_layout_v<Float8_e8m0fnu>,
"c10::Float8_e8m0fnu must be standard layout.");
} // namespace c10

120
c10/util/Float8_e8m0fnu.h Normal file
View File

@ -0,0 +1,120 @@
#pragma once
/// Defines the Float8_e8m0fnu type (8-bit floating-point) including
/// conversions to standard C types
/// Binary configuration :
/// eeeeeeee
/// no sign bits
/// 8 exponent bits
/// no mantissa bits
///
/// This is the E8M0 dtype from the OCP MX format spec
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
/// Section 5.4.1)
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
// TODO(#146647): do we need to special case OPENCL?
#if defined(__cplusplus)
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#include <iosfwd>
#include <ostream>
namespace c10 {
namespace detail {
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 e8m0fnu format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) {
// TODO(#146647): maybe rewrite without control flow
uint32_t f_bits = c10::detail::fp32_to_bits(f);
// extract the exponent
uint32_t exponent = (f_bits >> 23) & 0b11111111;
// special case float32 NaN and +-inf to map to e8m0 nan
if (exponent == 0b11111111) {
return exponent;
}
// next, we use guard, round, sticky bits and the LSB to implement round to
// nearest, with ties to even
// guard bit - bit 23, or 22 zero-indexed
uint8_t g = (f_bits & 0x400000) > 0;
// round bit - bit 22, or 21 zero-indexed
uint8_t r = (f_bits & 0x200000) > 0;
// sticky bit - bits 21 to 1, or 20 to 0 zero-indexed
uint8_t s = (f_bits & 0x1FFFFF) > 0;
// in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the
// original float32 is denormal, and to 1 if the original float32 is normal.
uint8_t lsb = exponent > 0;
// implement the RNE logic
bool round_up = false;
// if g == 0, round down (no-op)
if (g == 1) {
if ((r == 1) || (s == 1)) {
// round up
round_up = true;
} else {
if (lsb == 1) {
// round up
round_up = true;
}
// if lsb == 0, round down (no-op)
}
}
if (round_up) {
// adjust exponent
// note that if exponent was 255 we would have already returned earlier, so
// we know we can add one safely without running out of bounds
exponent++;
}
return exponent;
}
} // namespace detail
struct alignas(1) Float8_e8m0fnu {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e8m0fnu() = default;
constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e8m0fnu(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const Float8_e8m0fnu& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e8m0fnu-inl.h> // IWYU pragma: keep

View File

@ -5,6 +5,7 @@
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Float8_e8m0fnu.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>
#include <c10/util/overflows.h>
@ -151,6 +152,19 @@ struct static_cast_with_inter_type<
}
};
// TODO(#146647): Can we make all these template specialization happen
// based off our apply macros?
template <>
struct static_cast_with_inter_type<
c10::complex<c10::Half>,
c10::Float8_e8m0fnu> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
c10::Half>
apply(c10::Float8_e8m0fnu src) {
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
}
};
template <>
struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Half> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: quantization"]
import struct
import unittest
import torch
@ -14,6 +15,7 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
subtest,
TemporaryFileName,
TestCase,
)
@ -23,11 +25,13 @@ FLOAT8_DTYPES = [
torch.float8_e5m2fnuz,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e8m0fnu,
]
CUDA_FLOAT8_DTYPES = [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e8m0fnu,
]
# The following information are not yet provided by torch.finfo.
@ -37,6 +41,7 @@ MANTISSA_BITS = {
torch.float8_e5m2fnuz: 2,
torch.float8_e4m3fn: 3,
torch.float8_e4m3fnuz: 3,
torch.float8_e8m0fnu: 0,
}
# As in np.finfo(dtype).minexp
@ -45,6 +50,7 @@ MINEXP = {
torch.float8_e5m2fnuz: -15,
torch.float8_e4m3fn: -6,
torch.float8_e4m3fnuz: -7,
torch.float8_e8m0fnu: -127,
}
SPECIAL_NUMBERS = {
@ -108,11 +114,24 @@ SPECIAL_NUMBERS = {
("00000001", 0.125 * (2**-7), "min_subnorm"),
("10000001", -0.125 * (2**-7), "neg_min_subnorm"),
],
torch.float8_e8m0fnu: [
("00000000", float(2**-127), "smallest_number"),
("11111110", float(2**127), "largest_number"),
("01111110", 0.5, "zero_point_five"),
("01111111", 1.0, "one"),
("10000000", 2.0, "two"),
("11111111", float("nan"), "nan"),
],
}
FLOAT8_DTYPES_WITH_INF = [torch.float8_e5m2]
def _int_bits_to_float(x):
y = struct.unpack("!f", struct.pack("!I", x))[0]
return y
def simulate_fp8_precision(input, variant):
"""Round input (as float32) to the given float8 datatype variant."""
@ -165,6 +184,24 @@ def simulate_fp8_precision(input, variant):
return vals * signs
def _round_e8m0_rne(biased_exponent, lsb, g, r, s):
round_up = False
# apply g,r,s rounding rules for RNE rounding
if g == 1:
if (r == 1) or (s == 1):
round_up = True
else:
if lsb:
round_up = True
# round up if necessary
if round_up:
biased_exponent += 1
return biased_exponent
ROUND_TRIP_TEST_CASES = (
# A general 'soak test'.
subtest(
@ -198,17 +235,19 @@ ROUND_TRIP_TEST_CASES = (
class TestFloat8Dtype(TestCase):
"""
Sanity test for zeros comparison
"""
@dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
def test_creation_with_zeros(self, dtype, device):
"""Sanity test, round-trip casting of zeros."""
x = torch.zeros(8, dtype=torch.float, device=device)
x8 = torch.zeros(8, dtype=dtype, device=device)
self.assertEqual(x, x8.float(), atol=0, rtol=0)
if dtype is torch.float8_e8m0fnu:
# zeros are not supported for this dtype, values get clamped
# to 2 ^ -127
x = torch.full((8,), 2**-127, dtype=torch.float, device=device)
self.assertEqual(x, x8.float(), atol=0, rtol=0)
else:
x = torch.zeros(8, dtype=torch.float, device=device)
self.assertEqual(x, x8.float(), atol=0, rtol=0)
@dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
@ -217,12 +256,69 @@ class TestFloat8Dtype(TestCase):
"""Numerical test of float8 conversion, by performing a round-trip cast
to the float8 dtype and back to float32, comparing against simulated
lower precision."""
if dtype is torch.float8_e8m0fnu:
return unittest.skip("numerics for e8m0fnu are tested elsewhere")
x = get_input(dtype, device)
x = torch.cat((x, -x))
x8 = x.to(dtype)
x8_simulated = simulate_fp8_precision(x, dtype)
self.assertEqual(x8_simulated, x8.float())
def test_float8_e8m0fnu_rne_rounding(self, device):
"""
For every possible e8m0 exponent (256 options) and for every possible
g, r, s bits of the float32 mantissa, verify that RNE rounding is
correctly applied when casting from float32 to e8m0
Note: this code is morally similar to `test_cast_round_trip`, but
IMO simpler to special case e8m0 here.
"""
for biased_exponent in range(0, 256):
# iterate through all the possible options of guard, round, sticky bits
# for the current exponent
for grs in range(8):
# create a positive floating point number with the specified exponent
# and mantissa guard, round, sticky bits
uint32_t_start = (biased_exponent << 23) + (grs << 20)
fp32_start = _int_bits_to_float(uint32_t_start)
# create an RNE rounded version of the exponent
if biased_exponent == 255:
new_biased_exponent = biased_exponent
else:
lsb = biased_exponent > 0
g = grs >> 2
r = (grs >> 1) & 0b1
s = grs & 0b1
new_biased_exponent = _round_e8m0_rne(biased_exponent, lsb, g, r, s)
# create an RNE rounded version of the float
fp32_e8m0_fp32_emulated = _int_bits_to_float(new_biased_exponent << 23)
# now, do the same in PyTorch and see if results match
fp32_pt_start = torch.full(
(1,), fp32_start, device=device, dtype=torch.float
)
fp32_pt_e8m0 = fp32_pt_start.to(torch.float8_e8m0fnu)
fp32_pt_e8m0_fp32 = fp32_pt_e8m0.to(torch.float)
expected = fp32_e8m0_fp32_emulated
if biased_exponent == 254 and grs >= 4:
# special case rounding up from the largest representable float32 exponent, which
# saturates to nan
expected = float("nan")
elif biased_exponent == 255:
# special case inf and nan, which becomes nan
expected = float("nan")
actual = fp32_pt_e8m0_fp32.item()
self.assertEqual(
expected, actual, f"expected: {expected}, actual: {actual}"
)
@dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
def test_special_numbers(self, dtype, device):
@ -269,6 +365,32 @@ class TestFloat8Dtype(TestCase):
torch.use_deterministic_algorithms(use_deterministic)
torch.empty(4, 4, device=device, dtype=dtype)
@dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
def test_to_string(self, dtype, device):
x = torch.empty(4, 4, device=device, dtype=dtype)
str(x)
@dtypes(*FLOAT8_DTYPES)
def test_finfo(self, dtype, device):
torch.finfo(dtype)
@dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
def test_cat(self, dtype, device):
x1 = torch.empty(4, 4, device=device, dtype=dtype)
x2 = torch.empty(4, 4, device=device, dtype=dtype)
torch.cat([x1, x2])
@dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
def test_save_load(self, dtype, device):
x1 = torch.randint(0, 10, (4, 4), device=device, dtype=torch.uint8).view(dtype)
with TemporaryFileName() as fname:
torch.save(x1, fname)
x1_save_load = torch.load(fname)
torch.testing.assert_close(x1, x1_save_load, atol=0, rtol=0)
instantiate_device_type_tests(TestFloat8Dtype, globals())
@ -285,6 +407,9 @@ class TestFloat8DtypeCPUOnly(TestCase):
@dtypes(*CUDA_FLOAT8_DTYPES)
def test_mul(self, dtype):
# TODO(#113663): remove arithmetic support from all float8 dtypes
if dtype is torch.float8_e8m0fnu:
return unittest.skip("arithmetic not supported for torch.float8_e8m0fnu")
shape = (10, 10)
a = torch.randn(shape)
a8_simulated = simulate_fp8_precision(a, dtype)
@ -299,6 +424,11 @@ class TestFloat8DtypeCPUOnly(TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on Windows yet")
@dtypes(*CUDA_FLOAT8_DTYPES)
def test_pt2_traceable_aot_eager(self, dtype):
if dtype is torch.float8_e8m0fnu:
return unittest.skip(
"PT2 support for torch.float8_e8m0fnu is not implemented yet"
)
@torch.compile(backend="aot_eager", fullgraph=True)
def f(x):
x = x.to(dtype)

View File

@ -1362,7 +1362,7 @@ def gen_pyi(
# Generate type signatures for dtype classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# TODO: don't explicitly list dtypes here; get it from canonical
# TODO(#146647): don't explicitly list dtypes here; get it from canonical
# source
dtype_class_hints = [
f"{n}: dtype = ..."
@ -1377,6 +1377,7 @@ def gen_pyi(
"float8_e4m3fnuz",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_e8m0fnu",
"half",
"uint8",
"uint16",

View File

@ -150,7 +150,17 @@ class _Formatter:
# no valid number, do nothing
return
if tensor.dtype == torch.float8_e8m0fnu: # type: ignore[attr-defined]
# float8_e8m0fnu is special and does not define arithmetic ops,
# and printing code further in this file assumes the existence
# of various arithmetic ops to figure out what to print. We hack
# and convert to float here to make printing work correctly.
# TODO(#113663): also add the other float8 dtypes here after arithmetic
# support for them is removed
nonzero_finite_vals = nonzero_finite_vals.float()
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())

View File

@ -123,16 +123,15 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
}
#define _AT_DISPATCH_FINFO_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
at::kHalf, \
at::ScalarType::BFloat16, \
at::ScalarType::Float8_e5m2, \
at::ScalarType::Float8_e5m2fnuz, \
at::ScalarType::Float8_e4m3fn, \
at::ScalarType::Float8_e4m3fnuz, \
AT_DISPATCH_V2( \
TYPE, \
NAME, \
__VA_ARGS__)
AT_WRAP(__VA_ARGS__), \
AT_EXPAND(AT_FLOATING_TYPES), \
AT_EXPAND(AT_COMPLEX_TYPES), \
at::kHalf, \
at::ScalarType::BFloat16, \
AT_EXPAND(AT_FLOAT8_TYPES))
static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
HANDLE_TH_ERRORS

View File

@ -79,6 +79,7 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
*(at::BFloat16*)data =
at::convert<at::BFloat16, double>(THPUtils_unpackDouble(obj));
break;
// TODO(#146647): simplify below with macros
case at::kFloat8_e5m2:
*(at::Float8_e5m2*)data =
at::convert<at::Float8_e5m2, double>(THPUtils_unpackDouble(obj));
@ -95,8 +96,12 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
*(at::Float8_e4m3fnuz*)data =
at::convert<at::Float8_e4m3fnuz, double>(THPUtils_unpackDouble(obj));
break;
case at::kFloat8_e8m0fnu:
*(at::Float8_e8m0fnu*)data =
at::convert<at::Float8_e8m0fnu, double>(THPUtils_unpackDouble(obj));
break;
default:
throw std::runtime_error("invalid type");
throw std::runtime_error("store_scalar: invalid type");
}
}
@ -143,6 +148,7 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) {
case at::kBFloat16:
return PyFloat_FromDouble(
at::convert<double, at::BFloat16>(*(at::BFloat16*)data));
// TODO(#146647): simplify below with macros
case at::kFloat8_e5m2:
return PyFloat_FromDouble(
at::convert<double, at::Float8_e5m2>(*(at::Float8_e5m2*)data));
@ -155,8 +161,11 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) {
case at::kFloat8_e4m3fnuz:
return PyFloat_FromDouble(at::convert<double, at::Float8_e4m3fnuz>(
*(at::Float8_e4m3fnuz*)data));
case at::kFloat8_e8m0fnu:
return PyFloat_FromDouble(
at::convert<double, at::Float8_e8m0fnu>(*(at::Float8_e8m0fnu*)data));
default:
throw std::runtime_error("invalid type");
throw std::runtime_error("load_scalar: invalid type");
}
}

View File

@ -535,6 +535,7 @@ def _new_dtypes():
torch.float8_e4m3fn,
torch.float8_e5m2fnuz,
torch.float8_e4m3fnuz,
torch.float8_e8m0fnu,
torch.bits8,
torch.bits16,
torch.bits1x8,

View File

@ -51,6 +51,7 @@ float8_e5m2T = BaseCppType("at", "Float8_e5m2")
float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz")
float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz")
float8_e8m0fnuT = BaseCppType("at", "Float8_e8m0fnu")
stringT = BaseCppType("c10", "string_view")
generatorT = BaseCppType("at", "Generator")
scalarTypeT = BaseCppType("at", "ScalarType")
@ -102,6 +103,7 @@ ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT,
ScalarType.Float8_e4m3fn: float8_e4m3fnT,
ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
ScalarType.Float8_e8m0fnu: float8_e8m0fnuT,
}
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {

View File

@ -374,6 +374,7 @@ class ScalarType(Enum):
Float8_e5m2fnuz = auto()
Float8_e4m3fn = auto()
Float8_e4m3fnuz = auto()
Float8_e8m0fnu = auto()
def __str__(self) -> str:
return self.name