Move Float4 to headeronly (#159414)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159414
Approved by: https://github.com/desertfire
This commit is contained in:
Jane Xu
2025-07-29 13:35:25 -07:00
committed by PyTorch MergeBot
parent 52a52d1b78
commit b268f22ab2
4 changed files with 42 additions and 28 deletions

View File

@ -1,28 +1 @@
#pragma once
#include <cstdint>
#include <c10/macros/Macros.h>
/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed
/// into one byte). This is the FP4 dtype from the OCP MX format spec
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
/// Section 5.3.3)
///
/// Given two high precision values val0 and val1, here is the
/// binary configuration of their packed representation, from MSB to LSB:
///
/// original value | val1 : val0
/// ========================================
/// bit index (MSB==7, LSB==0) | 7654 : 3210
/// sign/exponent/mantissa | seem : seem
///
namespace c10 {
struct alignas(1) Float4_e2m1fn_x2 {
uint8_t val_;
Float4_e2m1fn_x2() = default;
C10_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {}
};
} // namespace c10
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>

View File

@ -8,6 +8,7 @@
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
#include <torch/headeronly/util/bits.h>
#include <torch/headeronly/util/qint32.h>
@ -86,6 +87,11 @@ TEST(TestDtype, TestFloat8_e5m2fnuz) {
EXPECT_EQ(a / b, div);
}
TEST(TestDtype, TestFloat4) {
// not much you can do with this type, just make sure it compiles
torch::headeronly::Float4_e2m1fn_x2 a(5);
}
TEST(TestDtype, TestHalf) {
c10::Half a = 1.0f;
c10::Half b = 2.0f;

View File

@ -12,6 +12,9 @@ bit_cast
# c10/util/BFloat16-math.h, c10/util/BFloat16.h
BFloat16
# torch/headeronly/util/Float4_e2m1fn_x2.h
Float4_e2m1fn_x2
# c10/util/Float8_e4m3fn.h
Float8_e4m3fn

View File

@ -0,0 +1,32 @@
#pragma once
#include <cstdint>
#include <torch/headeronly/macros/Macros.h>
/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed
/// into one byte). This is the FP4 dtype from the OCP MX format spec
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
/// Section 5.3.3)
///
/// Given two high precision values val0 and val1, here is the
/// binary configuration of their packed representation, from MSB to LSB:
///
/// original value | val1 : val0
/// ========================================
/// bit index (MSB==7, LSB==0) | 7654 : 3210
/// sign/exponent/mantissa | seem : seem
///
namespace c10 {
struct alignas(1) Float4_e2m1fn_x2 {
uint8_t val_;
Float4_e2m1fn_x2() = default;
C10_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {}
};
} // namespace c10
namespace torch::headeronly {
using c10::Float4_e2m1fn_x2;
} // namespace torch::headeronly