add the torch.float8_e8m0fnu dtype to PyTorch (#147466)

Summary:

Continuing the work from https://github.com/pytorch/pytorch/pull/146427

Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in
https://github.com/pytorch/pytorch/issues/146414 . Please see the issue for a detailed definition of the format.  Example of basic functionality:

```python
import torch

# round trip
x0 = torch.randn(4, 4, dtype=torch.float32)
x1 = x0.to(torch.float8_e8m0fnu)  # RNE rounding
x2 = x1.to(torch.float32)  # 2 ** exponent

# creation with empty
x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu)

# printing
print(x0)
```

Done in this PR:
* numerical correctness
* op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32
* printing a tensor works

For future PRs:
* performance optimizations for casting
* torch._scaled_mm
* PT2
* various cleanups (detailed in comments with issue numbers)

Test Plan:

```
pytest test/quantization/core/experimental/test_float8.py -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147466
Approved by: https://github.com/drisspg
This commit is contained in:
vasiliy
2025-02-20 13:55:42 +00:00
committed by PyTorch MergeBot
parent 574371d828
commit 382fbcc1e4
25 changed files with 535 additions and 44 deletions

View File

@ -63,10 +63,12 @@ DLDataType getDLDataType(const Tensor& t) {
case ScalarType::BFloat16: case ScalarType::BFloat16:
dtype.code = DLDataTypeCode::kDLBfloat; dtype.code = DLDataTypeCode::kDLBfloat;
break; break;
// TODO(#146647): use macro here instead of spelling out each shell dtype
case ScalarType::Float8_e5m2: case ScalarType::Float8_e5m2:
case ScalarType::Float8_e5m2fnuz: case ScalarType::Float8_e5m2fnuz:
case ScalarType::Float8_e4m3fn: case ScalarType::Float8_e4m3fn:
case ScalarType::Float8_e4m3fnuz: case ScalarType::Float8_e4m3fnuz:
case ScalarType::Float8_e8m0fnu:
TORCH_CHECK(false, "float8 types are not supported by dlpack"); TORCH_CHECK(false, "float8 types are not supported by dlpack");
break; break;
case ScalarType::QInt8: case ScalarType::QInt8:

View File

@ -87,7 +87,7 @@
#define AT_FLOAT8_TYPES \ #define AT_FLOAT8_TYPES \
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \ c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
c10::kFloat8_e4m3fnuz c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
#define AT_INTEGRAL_TYPES \ #define AT_INTEGRAL_TYPES \
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort

View File

@ -59,8 +59,8 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
#if !defined(C10_MOBILE) #if !defined(C10_MOBILE)
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \ #define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_V2( \ AT_DISPATCH_V2( \
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, \ TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, \
kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#else #else
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \ #define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

View File

@ -460,7 +460,8 @@ Tensor isinf(const Tensor& self) {
Tensor isfinite(const Tensor& self) { Tensor isfinite(const Tensor& self) {
// Note: Integral tensor values are always finite // Note: Integral tensor values are always finite
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) { if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true) ||
self.scalar_type() == kFloat8_e8m0fnu) {
return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve); return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
} }

View File

@ -204,12 +204,12 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ #define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \ AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
kComplexHalf, kHalf, kBool, \ kComplexHalf, kHalf, kBool, \
kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \ kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \ #define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \ AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \ kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#else #else
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ #define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

View File

@ -51,6 +51,9 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
fill_non_native_type<at::Float8_e4m3fnuz>(iter, value_scalar); fill_non_native_type<at::Float8_e4m3fnuz>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) { } else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) {
fill_non_native_type<at::Float8_e5m2fnuz>(iter, value_scalar); fill_non_native_type<at::Float8_e5m2fnuz>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e8m0fnu) {
// TODO(#146647): use macro here instead of spelling out each float8 dtype
fill_non_native_type<at::Float8_e8m0fnu>(iter, value_scalar);
} else { } else {
AT_DISPATCH_V2( AT_DISPATCH_V2(
iter.dtype(), "fill_cpu", AT_WRAP([&]() { iter.dtype(), "fill_cpu", AT_WRAP([&]() {

View File

@ -184,7 +184,13 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
} }
}), }),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES), // AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf, kComplexHalf,
kHalf, kHalf,
kBool, kBool,

View File

@ -144,6 +144,28 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; }); gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; });
break; break;
} }
} else if (dtype == kFloat8_e8m0fnu) {
// TODO(#146647): clean this up, too much copy-pasta
switch (other_dtype) {
case kFloat:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
return Float8_e8m0fnu(value);
});
break;
case kHalf:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
return Float8_e8m0fnu(value);
});
break;
case kBFloat16:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
return Float8_e8m0fnu(value);
});
break;
default:
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e8m0fnu x) { return x; });
break;
}
} else { } else {
TORCH_CHECK(false, "This supposed ot be called only for Float8 types"); TORCH_CHECK(false, "This supposed ot be called only for Float8 types");
} }
@ -157,7 +179,7 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] { AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
}); });
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) { } else if (isFloat8Type(dtype)) {
float8_copy_kernel_cuda(iter); float8_copy_kernel_cuda(iter);
} else if (iter.dtype(1) == kFloat && (dtype == kBFloat16 || dtype == kHalf)) { } else if (iter.dtype(1) == kFloat && (dtype == kBFloat16 || dtype == kHalf)) {
if (dtype == kBFloat16) { if (dtype == kBFloat16) {

View File

@ -582,7 +582,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
}), }),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES), // AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf, kComplexHalf,
kHalf, kHalf,
kBool, kBool,
@ -606,7 +612,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
}), }),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES), // AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf, kComplexHalf,
kHalf, kHalf,
kBool, kBool,
@ -630,7 +642,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
}), }),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES), // AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf, kComplexHalf,
kHalf, kHalf,
kBool, kBool,
@ -652,7 +670,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
}), }),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES), // AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf, kComplexHalf,
kHalf, kHalf,
kBool, kBool,
@ -677,7 +701,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
}), }),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES), // AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf, kComplexHalf,
kHalf, kHalf,
kBool, kBool,

View File

@ -228,6 +228,10 @@ template <> inline std::string typeName<at::Float8_e5m2fnuz>() {
template <> inline std::string typeName<at::Float8_e4m3fnuz>() { template <> inline std::string typeName<at::Float8_e4m3fnuz>() {
return "at::Float8_e4m3fnuz"; return "at::Float8_e4m3fnuz";
} }
template <> inline std::string typeName<at::Float8_e8m0fnu>() {
// TODO(#146647): Can the code here be made generic for any scalartype?
return "at::Float8_e8m0fnu";
}
#define TYPE_NAME_CASE(ctype, scalartype) \ #define TYPE_NAME_CASE(ctype, scalartype) \
case ScalarType::scalartype: return typeName<ctype>(); case ScalarType::scalartype: return typeName<ctype>();

View File

@ -49,16 +49,9 @@ class C10_API Scalar {
#define DEFINE_IMPLICIT_CTOR(type, name) \ #define DEFINE_IMPLICIT_CTOR(type, name) \
Scalar(type vv) : Scalar(vv, true) {} Scalar(type vv) : Scalar(vv, true) {}
AT_FORALL_SCALAR_TYPES_AND7( AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
Half,
BFloat16,
Float8_e5m2,
Float8_e4m3fn,
Float8_e5m2fnuz,
Float8_e4m3fnuz,
ComplexHalf,
DEFINE_IMPLICIT_CTOR)
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR) AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
AT_FORALL_FLOAT8_TYPES(DEFINE_IMPLICIT_CTOR)
// Helper constructors to allow Scalar creation from long and long long types // Helper constructors to allow Scalar creation from long and long long types
// As std::is_same_v<long, long long> is false(except Android), one needs to // As std::is_same_v<long, long long> is false(except Android), one needs to

View File

@ -222,6 +222,9 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
return std::make_pair("float8_e5m2fnuz", ""); return std::make_pair("float8_e5m2fnuz", "");
case c10::ScalarType::Float8_e4m3fnuz: case c10::ScalarType::Float8_e4m3fnuz:
return std::make_pair("float8_e4m3fnuz", ""); return std::make_pair("float8_e4m3fnuz", "");
case c10::ScalarType::Float8_e8m0fnu:
// TODO(#146647): macroify all of this
return std::make_pair("float8_e8m0fnu", "");
default: default:
throw std::runtime_error("Unimplemented scalar type"); throw std::runtime_error("Unimplemented scalar type");
} }

View File

@ -7,6 +7,7 @@
#include <c10/util/Float8_e4m3fnuz.h> #include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h> #include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h> #include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Float8_e8m0fnu.h>
#include <c10/util/Half.h> #include <c10/util/Half.h>
#include <c10/util/bits.h> #include <c10/util/bits.h>
#include <c10/util/complex.h> #include <c10/util/complex.h>
@ -102,7 +103,8 @@ struct dummy_int1_7_t {};
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \ _(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \ _(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \ _(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ _(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */
// If you want to support ComplexHalf for real, add ComplexHalf // If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert() // into this macro (and change the name). But beware: convert()
@ -146,7 +148,8 @@ struct dummy_int1_7_t {};
_(at::Float8_e5m2, Float8_e5m2) \ _(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \ _(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \ _(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
enum class ScalarType : int8_t { enum class ScalarType : int8_t {
#define DEFINE_ST_ENUM_VAL_(_1, n) n, #define DEFINE_ST_ENUM_VAL_(_1, n) n,
@ -317,6 +320,13 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
_(c10::quint4x2, QUInt4x2) \ _(c10::quint4x2, QUInt4x2) \
_(c10::quint2x4, QUInt2x4) _(c10::quint2x4, QUInt2x4)
#define AT_FORALL_FLOAT8_TYPES(_) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
#define AT_FORALL_COMPLEX_TYPES(_) \ #define AT_FORALL_COMPLEX_TYPES(_) \
_(c10::complex<float>, ComplexFloat) \ _(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) _(c10::complex<double>, ComplexDouble)
@ -372,7 +382,8 @@ inline bool isIntegralType(ScalarType t) {
inline bool isFloat8Type(ScalarType t) { inline bool isFloat8Type(ScalarType t) {
return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz || return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz ||
t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz; t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz ||
t == ScalarType::Float8_e8m0fnu;
} }
inline bool isReducedFloatingType(ScalarType t) { inline bool isReducedFloatingType(ScalarType t) {
@ -446,6 +457,10 @@ inline bool isSignedType(ScalarType t) {
return std::numeric_limits< \ return std::numeric_limits< \
::c10::impl::ScalarTypeToCPPTypeT<ScalarType::name>>::is_signed; ::c10::impl::ScalarTypeToCPPTypeT<ScalarType::name>>::is_signed;
// TODO(#146647): If we expect to have numeric_limits for everything,
// let's just have a big macro for the whole thing.
// If we're hardcoding it, let's just use the macro and a "true"/"false"
// below?
switch (t) { switch (t) {
case ScalarType::QInt8: case ScalarType::QInt8:
case ScalarType::QUInt8: case ScalarType::QUInt8:
@ -467,6 +482,7 @@ inline bool isSignedType(ScalarType t) {
CASE_ISSIGNED(Float8_e5m2fnuz); CASE_ISSIGNED(Float8_e5m2fnuz);
CASE_ISSIGNED(Float8_e4m3fn); CASE_ISSIGNED(Float8_e4m3fn);
CASE_ISSIGNED(Float8_e4m3fnuz); CASE_ISSIGNED(Float8_e4m3fnuz);
CASE_ISSIGNED(Float8_e8m0fnu);
CASE_ISSIGNED(Byte); CASE_ISSIGNED(Byte);
CASE_ISSIGNED(Char); CASE_ISSIGNED(Char);
CASE_ISSIGNED(Short); CASE_ISSIGNED(Short);

View File

@ -0,0 +1,112 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#include <cstring>
#include <limits>
// TODO(#146647): Can we remove the below warning?
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value)
: x(detail::fp8e8m0fnu_from_fp32_value(value)) {}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e8m0fnu::operator float() const {
// TODO(#146647): maybe rewrite without control flow
// if exponent is zero, need to special case to return 2^-127 instead of zero
if (x == 0) {
return c10::detail::fp32_from_bits(0x00400000);
}
// if exponent is NaN, need to special case to return properly encoded NaN
if (isnan()) {
return c10::detail::fp32_from_bits(0x7f800001);
}
// leave sign at 0, set the exponent bits, leave stored mantissa at 0
uint32_t res = x << 23;
return c10::detail::fp32_from_bits(res);
}
/// Special values helper
inline C10_HOST_DEVICE bool Float8_e8m0fnu::isnan() const {
return x == 0b11111111;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e8m0fnu to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e8m0fnu> {
public:
static constexpr bool is_specialized = true;
static constexpr bool is_signed = false;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = false;
static constexpr auto has_denorm_loss = false;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 1;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 1; // just a 2!
static constexpr int radix = 2;
static constexpr int min_exponent = -126;
static constexpr int min_exponent10 = -38;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before = false;
static constexpr c10::Float8_e8m0fnu min() {
// 2^-127
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu lowest() {
// 2^-127
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu max() {
// 254 biased, which is 127 unbiased, so 2^127
return c10::Float8_e8m0fnu(0b11111110, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu epsilon() {
// according to https://en.cppreference.com/w/cpp/types/numeric_limits, this
// is "the difference between 1.0 and the next representable value of the
// given floating-point type". The next representable value is 2.0, so the
// difference is 1.0 which is 2^0. 0 unbiased is 127 biased.
return c10::Float8_e8m0fnu(0b01111111, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu round_error() {
// 0.5 in float, which is 2^-1, and -1 + 127 = 126
return c10::Float8_e8m0fnu(0b01111110, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu quiet_NaN() {
return c10::Float8_e8m0fnu(0b11111111, c10::Float8_e8m0fnu::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,12 @@
#include <c10/macros/Macros.h>
#include <c10/util/Float8_e8m0fnu.h>
namespace c10 {
// TODO(#146647): Can we have these in a single shared cpp file
// built with macro to remove the need for a new cpp file?
static_assert(
std::is_standard_layout_v<Float8_e8m0fnu>,
"c10::Float8_e8m0fnu must be standard layout.");
} // namespace c10

120
c10/util/Float8_e8m0fnu.h Normal file
View File

@ -0,0 +1,120 @@
#pragma once
/// Defines the Float8_e8m0fnu type (8-bit floating-point) including
/// conversions to standard C types
/// Binary configuration :
/// eeeeeeee
/// no sign bits
/// 8 exponent bits
/// no mantissa bits
///
/// This is the E8M0 dtype from the OCP MX format spec
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
/// Section 5.4.1)
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
// TODO(#146647): do we need to special case OPENCL?
#if defined(__cplusplus)
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#include <iosfwd>
#include <ostream>
namespace c10 {
namespace detail {
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 e8m0fnu format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) {
// TODO(#146647): maybe rewrite without control flow
uint32_t f_bits = c10::detail::fp32_to_bits(f);
// extract the exponent
uint32_t exponent = (f_bits >> 23) & 0b11111111;
// special case float32 NaN and +-inf to map to e8m0 nan
if (exponent == 0b11111111) {
return exponent;
}
// next, we use guard, round, sticky bits and the LSB to implement round to
// nearest, with ties to even
// guard bit - bit 23, or 22 zero-indexed
uint8_t g = (f_bits & 0x400000) > 0;
// round bit - bit 22, or 21 zero-indexed
uint8_t r = (f_bits & 0x200000) > 0;
// sticky bit - bits 21 to 1, or 20 to 0 zero-indexed
uint8_t s = (f_bits & 0x1FFFFF) > 0;
// in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the
// original float32 is denormal, and to 1 if the original float32 is normal.
uint8_t lsb = exponent > 0;
// implement the RNE logic
bool round_up = false;
// if g == 0, round down (no-op)
if (g == 1) {
if ((r == 1) || (s == 1)) {
// round up
round_up = true;
} else {
if (lsb == 1) {
// round up
round_up = true;
}
// if lsb == 0, round down (no-op)
}
}
if (round_up) {
// adjust exponent
// note that if exponent was 255 we would have already returned earlier, so
// we know we can add one safely without running out of bounds
exponent++;
}
return exponent;
}
} // namespace detail
struct alignas(1) Float8_e8m0fnu {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e8m0fnu() = default;
constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e8m0fnu(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const Float8_e8m0fnu& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e8m0fnu-inl.h> // IWYU pragma: keep

View File

@ -5,6 +5,7 @@
#include <c10/util/Float8_e4m3fnuz.h> #include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h> #include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h> #include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Float8_e8m0fnu.h>
#include <c10/util/Half.h> #include <c10/util/Half.h>
#include <c10/util/complex.h> #include <c10/util/complex.h>
#include <c10/util/overflows.h> #include <c10/util/overflows.h>
@ -151,6 +152,19 @@ struct static_cast_with_inter_type<
} }
}; };
// TODO(#146647): Can we make all these template specialization happen
// based off our apply macros?
template <>
struct static_cast_with_inter_type<
c10::complex<c10::Half>,
c10::Float8_e8m0fnu> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
c10::Half>
apply(c10::Float8_e8m0fnu src) {
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
}
};
template <> template <>
struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Half> { struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Half> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: quantization"] # Owner(s): ["oncall: quantization"]
import struct
import unittest import unittest
import torch import torch
@ -14,6 +15,7 @@ from torch.testing._internal.common_utils import (
parametrize, parametrize,
run_tests, run_tests,
subtest, subtest,
TemporaryFileName,
TestCase, TestCase,
) )
@ -23,11 +25,13 @@ FLOAT8_DTYPES = [
torch.float8_e5m2fnuz, torch.float8_e5m2fnuz,
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e4m3fnuz, torch.float8_e4m3fnuz,
torch.float8_e8m0fnu,
] ]
CUDA_FLOAT8_DTYPES = [ CUDA_FLOAT8_DTYPES = [
torch.float8_e5m2, torch.float8_e5m2,
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e8m0fnu,
] ]
# The following information are not yet provided by torch.finfo. # The following information are not yet provided by torch.finfo.
@ -37,6 +41,7 @@ MANTISSA_BITS = {
torch.float8_e5m2fnuz: 2, torch.float8_e5m2fnuz: 2,
torch.float8_e4m3fn: 3, torch.float8_e4m3fn: 3,
torch.float8_e4m3fnuz: 3, torch.float8_e4m3fnuz: 3,
torch.float8_e8m0fnu: 0,
} }
# As in np.finfo(dtype).minexp # As in np.finfo(dtype).minexp
@ -45,6 +50,7 @@ MINEXP = {
torch.float8_e5m2fnuz: -15, torch.float8_e5m2fnuz: -15,
torch.float8_e4m3fn: -6, torch.float8_e4m3fn: -6,
torch.float8_e4m3fnuz: -7, torch.float8_e4m3fnuz: -7,
torch.float8_e8m0fnu: -127,
} }
SPECIAL_NUMBERS = { SPECIAL_NUMBERS = {
@ -108,11 +114,24 @@ SPECIAL_NUMBERS = {
("00000001", 0.125 * (2**-7), "min_subnorm"), ("00000001", 0.125 * (2**-7), "min_subnorm"),
("10000001", -0.125 * (2**-7), "neg_min_subnorm"), ("10000001", -0.125 * (2**-7), "neg_min_subnorm"),
], ],
torch.float8_e8m0fnu: [
("00000000", float(2**-127), "smallest_number"),
("11111110", float(2**127), "largest_number"),
("01111110", 0.5, "zero_point_five"),
("01111111", 1.0, "one"),
("10000000", 2.0, "two"),
("11111111", float("nan"), "nan"),
],
} }
FLOAT8_DTYPES_WITH_INF = [torch.float8_e5m2] FLOAT8_DTYPES_WITH_INF = [torch.float8_e5m2]
def _int_bits_to_float(x):
y = struct.unpack("!f", struct.pack("!I", x))[0]
return y
def simulate_fp8_precision(input, variant): def simulate_fp8_precision(input, variant):
"""Round input (as float32) to the given float8 datatype variant.""" """Round input (as float32) to the given float8 datatype variant."""
@ -165,6 +184,24 @@ def simulate_fp8_precision(input, variant):
return vals * signs return vals * signs
def _round_e8m0_rne(biased_exponent, lsb, g, r, s):
round_up = False
# apply g,r,s rounding rules for RNE rounding
if g == 1:
if (r == 1) or (s == 1):
round_up = True
else:
if lsb:
round_up = True
# round up if necessary
if round_up:
biased_exponent += 1
return biased_exponent
ROUND_TRIP_TEST_CASES = ( ROUND_TRIP_TEST_CASES = (
# A general 'soak test'. # A general 'soak test'.
subtest( subtest(
@ -198,17 +235,19 @@ ROUND_TRIP_TEST_CASES = (
class TestFloat8Dtype(TestCase): class TestFloat8Dtype(TestCase):
"""
Sanity test for zeros comparison
"""
@dtypes(*FLOAT8_DTYPES) @dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES) @dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
def test_creation_with_zeros(self, dtype, device): def test_creation_with_zeros(self, dtype, device):
"""Sanity test, round-trip casting of zeros.""" """Sanity test, round-trip casting of zeros."""
x = torch.zeros(8, dtype=torch.float, device=device)
x8 = torch.zeros(8, dtype=dtype, device=device) x8 = torch.zeros(8, dtype=dtype, device=device)
self.assertEqual(x, x8.float(), atol=0, rtol=0) if dtype is torch.float8_e8m0fnu:
# zeros are not supported for this dtype, values get clamped
# to 2 ^ -127
x = torch.full((8,), 2**-127, dtype=torch.float, device=device)
self.assertEqual(x, x8.float(), atol=0, rtol=0)
else:
x = torch.zeros(8, dtype=torch.float, device=device)
self.assertEqual(x, x8.float(), atol=0, rtol=0)
@dtypes(*FLOAT8_DTYPES) @dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES) @dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
@ -217,12 +256,69 @@ class TestFloat8Dtype(TestCase):
"""Numerical test of float8 conversion, by performing a round-trip cast """Numerical test of float8 conversion, by performing a round-trip cast
to the float8 dtype and back to float32, comparing against simulated to the float8 dtype and back to float32, comparing against simulated
lower precision.""" lower precision."""
if dtype is torch.float8_e8m0fnu:
return unittest.skip("numerics for e8m0fnu are tested elsewhere")
x = get_input(dtype, device) x = get_input(dtype, device)
x = torch.cat((x, -x)) x = torch.cat((x, -x))
x8 = x.to(dtype) x8 = x.to(dtype)
x8_simulated = simulate_fp8_precision(x, dtype) x8_simulated = simulate_fp8_precision(x, dtype)
self.assertEqual(x8_simulated, x8.float()) self.assertEqual(x8_simulated, x8.float())
def test_float8_e8m0fnu_rne_rounding(self, device):
"""
For every possible e8m0 exponent (256 options) and for every possible
g, r, s bits of the float32 mantissa, verify that RNE rounding is
correctly applied when casting from float32 to e8m0
Note: this code is morally similar to `test_cast_round_trip`, but
IMO simpler to special case e8m0 here.
"""
for biased_exponent in range(0, 256):
# iterate through all the possible options of guard, round, sticky bits
# for the current exponent
for grs in range(8):
# create a positive floating point number with the specified exponent
# and mantissa guard, round, sticky bits
uint32_t_start = (biased_exponent << 23) + (grs << 20)
fp32_start = _int_bits_to_float(uint32_t_start)
# create an RNE rounded version of the exponent
if biased_exponent == 255:
new_biased_exponent = biased_exponent
else:
lsb = biased_exponent > 0
g = grs >> 2
r = (grs >> 1) & 0b1
s = grs & 0b1
new_biased_exponent = _round_e8m0_rne(biased_exponent, lsb, g, r, s)
# create an RNE rounded version of the float
fp32_e8m0_fp32_emulated = _int_bits_to_float(new_biased_exponent << 23)
# now, do the same in PyTorch and see if results match
fp32_pt_start = torch.full(
(1,), fp32_start, device=device, dtype=torch.float
)
fp32_pt_e8m0 = fp32_pt_start.to(torch.float8_e8m0fnu)
fp32_pt_e8m0_fp32 = fp32_pt_e8m0.to(torch.float)
expected = fp32_e8m0_fp32_emulated
if biased_exponent == 254 and grs >= 4:
# special case rounding up from the largest representable float32 exponent, which
# saturates to nan
expected = float("nan")
elif biased_exponent == 255:
# special case inf and nan, which becomes nan
expected = float("nan")
actual = fp32_pt_e8m0_fp32.item()
self.assertEqual(
expected, actual, f"expected: {expected}, actual: {actual}"
)
@dtypes(*FLOAT8_DTYPES) @dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES) @dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
def test_special_numbers(self, dtype, device): def test_special_numbers(self, dtype, device):
@ -269,6 +365,32 @@ class TestFloat8Dtype(TestCase):
torch.use_deterministic_algorithms(use_deterministic) torch.use_deterministic_algorithms(use_deterministic)
torch.empty(4, 4, device=device, dtype=dtype) torch.empty(4, 4, device=device, dtype=dtype)
@dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
def test_to_string(self, dtype, device):
x = torch.empty(4, 4, device=device, dtype=dtype)
str(x)
@dtypes(*FLOAT8_DTYPES)
def test_finfo(self, dtype, device):
torch.finfo(dtype)
@dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
def test_cat(self, dtype, device):
x1 = torch.empty(4, 4, device=device, dtype=dtype)
x2 = torch.empty(4, 4, device=device, dtype=dtype)
torch.cat([x1, x2])
@dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
def test_save_load(self, dtype, device):
x1 = torch.randint(0, 10, (4, 4), device=device, dtype=torch.uint8).view(dtype)
with TemporaryFileName() as fname:
torch.save(x1, fname)
x1_save_load = torch.load(fname)
torch.testing.assert_close(x1, x1_save_load, atol=0, rtol=0)
instantiate_device_type_tests(TestFloat8Dtype, globals()) instantiate_device_type_tests(TestFloat8Dtype, globals())
@ -285,6 +407,9 @@ class TestFloat8DtypeCPUOnly(TestCase):
@dtypes(*CUDA_FLOAT8_DTYPES) @dtypes(*CUDA_FLOAT8_DTYPES)
def test_mul(self, dtype): def test_mul(self, dtype):
# TODO(#113663): remove arithmetic support from all float8 dtypes
if dtype is torch.float8_e8m0fnu:
return unittest.skip("arithmetic not supported for torch.float8_e8m0fnu")
shape = (10, 10) shape = (10, 10)
a = torch.randn(shape) a = torch.randn(shape)
a8_simulated = simulate_fp8_precision(a, dtype) a8_simulated = simulate_fp8_precision(a, dtype)
@ -299,6 +424,11 @@ class TestFloat8DtypeCPUOnly(TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on Windows yet") @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on Windows yet")
@dtypes(*CUDA_FLOAT8_DTYPES) @dtypes(*CUDA_FLOAT8_DTYPES)
def test_pt2_traceable_aot_eager(self, dtype): def test_pt2_traceable_aot_eager(self, dtype):
if dtype is torch.float8_e8m0fnu:
return unittest.skip(
"PT2 support for torch.float8_e8m0fnu is not implemented yet"
)
@torch.compile(backend="aot_eager", fullgraph=True) @torch.compile(backend="aot_eager", fullgraph=True)
def f(x): def f(x):
x = x.to(dtype) x = x.to(dtype)

View File

@ -1362,7 +1362,7 @@ def gen_pyi(
# Generate type signatures for dtype classes # Generate type signatures for dtype classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# TODO: don't explicitly list dtypes here; get it from canonical # TODO(#146647): don't explicitly list dtypes here; get it from canonical
# source # source
dtype_class_hints = [ dtype_class_hints = [
f"{n}: dtype = ..." f"{n}: dtype = ..."
@ -1377,6 +1377,7 @@ def gen_pyi(
"float8_e4m3fnuz", "float8_e4m3fnuz",
"float8_e5m2", "float8_e5m2",
"float8_e5m2fnuz", "float8_e5m2fnuz",
"float8_e8m0fnu",
"half", "half",
"uint8", "uint8",
"uint16", "uint16",

View File

@ -150,7 +150,17 @@ class _Formatter:
# no valid number, do nothing # no valid number, do nothing
return return
if tensor.dtype == torch.float8_e8m0fnu: # type: ignore[attr-defined]
# float8_e8m0fnu is special and does not define arithmetic ops,
# and printing code further in this file assumes the existence
# of various arithmetic ops to figure out what to print. We hack
# and convert to float here to make printing work correctly.
# TODO(#113663): also add the other float8 dtypes here after arithmetic
# support for them is removed
nonzero_finite_vals = nonzero_finite_vals.float()
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU. # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs()) nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min()) nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max()) nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())

View File

@ -123,16 +123,15 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
} }
#define _AT_DISPATCH_FINFO_TYPES(TYPE, NAME, ...) \ #define _AT_DISPATCH_FINFO_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \ AT_DISPATCH_V2( \
at::kHalf, \
at::ScalarType::BFloat16, \
at::ScalarType::Float8_e5m2, \
at::ScalarType::Float8_e5m2fnuz, \
at::ScalarType::Float8_e4m3fn, \
at::ScalarType::Float8_e4m3fnuz, \
TYPE, \ TYPE, \
NAME, \ NAME, \
__VA_ARGS__) AT_WRAP(__VA_ARGS__), \
AT_EXPAND(AT_FLOATING_TYPES), \
AT_EXPAND(AT_COMPLEX_TYPES), \
at::kHalf, \
at::ScalarType::BFloat16, \
AT_EXPAND(AT_FLOAT8_TYPES))
static PyObject* THPFInfo_eps(THPFInfo* self, void*) { static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS

View File

@ -79,6 +79,7 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
*(at::BFloat16*)data = *(at::BFloat16*)data =
at::convert<at::BFloat16, double>(THPUtils_unpackDouble(obj)); at::convert<at::BFloat16, double>(THPUtils_unpackDouble(obj));
break; break;
// TODO(#146647): simplify below with macros
case at::kFloat8_e5m2: case at::kFloat8_e5m2:
*(at::Float8_e5m2*)data = *(at::Float8_e5m2*)data =
at::convert<at::Float8_e5m2, double>(THPUtils_unpackDouble(obj)); at::convert<at::Float8_e5m2, double>(THPUtils_unpackDouble(obj));
@ -95,8 +96,12 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
*(at::Float8_e4m3fnuz*)data = *(at::Float8_e4m3fnuz*)data =
at::convert<at::Float8_e4m3fnuz, double>(THPUtils_unpackDouble(obj)); at::convert<at::Float8_e4m3fnuz, double>(THPUtils_unpackDouble(obj));
break; break;
case at::kFloat8_e8m0fnu:
*(at::Float8_e8m0fnu*)data =
at::convert<at::Float8_e8m0fnu, double>(THPUtils_unpackDouble(obj));
break;
default: default:
throw std::runtime_error("invalid type"); throw std::runtime_error("store_scalar: invalid type");
} }
} }
@ -143,6 +148,7 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) {
case at::kBFloat16: case at::kBFloat16:
return PyFloat_FromDouble( return PyFloat_FromDouble(
at::convert<double, at::BFloat16>(*(at::BFloat16*)data)); at::convert<double, at::BFloat16>(*(at::BFloat16*)data));
// TODO(#146647): simplify below with macros
case at::kFloat8_e5m2: case at::kFloat8_e5m2:
return PyFloat_FromDouble( return PyFloat_FromDouble(
at::convert<double, at::Float8_e5m2>(*(at::Float8_e5m2*)data)); at::convert<double, at::Float8_e5m2>(*(at::Float8_e5m2*)data));
@ -155,8 +161,11 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) {
case at::kFloat8_e4m3fnuz: case at::kFloat8_e4m3fnuz:
return PyFloat_FromDouble(at::convert<double, at::Float8_e4m3fnuz>( return PyFloat_FromDouble(at::convert<double, at::Float8_e4m3fnuz>(
*(at::Float8_e4m3fnuz*)data)); *(at::Float8_e4m3fnuz*)data));
case at::kFloat8_e8m0fnu:
return PyFloat_FromDouble(
at::convert<double, at::Float8_e8m0fnu>(*(at::Float8_e8m0fnu*)data));
default: default:
throw std::runtime_error("invalid type"); throw std::runtime_error("load_scalar: invalid type");
} }
} }

View File

@ -535,6 +535,7 @@ def _new_dtypes():
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e5m2fnuz, torch.float8_e5m2fnuz,
torch.float8_e4m3fnuz, torch.float8_e4m3fnuz,
torch.float8_e8m0fnu,
torch.bits8, torch.bits8,
torch.bits16, torch.bits16,
torch.bits1x8, torch.bits1x8,

View File

@ -51,6 +51,7 @@ float8_e5m2T = BaseCppType("at", "Float8_e5m2")
float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz") float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz")
float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn") float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz") float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz")
float8_e8m0fnuT = BaseCppType("at", "Float8_e8m0fnu")
stringT = BaseCppType("c10", "string_view") stringT = BaseCppType("c10", "string_view")
generatorT = BaseCppType("at", "Generator") generatorT = BaseCppType("at", "Generator")
scalarTypeT = BaseCppType("at", "ScalarType") scalarTypeT = BaseCppType("at", "ScalarType")
@ -102,6 +103,7 @@ ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT, ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT,
ScalarType.Float8_e4m3fn: float8_e4m3fnT, ScalarType.Float8_e4m3fn: float8_e4m3fnT,
ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT, ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
ScalarType.Float8_e8m0fnu: float8_e8m0fnuT,
} }
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {

View File

@ -374,6 +374,7 @@ class ScalarType(Enum):
Float8_e5m2fnuz = auto() Float8_e5m2fnuz = auto()
Float8_e4m3fn = auto() Float8_e4m3fn = auto()
Float8_e4m3fnuz = auto() Float8_e4m3fnuz = auto()
Float8_e8m0fnu = auto()
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name