mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move Float4 to headeronly (#159414)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159414 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
52a52d1b78
commit
b268f22ab2
@ -1,28 +1 @@
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed
|
||||
/// into one byte). This is the FP4 dtype from the OCP MX format spec
|
||||
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
|
||||
/// Section 5.3.3)
|
||||
///
|
||||
/// Given two high precision values val0 and val1, here is the
|
||||
/// binary configuration of their packed representation, from MSB to LSB:
|
||||
///
|
||||
/// original value | val1 : val0
|
||||
/// ========================================
|
||||
/// bit index (MSB==7, LSB==0) | 7654 : 3210
|
||||
/// sign/exponent/mantissa | seem : seem
|
||||
///
|
||||
|
||||
namespace c10 {
|
||||
|
||||
struct alignas(1) Float4_e2m1fn_x2 {
|
||||
uint8_t val_;
|
||||
Float4_e2m1fn_x2() = default;
|
||||
C10_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <c10/util/Float8_e5m2fnuz.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
|
||||
|
||||
#include <torch/headeronly/util/bits.h>
|
||||
#include <torch/headeronly/util/qint32.h>
|
||||
@ -86,6 +87,11 @@ TEST(TestDtype, TestFloat8_e5m2fnuz) {
|
||||
EXPECT_EQ(a / b, div);
|
||||
}
|
||||
|
||||
TEST(TestDtype, TestFloat4) {
|
||||
// not much you can do with this type, just make sure it compiles
|
||||
torch::headeronly::Float4_e2m1fn_x2 a(5);
|
||||
}
|
||||
|
||||
TEST(TestDtype, TestHalf) {
|
||||
c10::Half a = 1.0f;
|
||||
c10::Half b = 2.0f;
|
||||
|
@ -12,6 +12,9 @@ bit_cast
|
||||
# c10/util/BFloat16-math.h, c10/util/BFloat16.h
|
||||
BFloat16
|
||||
|
||||
# torch/headeronly/util/Float4_e2m1fn_x2.h
|
||||
Float4_e2m1fn_x2
|
||||
|
||||
# c10/util/Float8_e4m3fn.h
|
||||
Float8_e4m3fn
|
||||
|
||||
|
32
torch/headeronly/util/Float4_e2m1fn_x2.h
Normal file
32
torch/headeronly/util/Float4_e2m1fn_x2.h
Normal file
@ -0,0 +1,32 @@
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
|
||||
/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed
|
||||
/// into one byte). This is the FP4 dtype from the OCP MX format spec
|
||||
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
|
||||
/// Section 5.3.3)
|
||||
///
|
||||
/// Given two high precision values val0 and val1, here is the
|
||||
/// binary configuration of their packed representation, from MSB to LSB:
|
||||
///
|
||||
/// original value | val1 : val0
|
||||
/// ========================================
|
||||
/// bit index (MSB==7, LSB==0) | 7654 : 3210
|
||||
/// sign/exponent/mantissa | seem : seem
|
||||
///
|
||||
|
||||
namespace c10 {
|
||||
|
||||
struct alignas(1) Float4_e2m1fn_x2 {
|
||||
uint8_t val_;
|
||||
Float4_e2m1fn_x2() = default;
|
||||
C10_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::headeronly {
|
||||
using c10::Float4_e2m1fn_x2;
|
||||
} // namespace torch::headeronly
|
Reference in New Issue
Block a user