mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add the torch.float8_e8m0fnu
dtype to PyTorch (#147466)
Summary: Continuing the work from https://github.com/pytorch/pytorch/pull/146427 Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in https://github.com/pytorch/pytorch/issues/146414 . Please see the issue for a detailed definition of the format. Example of basic functionality: ```python import torch # round trip x0 = torch.randn(4, 4, dtype=torch.float32) x1 = x0.to(torch.float8_e8m0fnu) # RNE rounding x2 = x1.to(torch.float32) # 2 ** exponent # creation with empty x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu) # printing print(x0) ``` Done in this PR: * numerical correctness * op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32 * printing a tensor works For future PRs: * performance optimizations for casting * torch._scaled_mm * PT2 * various cleanups (detailed in comments with issue numbers) Test Plan: ``` pytest test/quantization/core/experimental/test_float8.py -s ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/147466 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
574371d828
commit
382fbcc1e4
@ -63,10 +63,12 @@ DLDataType getDLDataType(const Tensor& t) {
|
||||
case ScalarType::BFloat16:
|
||||
dtype.code = DLDataTypeCode::kDLBfloat;
|
||||
break;
|
||||
// TODO(#146647): use macro here instead of spelling out each shell dtype
|
||||
case ScalarType::Float8_e5m2:
|
||||
case ScalarType::Float8_e5m2fnuz:
|
||||
case ScalarType::Float8_e4m3fn:
|
||||
case ScalarType::Float8_e4m3fnuz:
|
||||
case ScalarType::Float8_e8m0fnu:
|
||||
TORCH_CHECK(false, "float8 types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::QInt8:
|
||||
|
@ -87,7 +87,7 @@
|
||||
|
||||
#define AT_FLOAT8_TYPES \
|
||||
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
|
||||
c10::kFloat8_e4m3fnuz
|
||||
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
|
||||
|
||||
#define AT_INTEGRAL_TYPES \
|
||||
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
|
||||
|
@ -59,8 +59,8 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_V2( \
|
||||
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, \
|
||||
kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, \
|
||||
AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
#else
|
||||
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
|
@ -460,7 +460,8 @@ Tensor isinf(const Tensor& self) {
|
||||
|
||||
Tensor isfinite(const Tensor& self) {
|
||||
// Note: Integral tensor values are always finite
|
||||
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
|
||||
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true) ||
|
||||
self.scalar_type() == kFloat8_e8m0fnu) {
|
||||
return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
|
||||
}
|
||||
|
||||
|
@ -204,12 +204,12 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne
|
||||
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
|
||||
kComplexHalf, kHalf, kBool, \
|
||||
kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
|
||||
kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
#else
|
||||
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
|
@ -51,6 +51,9 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
|
||||
fill_non_native_type<at::Float8_e4m3fnuz>(iter, value_scalar);
|
||||
} else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) {
|
||||
fill_non_native_type<at::Float8_e5m2fnuz>(iter, value_scalar);
|
||||
} else if (iter.dtype() == ScalarType::Float8_e8m0fnu) {
|
||||
// TODO(#146647): use macro here instead of spelling out each float8 dtype
|
||||
fill_non_native_type<at::Float8_e8m0fnu>(iter, value_scalar);
|
||||
} else {
|
||||
AT_DISPATCH_V2(
|
||||
iter.dtype(), "fill_cpu", AT_WRAP([&]() {
|
||||
|
@ -184,7 +184,13 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
|
||||
}
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
|
||||
// should not be supported here, then reenable AT_FLOAT8_DTYPES
|
||||
kFloat8_e4m3fn,
|
||||
kFloat8_e5m2,
|
||||
kFloat8_e4m3fnuz,
|
||||
kFloat8_e5m2fnuz,
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
|
@ -144,6 +144,28 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; });
|
||||
break;
|
||||
}
|
||||
} else if (dtype == kFloat8_e8m0fnu) {
|
||||
// TODO(#146647): clean this up, too much copy-pasta
|
||||
switch (other_dtype) {
|
||||
case kFloat:
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
|
||||
return Float8_e8m0fnu(value);
|
||||
});
|
||||
break;
|
||||
case kHalf:
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
|
||||
return Float8_e8m0fnu(value);
|
||||
});
|
||||
break;
|
||||
case kBFloat16:
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
|
||||
return Float8_e8m0fnu(value);
|
||||
});
|
||||
break;
|
||||
default:
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e8m0fnu x) { return x; });
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "This supposed ot be called only for Float8 types");
|
||||
}
|
||||
@ -157,7 +179,7 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
|
||||
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
|
||||
});
|
||||
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) {
|
||||
} else if (isFloat8Type(dtype)) {
|
||||
float8_copy_kernel_cuda(iter);
|
||||
} else if (iter.dtype(1) == kFloat && (dtype == kBFloat16 || dtype == kHalf)) {
|
||||
if (dtype == kBFloat16) {
|
||||
|
@ -582,7 +582,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
|
||||
// should not be supported here, then reenable AT_FLOAT8_DTYPES
|
||||
kFloat8_e4m3fn,
|
||||
kFloat8_e5m2,
|
||||
kFloat8_e4m3fnuz,
|
||||
kFloat8_e5m2fnuz,
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
@ -606,7 +612,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
|
||||
// should not be supported here, then reenable AT_FLOAT8_DTYPES
|
||||
kFloat8_e4m3fn,
|
||||
kFloat8_e5m2,
|
||||
kFloat8_e4m3fnuz,
|
||||
kFloat8_e5m2fnuz,
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
@ -630,7 +642,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
|
||||
// should not be supported here, then reenable AT_FLOAT8_DTYPES
|
||||
kFloat8_e4m3fn,
|
||||
kFloat8_e5m2,
|
||||
kFloat8_e4m3fnuz,
|
||||
kFloat8_e5m2fnuz,
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
@ -652,7 +670,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
|
||||
// should not be supported here, then reenable AT_FLOAT8_DTYPES
|
||||
kFloat8_e4m3fn,
|
||||
kFloat8_e5m2,
|
||||
kFloat8_e4m3fnuz,
|
||||
kFloat8_e5m2fnuz,
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
@ -677,7 +701,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
|
||||
// should not be supported here, then reenable AT_FLOAT8_DTYPES
|
||||
kFloat8_e4m3fn,
|
||||
kFloat8_e5m2,
|
||||
kFloat8_e4m3fnuz,
|
||||
kFloat8_e5m2fnuz,
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
|
@ -228,6 +228,10 @@ template <> inline std::string typeName<at::Float8_e5m2fnuz>() {
|
||||
template <> inline std::string typeName<at::Float8_e4m3fnuz>() {
|
||||
return "at::Float8_e4m3fnuz";
|
||||
}
|
||||
template <> inline std::string typeName<at::Float8_e8m0fnu>() {
|
||||
// TODO(#146647): Can the code here be made generic for any scalartype?
|
||||
return "at::Float8_e8m0fnu";
|
||||
}
|
||||
|
||||
#define TYPE_NAME_CASE(ctype, scalartype) \
|
||||
case ScalarType::scalartype: return typeName<ctype>();
|
||||
|
@ -49,16 +49,9 @@ class C10_API Scalar {
|
||||
#define DEFINE_IMPLICIT_CTOR(type, name) \
|
||||
Scalar(type vv) : Scalar(vv, true) {}
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND7(
|
||||
Half,
|
||||
BFloat16,
|
||||
Float8_e5m2,
|
||||
Float8_e4m3fn,
|
||||
Float8_e5m2fnuz,
|
||||
Float8_e4m3fnuz,
|
||||
ComplexHalf,
|
||||
DEFINE_IMPLICIT_CTOR)
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
|
||||
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
|
||||
AT_FORALL_FLOAT8_TYPES(DEFINE_IMPLICIT_CTOR)
|
||||
|
||||
// Helper constructors to allow Scalar creation from long and long long types
|
||||
// As std::is_same_v<long, long long> is false(except Android), one needs to
|
||||
|
@ -222,6 +222,9 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
|
||||
return std::make_pair("float8_e5m2fnuz", "");
|
||||
case c10::ScalarType::Float8_e4m3fnuz:
|
||||
return std::make_pair("float8_e4m3fnuz", "");
|
||||
case c10::ScalarType::Float8_e8m0fnu:
|
||||
// TODO(#146647): macroify all of this
|
||||
return std::make_pair("float8_e8m0fnu", "");
|
||||
default:
|
||||
throw std::runtime_error("Unimplemented scalar type");
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Float8_e5m2fnuz.h>
|
||||
#include <c10/util/Float8_e8m0fnu.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/bits.h>
|
||||
#include <c10/util/complex.h>
|
||||
@ -102,7 +103,8 @@ struct dummy_int1_7_t {};
|
||||
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
|
||||
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
|
||||
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
|
||||
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */
|
||||
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
|
||||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
@ -146,7 +148,8 @@ struct dummy_int1_7_t {};
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz)
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
enum class ScalarType : int8_t {
|
||||
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
|
||||
@ -317,6 +320,13 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
_(c10::quint4x2, QUInt4x2) \
|
||||
_(c10::quint2x4, QUInt2x4)
|
||||
|
||||
#define AT_FORALL_FLOAT8_TYPES(_) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
#define AT_FORALL_COMPLEX_TYPES(_) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble)
|
||||
@ -372,7 +382,8 @@ inline bool isIntegralType(ScalarType t) {
|
||||
|
||||
inline bool isFloat8Type(ScalarType t) {
|
||||
return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz ||
|
||||
t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz;
|
||||
t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz ||
|
||||
t == ScalarType::Float8_e8m0fnu;
|
||||
}
|
||||
|
||||
inline bool isReducedFloatingType(ScalarType t) {
|
||||
@ -446,6 +457,10 @@ inline bool isSignedType(ScalarType t) {
|
||||
return std::numeric_limits< \
|
||||
::c10::impl::ScalarTypeToCPPTypeT<ScalarType::name>>::is_signed;
|
||||
|
||||
// TODO(#146647): If we expect to have numeric_limits for everything,
|
||||
// let's just have a big macro for the whole thing.
|
||||
// If we're hardcoding it, let's just use the macro and a "true"/"false"
|
||||
// below?
|
||||
switch (t) {
|
||||
case ScalarType::QInt8:
|
||||
case ScalarType::QUInt8:
|
||||
@ -467,6 +482,7 @@ inline bool isSignedType(ScalarType t) {
|
||||
CASE_ISSIGNED(Float8_e5m2fnuz);
|
||||
CASE_ISSIGNED(Float8_e4m3fn);
|
||||
CASE_ISSIGNED(Float8_e4m3fnuz);
|
||||
CASE_ISSIGNED(Float8_e8m0fnu);
|
||||
CASE_ISSIGNED(Byte);
|
||||
CASE_ISSIGNED(Char);
|
||||
CASE_ISSIGNED(Short);
|
||||
|
112
c10/util/Float8_e8m0fnu-inl.h
Normal file
112
c10/util/Float8_e8m0fnu-inl.h
Normal file
@ -0,0 +1,112 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
// TODO(#146647): Can we remove the below warning?
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value)
|
||||
: x(detail::fp8e8m0fnu_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e8m0fnu::operator float() const {
|
||||
// TODO(#146647): maybe rewrite without control flow
|
||||
|
||||
// if exponent is zero, need to special case to return 2^-127 instead of zero
|
||||
if (x == 0) {
|
||||
return c10::detail::fp32_from_bits(0x00400000);
|
||||
}
|
||||
|
||||
// if exponent is NaN, need to special case to return properly encoded NaN
|
||||
if (isnan()) {
|
||||
return c10::detail::fp32_from_bits(0x7f800001);
|
||||
}
|
||||
|
||||
// leave sign at 0, set the exponent bits, leave stored mantissa at 0
|
||||
uint32_t res = x << 23;
|
||||
|
||||
return c10::detail::fp32_from_bits(res);
|
||||
}
|
||||
|
||||
/// Special values helper
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e8m0fnu::isnan() const {
|
||||
return x == 0b11111111;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e8m0fnu to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e8m0fnu> {
|
||||
public:
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_signed = false;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = false;
|
||||
static constexpr auto has_denorm_loss = false;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 1;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 1; // just a 2!
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -126;
|
||||
static constexpr int min_exponent10 = -38;
|
||||
static constexpr int max_exponent = 128;
|
||||
static constexpr int max_exponent10 = 38;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before = false;
|
||||
|
||||
static constexpr c10::Float8_e8m0fnu min() {
|
||||
// 2^-127
|
||||
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu lowest() {
|
||||
// 2^-127
|
||||
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu max() {
|
||||
// 254 biased, which is 127 unbiased, so 2^127
|
||||
return c10::Float8_e8m0fnu(0b11111110, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu epsilon() {
|
||||
// according to https://en.cppreference.com/w/cpp/types/numeric_limits, this
|
||||
// is "the difference between 1.0 and the next representable value of the
|
||||
// given floating-point type". The next representable value is 2.0, so the
|
||||
// difference is 1.0 which is 2^0. 0 unbiased is 127 biased.
|
||||
return c10::Float8_e8m0fnu(0b01111111, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu round_error() {
|
||||
// 0.5 in float, which is 2^-1, and -1 + 127 = 126
|
||||
return c10::Float8_e8m0fnu(0b01111110, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu quiet_NaN() {
|
||||
return c10::Float8_e8m0fnu(0b11111111, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
12
c10/util/Float8_e8m0fnu.cpp
Normal file
12
c10/util/Float8_e8m0fnu.cpp
Normal file
@ -0,0 +1,12 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Float8_e8m0fnu.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// TODO(#146647): Can we have these in a single shared cpp file
|
||||
// built with macro to remove the need for a new cpp file?
|
||||
static_assert(
|
||||
std::is_standard_layout_v<Float8_e8m0fnu>,
|
||||
"c10::Float8_e8m0fnu must be standard layout.");
|
||||
|
||||
} // namespace c10
|
120
c10/util/Float8_e8m0fnu.h
Normal file
120
c10/util/Float8_e8m0fnu.h
Normal file
@ -0,0 +1,120 @@
|
||||
#pragma once
|
||||
|
||||
/// Defines the Float8_e8m0fnu type (8-bit floating-point) including
|
||||
/// conversions to standard C types
|
||||
/// Binary configuration :
|
||||
/// eeeeeeee
|
||||
/// no sign bits
|
||||
/// 8 exponent bits
|
||||
/// no mantissa bits
|
||||
///
|
||||
/// This is the E8M0 dtype from the OCP MX format spec
|
||||
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
|
||||
/// Section 5.4.1)
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
// TODO(#146647): do we need to special case OPENCL?
|
||||
#if defined(__cplusplus)
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#include <iosfwd>
|
||||
#include <ostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 e8m0fnu format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) {
|
||||
// TODO(#146647): maybe rewrite without control flow
|
||||
|
||||
uint32_t f_bits = c10::detail::fp32_to_bits(f);
|
||||
|
||||
// extract the exponent
|
||||
uint32_t exponent = (f_bits >> 23) & 0b11111111;
|
||||
|
||||
// special case float32 NaN and +-inf to map to e8m0 nan
|
||||
if (exponent == 0b11111111) {
|
||||
return exponent;
|
||||
}
|
||||
|
||||
// next, we use guard, round, sticky bits and the LSB to implement round to
|
||||
// nearest, with ties to even
|
||||
|
||||
// guard bit - bit 23, or 22 zero-indexed
|
||||
uint8_t g = (f_bits & 0x400000) > 0;
|
||||
// round bit - bit 22, or 21 zero-indexed
|
||||
uint8_t r = (f_bits & 0x200000) > 0;
|
||||
// sticky bit - bits 21 to 1, or 20 to 0 zero-indexed
|
||||
uint8_t s = (f_bits & 0x1FFFFF) > 0;
|
||||
// in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the
|
||||
// original float32 is denormal, and to 1 if the original float32 is normal.
|
||||
uint8_t lsb = exponent > 0;
|
||||
|
||||
// implement the RNE logic
|
||||
bool round_up = false;
|
||||
|
||||
// if g == 0, round down (no-op)
|
||||
if (g == 1) {
|
||||
if ((r == 1) || (s == 1)) {
|
||||
// round up
|
||||
round_up = true;
|
||||
} else {
|
||||
if (lsb == 1) {
|
||||
// round up
|
||||
round_up = true;
|
||||
}
|
||||
// if lsb == 0, round down (no-op)
|
||||
}
|
||||
}
|
||||
|
||||
if (round_up) {
|
||||
// adjust exponent
|
||||
// note that if exponent was 255 we would have already returned earlier, so
|
||||
// we know we can add one safely without running out of bounds
|
||||
exponent++;
|
||||
}
|
||||
|
||||
return exponent;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e8m0fnu {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e8m0fnu() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t)
|
||||
: x(bits) {}
|
||||
inline C10_HOST_DEVICE Float8_e8m0fnu(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
};
|
||||
|
||||
C10_API inline std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const Float8_e8m0fnu& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e8m0fnu-inl.h> // IWYU pragma: keep
|
@ -5,6 +5,7 @@
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Float8_e5m2fnuz.h>
|
||||
#include <c10/util/Float8_e8m0fnu.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <c10/util/overflows.h>
|
||||
@ -151,6 +152,19 @@ struct static_cast_with_inter_type<
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(#146647): Can we make all these template specialization happen
|
||||
// based off our apply macros?
|
||||
template <>
|
||||
struct static_cast_with_inter_type<
|
||||
c10::complex<c10::Half>,
|
||||
c10::Float8_e8m0fnu> {
|
||||
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
|
||||
c10::Half>
|
||||
apply(c10::Float8_e8m0fnu src) {
|
||||
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Half> {
|
||||
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import struct
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@ -14,6 +15,7 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
subtest,
|
||||
TemporaryFileName,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
@ -23,11 +25,13 @@ FLOAT8_DTYPES = [
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e8m0fnu,
|
||||
]
|
||||
|
||||
CUDA_FLOAT8_DTYPES = [
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e8m0fnu,
|
||||
]
|
||||
|
||||
# The following information are not yet provided by torch.finfo.
|
||||
@ -37,6 +41,7 @@ MANTISSA_BITS = {
|
||||
torch.float8_e5m2fnuz: 2,
|
||||
torch.float8_e4m3fn: 3,
|
||||
torch.float8_e4m3fnuz: 3,
|
||||
torch.float8_e8m0fnu: 0,
|
||||
}
|
||||
|
||||
# As in np.finfo(dtype).minexp
|
||||
@ -45,6 +50,7 @@ MINEXP = {
|
||||
torch.float8_e5m2fnuz: -15,
|
||||
torch.float8_e4m3fn: -6,
|
||||
torch.float8_e4m3fnuz: -7,
|
||||
torch.float8_e8m0fnu: -127,
|
||||
}
|
||||
|
||||
SPECIAL_NUMBERS = {
|
||||
@ -108,11 +114,24 @@ SPECIAL_NUMBERS = {
|
||||
("00000001", 0.125 * (2**-7), "min_subnorm"),
|
||||
("10000001", -0.125 * (2**-7), "neg_min_subnorm"),
|
||||
],
|
||||
torch.float8_e8m0fnu: [
|
||||
("00000000", float(2**-127), "smallest_number"),
|
||||
("11111110", float(2**127), "largest_number"),
|
||||
("01111110", 0.5, "zero_point_five"),
|
||||
("01111111", 1.0, "one"),
|
||||
("10000000", 2.0, "two"),
|
||||
("11111111", float("nan"), "nan"),
|
||||
],
|
||||
}
|
||||
|
||||
FLOAT8_DTYPES_WITH_INF = [torch.float8_e5m2]
|
||||
|
||||
|
||||
def _int_bits_to_float(x):
|
||||
y = struct.unpack("!f", struct.pack("!I", x))[0]
|
||||
return y
|
||||
|
||||
|
||||
def simulate_fp8_precision(input, variant):
|
||||
"""Round input (as float32) to the given float8 datatype variant."""
|
||||
|
||||
@ -165,6 +184,24 @@ def simulate_fp8_precision(input, variant):
|
||||
return vals * signs
|
||||
|
||||
|
||||
def _round_e8m0_rne(biased_exponent, lsb, g, r, s):
|
||||
round_up = False
|
||||
|
||||
# apply g,r,s rounding rules for RNE rounding
|
||||
if g == 1:
|
||||
if (r == 1) or (s == 1):
|
||||
round_up = True
|
||||
else:
|
||||
if lsb:
|
||||
round_up = True
|
||||
|
||||
# round up if necessary
|
||||
if round_up:
|
||||
biased_exponent += 1
|
||||
|
||||
return biased_exponent
|
||||
|
||||
|
||||
ROUND_TRIP_TEST_CASES = (
|
||||
# A general 'soak test'.
|
||||
subtest(
|
||||
@ -198,17 +235,19 @@ ROUND_TRIP_TEST_CASES = (
|
||||
|
||||
|
||||
class TestFloat8Dtype(TestCase):
|
||||
"""
|
||||
Sanity test for zeros comparison
|
||||
"""
|
||||
|
||||
@dtypes(*FLOAT8_DTYPES)
|
||||
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
|
||||
def test_creation_with_zeros(self, dtype, device):
|
||||
"""Sanity test, round-trip casting of zeros."""
|
||||
x = torch.zeros(8, dtype=torch.float, device=device)
|
||||
x8 = torch.zeros(8, dtype=dtype, device=device)
|
||||
self.assertEqual(x, x8.float(), atol=0, rtol=0)
|
||||
if dtype is torch.float8_e8m0fnu:
|
||||
# zeros are not supported for this dtype, values get clamped
|
||||
# to 2 ^ -127
|
||||
x = torch.full((8,), 2**-127, dtype=torch.float, device=device)
|
||||
self.assertEqual(x, x8.float(), atol=0, rtol=0)
|
||||
else:
|
||||
x = torch.zeros(8, dtype=torch.float, device=device)
|
||||
self.assertEqual(x, x8.float(), atol=0, rtol=0)
|
||||
|
||||
@dtypes(*FLOAT8_DTYPES)
|
||||
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
|
||||
@ -217,12 +256,69 @@ class TestFloat8Dtype(TestCase):
|
||||
"""Numerical test of float8 conversion, by performing a round-trip cast
|
||||
to the float8 dtype and back to float32, comparing against simulated
|
||||
lower precision."""
|
||||
if dtype is torch.float8_e8m0fnu:
|
||||
return unittest.skip("numerics for e8m0fnu are tested elsewhere")
|
||||
|
||||
x = get_input(dtype, device)
|
||||
x = torch.cat((x, -x))
|
||||
x8 = x.to(dtype)
|
||||
x8_simulated = simulate_fp8_precision(x, dtype)
|
||||
self.assertEqual(x8_simulated, x8.float())
|
||||
|
||||
def test_float8_e8m0fnu_rne_rounding(self, device):
|
||||
"""
|
||||
For every possible e8m0 exponent (256 options) and for every possible
|
||||
g, r, s bits of the float32 mantissa, verify that RNE rounding is
|
||||
correctly applied when casting from float32 to e8m0
|
||||
|
||||
Note: this code is morally similar to `test_cast_round_trip`, but
|
||||
IMO simpler to special case e8m0 here.
|
||||
"""
|
||||
|
||||
for biased_exponent in range(0, 256):
|
||||
# iterate through all the possible options of guard, round, sticky bits
|
||||
# for the current exponent
|
||||
for grs in range(8):
|
||||
# create a positive floating point number with the specified exponent
|
||||
# and mantissa guard, round, sticky bits
|
||||
uint32_t_start = (biased_exponent << 23) + (grs << 20)
|
||||
fp32_start = _int_bits_to_float(uint32_t_start)
|
||||
|
||||
# create an RNE rounded version of the exponent
|
||||
if biased_exponent == 255:
|
||||
new_biased_exponent = biased_exponent
|
||||
else:
|
||||
lsb = biased_exponent > 0
|
||||
g = grs >> 2
|
||||
r = (grs >> 1) & 0b1
|
||||
s = grs & 0b1
|
||||
new_biased_exponent = _round_e8m0_rne(biased_exponent, lsb, g, r, s)
|
||||
|
||||
# create an RNE rounded version of the float
|
||||
fp32_e8m0_fp32_emulated = _int_bits_to_float(new_biased_exponent << 23)
|
||||
|
||||
# now, do the same in PyTorch and see if results match
|
||||
fp32_pt_start = torch.full(
|
||||
(1,), fp32_start, device=device, dtype=torch.float
|
||||
)
|
||||
fp32_pt_e8m0 = fp32_pt_start.to(torch.float8_e8m0fnu)
|
||||
fp32_pt_e8m0_fp32 = fp32_pt_e8m0.to(torch.float)
|
||||
|
||||
expected = fp32_e8m0_fp32_emulated
|
||||
if biased_exponent == 254 and grs >= 4:
|
||||
# special case rounding up from the largest representable float32 exponent, which
|
||||
# saturates to nan
|
||||
expected = float("nan")
|
||||
elif biased_exponent == 255:
|
||||
# special case inf and nan, which becomes nan
|
||||
expected = float("nan")
|
||||
|
||||
actual = fp32_pt_e8m0_fp32.item()
|
||||
|
||||
self.assertEqual(
|
||||
expected, actual, f"expected: {expected}, actual: {actual}"
|
||||
)
|
||||
|
||||
@dtypes(*FLOAT8_DTYPES)
|
||||
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
|
||||
def test_special_numbers(self, dtype, device):
|
||||
@ -269,6 +365,32 @@ class TestFloat8Dtype(TestCase):
|
||||
torch.use_deterministic_algorithms(use_deterministic)
|
||||
torch.empty(4, 4, device=device, dtype=dtype)
|
||||
|
||||
@dtypes(*FLOAT8_DTYPES)
|
||||
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
|
||||
def test_to_string(self, dtype, device):
|
||||
x = torch.empty(4, 4, device=device, dtype=dtype)
|
||||
str(x)
|
||||
|
||||
@dtypes(*FLOAT8_DTYPES)
|
||||
def test_finfo(self, dtype, device):
|
||||
torch.finfo(dtype)
|
||||
|
||||
@dtypes(*FLOAT8_DTYPES)
|
||||
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
|
||||
def test_cat(self, dtype, device):
|
||||
x1 = torch.empty(4, 4, device=device, dtype=dtype)
|
||||
x2 = torch.empty(4, 4, device=device, dtype=dtype)
|
||||
torch.cat([x1, x2])
|
||||
|
||||
@dtypes(*FLOAT8_DTYPES)
|
||||
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
|
||||
def test_save_load(self, dtype, device):
|
||||
x1 = torch.randint(0, 10, (4, 4), device=device, dtype=torch.uint8).view(dtype)
|
||||
with TemporaryFileName() as fname:
|
||||
torch.save(x1, fname)
|
||||
x1_save_load = torch.load(fname)
|
||||
torch.testing.assert_close(x1, x1_save_load, atol=0, rtol=0)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestFloat8Dtype, globals())
|
||||
|
||||
@ -285,6 +407,9 @@ class TestFloat8DtypeCPUOnly(TestCase):
|
||||
|
||||
@dtypes(*CUDA_FLOAT8_DTYPES)
|
||||
def test_mul(self, dtype):
|
||||
# TODO(#113663): remove arithmetic support from all float8 dtypes
|
||||
if dtype is torch.float8_e8m0fnu:
|
||||
return unittest.skip("arithmetic not supported for torch.float8_e8m0fnu")
|
||||
shape = (10, 10)
|
||||
a = torch.randn(shape)
|
||||
a8_simulated = simulate_fp8_precision(a, dtype)
|
||||
@ -299,6 +424,11 @@ class TestFloat8DtypeCPUOnly(TestCase):
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on Windows yet")
|
||||
@dtypes(*CUDA_FLOAT8_DTYPES)
|
||||
def test_pt2_traceable_aot_eager(self, dtype):
|
||||
if dtype is torch.float8_e8m0fnu:
|
||||
return unittest.skip(
|
||||
"PT2 support for torch.float8_e8m0fnu is not implemented yet"
|
||||
)
|
||||
|
||||
@torch.compile(backend="aot_eager", fullgraph=True)
|
||||
def f(x):
|
||||
x = x.to(dtype)
|
||||
|
@ -1362,7 +1362,7 @@ def gen_pyi(
|
||||
# Generate type signatures for dtype classes
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
# TODO: don't explicitly list dtypes here; get it from canonical
|
||||
# TODO(#146647): don't explicitly list dtypes here; get it from canonical
|
||||
# source
|
||||
dtype_class_hints = [
|
||||
f"{n}: dtype = ..."
|
||||
@ -1377,6 +1377,7 @@ def gen_pyi(
|
||||
"float8_e4m3fnuz",
|
||||
"float8_e5m2",
|
||||
"float8_e5m2fnuz",
|
||||
"float8_e8m0fnu",
|
||||
"half",
|
||||
"uint8",
|
||||
"uint16",
|
||||
|
@ -150,7 +150,17 @@ class _Formatter:
|
||||
# no valid number, do nothing
|
||||
return
|
||||
|
||||
if tensor.dtype == torch.float8_e8m0fnu: # type: ignore[attr-defined]
|
||||
# float8_e8m0fnu is special and does not define arithmetic ops,
|
||||
# and printing code further in this file assumes the existence
|
||||
# of various arithmetic ops to figure out what to print. We hack
|
||||
# and convert to float here to make printing work correctly.
|
||||
# TODO(#113663): also add the other float8 dtypes here after arithmetic
|
||||
# support for them is removed
|
||||
nonzero_finite_vals = nonzero_finite_vals.float()
|
||||
|
||||
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
||||
|
||||
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
|
||||
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
|
||||
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())
|
||||
|
@ -123,16 +123,15 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
|
||||
}
|
||||
|
||||
#define _AT_DISPATCH_FINFO_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
|
||||
at::kHalf, \
|
||||
at::ScalarType::BFloat16, \
|
||||
at::ScalarType::Float8_e5m2, \
|
||||
at::ScalarType::Float8_e5m2fnuz, \
|
||||
at::ScalarType::Float8_e4m3fn, \
|
||||
at::ScalarType::Float8_e4m3fnuz, \
|
||||
AT_DISPATCH_V2( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
__VA_ARGS__)
|
||||
AT_WRAP(__VA_ARGS__), \
|
||||
AT_EXPAND(AT_FLOATING_TYPES), \
|
||||
AT_EXPAND(AT_COMPLEX_TYPES), \
|
||||
at::kHalf, \
|
||||
at::ScalarType::BFloat16, \
|
||||
AT_EXPAND(AT_FLOAT8_TYPES))
|
||||
|
||||
static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
|
||||
HANDLE_TH_ERRORS
|
||||
|
@ -79,6 +79,7 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
|
||||
*(at::BFloat16*)data =
|
||||
at::convert<at::BFloat16, double>(THPUtils_unpackDouble(obj));
|
||||
break;
|
||||
// TODO(#146647): simplify below with macros
|
||||
case at::kFloat8_e5m2:
|
||||
*(at::Float8_e5m2*)data =
|
||||
at::convert<at::Float8_e5m2, double>(THPUtils_unpackDouble(obj));
|
||||
@ -95,8 +96,12 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
|
||||
*(at::Float8_e4m3fnuz*)data =
|
||||
at::convert<at::Float8_e4m3fnuz, double>(THPUtils_unpackDouble(obj));
|
||||
break;
|
||||
case at::kFloat8_e8m0fnu:
|
||||
*(at::Float8_e8m0fnu*)data =
|
||||
at::convert<at::Float8_e8m0fnu, double>(THPUtils_unpackDouble(obj));
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("invalid type");
|
||||
throw std::runtime_error("store_scalar: invalid type");
|
||||
}
|
||||
}
|
||||
|
||||
@ -143,6 +148,7 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) {
|
||||
case at::kBFloat16:
|
||||
return PyFloat_FromDouble(
|
||||
at::convert<double, at::BFloat16>(*(at::BFloat16*)data));
|
||||
// TODO(#146647): simplify below with macros
|
||||
case at::kFloat8_e5m2:
|
||||
return PyFloat_FromDouble(
|
||||
at::convert<double, at::Float8_e5m2>(*(at::Float8_e5m2*)data));
|
||||
@ -155,8 +161,11 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) {
|
||||
case at::kFloat8_e4m3fnuz:
|
||||
return PyFloat_FromDouble(at::convert<double, at::Float8_e4m3fnuz>(
|
||||
*(at::Float8_e4m3fnuz*)data));
|
||||
case at::kFloat8_e8m0fnu:
|
||||
return PyFloat_FromDouble(
|
||||
at::convert<double, at::Float8_e8m0fnu>(*(at::Float8_e8m0fnu*)data));
|
||||
default:
|
||||
throw std::runtime_error("invalid type");
|
||||
throw std::runtime_error("load_scalar: invalid type");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -535,6 +535,7 @@ def _new_dtypes():
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e8m0fnu,
|
||||
torch.bits8,
|
||||
torch.bits16,
|
||||
torch.bits1x8,
|
||||
|
@ -51,6 +51,7 @@ float8_e5m2T = BaseCppType("at", "Float8_e5m2")
|
||||
float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz")
|
||||
float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
|
||||
float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz")
|
||||
float8_e8m0fnuT = BaseCppType("at", "Float8_e8m0fnu")
|
||||
stringT = BaseCppType("c10", "string_view")
|
||||
generatorT = BaseCppType("at", "Generator")
|
||||
scalarTypeT = BaseCppType("at", "ScalarType")
|
||||
@ -102,6 +103,7 @@ ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
|
||||
ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT,
|
||||
ScalarType.Float8_e4m3fn: float8_e4m3fnT,
|
||||
ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
|
||||
ScalarType.Float8_e8m0fnu: float8_e8m0fnuT,
|
||||
}
|
||||
|
||||
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
|
||||
|
@ -374,6 +374,7 @@ class ScalarType(Enum):
|
||||
Float8_e5m2fnuz = auto()
|
||||
Float8_e4m3fn = auto()
|
||||
Float8_e4m3fnuz = auto()
|
||||
Float8_e8m0fnu = auto()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
Reference in New Issue
Block a user