add torch.float4_e2m1fn_x2 to PyTorch (#148791)

Summary:

Redo of https://github.com/pytorch/pytorch/pull/146578 to get around
rebase conflicts.

Test Plan:

```
pytest test/quantization/core/experimental/test_floatx.py -s
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148791
Approved by: https://github.com/drisspg, https://github.com/eqy, https://github.com/jeffdaily
This commit is contained in:
vasiliy
2025-03-26 12:48:10 -07:00
committed by PyTorch MergeBot
parent ac91f8765b
commit e33bc41958
10 changed files with 95 additions and 7 deletions

View File

@ -71,6 +71,9 @@ DLDataType getDLDataType(const Tensor& t) {
case ScalarType::Float8_e8m0fnu:
TORCH_CHECK(false, "float8 types are not supported by dlpack");
break;
case ScalarType::Float4_e2m1fn_x2:
TORCH_CHECK(false, "float4 types are not supported by dlpack");
break;
case ScalarType::QInt8:
case ScalarType::QUInt8:
case ScalarType::QInt32:

View File

@ -507,7 +507,8 @@ TORCH_IMPL_FUNC(cat_out_cuda)
kBool,
kBFloat16,
AT_EXPAND(AT_FLOAT8_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
kFloat4_e2m1fn_x2);
}
} else if (materialized.size() > 1 &&
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
@ -542,7 +543,9 @@ TORCH_IMPL_FUNC(cat_out_cuda)
kFloat8_e4m3fnuz,
kFloat8_e5m2,
kFloat8_e5m2fnuz,
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
// TODO(#146647): extend this to other shell dtypes
kFloat4_e2m1fn_x2);
}
} else {
int64_t offset = 0;

View File

@ -225,6 +225,8 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
case c10::ScalarType::Float8_e8m0fnu:
// TODO(#146647): macroify all of this
return std::make_pair("float8_e8m0fnu", "");
case c10::ScalarType::Float4_e2m1fn_x2:
return std::make_pair("float4_e2m1fn_x2", "");
default:
throw std::runtime_error("Unimplemented scalar type");
}

View File

@ -3,6 +3,7 @@
#include <c10/util/BFloat16.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Exception.h>
#include <c10/util/Float4_e2m1fn_x2.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
@ -104,7 +105,8 @@ struct dummy_int1_7_t {};
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
@ -387,7 +389,8 @@ inline bool isFloat8Type(ScalarType t) {
}
inline bool isReducedFloatingType(ScalarType t) {
return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t);
return t == ScalarType::Half || t == ScalarType::BFloat16 ||
isFloat8Type(t) || t == ScalarType::Float4_e2m1fn_x2;
}
inline bool isFloatingType(ScalarType t) {
@ -502,6 +505,7 @@ inline bool isSignedType(ScalarType t) {
case ScalarType::Int5:
case ScalarType::Int6:
case ScalarType::Int7:
case ScalarType::Float4_e2m1fn_x2:
return true;
case ScalarType::UInt1:
case ScalarType::UInt2:

View File

@ -0,0 +1,28 @@
#pragma once
#include <cstdint>
#include <c10/macros/Macros.h>
/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed
/// into one byte). This is the FP4 dtype from the OCP MX format spec
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
/// Section 5.3.3)
///
/// Given two high precision values val0 and val1, here is the
/// binary configuration of their packed representation, from MSB to LSB:
///
/// original value | val1 : val0
/// ========================================
/// bit index (MSB==7, LSB==0) | 7654 : 3210
/// sign/exponent/mantissa | seem : seem
///
namespace c10 {
struct alignas(1) Float4_e2m1fn_x2 {
uint8_t val_;
Float4_e2m1fn_x2() = default;
C10_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {}
};
} // namespace c10

View File

@ -392,7 +392,37 @@ class TestFloat8Dtype(TestCase):
torch.testing.assert_close(x1, x1_save_load, atol=0, rtol=0)
class TestFloat4Dtype(TestCase):
# TODO(#146647): make the testing generic for shell dtypes
def test_float4_e2m1fn_x2(self, device):
# can create a tensor of dtype float4
x1 = torch.empty(4096, 4096, device=device, dtype=torch.float4_e2m1fn_x2)
# can create a string (so printing will work)
str(x1)
# can view float4_e2m1fn_x2 as uint8
x2 = x1.view(torch.uint8)
# can view uint8 as float4_e2m1fn_x2
x2.view(torch.float4_e2m1fn_x2)
def test_f4_save_load(self, device):
x1 = torch.randint(0, 10, (4, 4), device=device, dtype=torch.uint8).view(
torch.float4_e2m1fn_x2
)
with TemporaryFileName() as fname:
torch.save(x1, fname)
x1_save_load = torch.load(fname)
# TODO(#146647): make this and all other shell dtypes support equality
# comparison
torch.testing.assert_close(
x1.view(torch.uint8), x1_save_load.view(torch.uint8), atol=0, rtol=0
)
instantiate_device_type_tests(TestFloat8Dtype, globals())
instantiate_device_type_tests(TestFloat4Dtype, globals())
class TestFloat8DtypeCPUOnly(TestCase):

View File

@ -163,15 +163,15 @@ try:
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_float8 import TestFloat8DtypeCPU # noqa: F401
from quantization.core.experimental.test_floatx import TestFloat8DtypeCPU # noqa: F401
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_float8 import TestFloat8DtypeCUDA # noqa: F401
from quantization.core.experimental.test_floatx import TestFloat8DtypeCUDA # noqa: F401
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_float8 import TestFloat8DtypeCPUOnlyCPU # noqa: F401
from quantization.core.experimental.test_floatx import TestFloat8DtypeCPUOnlyCPU # noqa: F401
except ImportError as e:
logging.warning(e)

View File

@ -1378,6 +1378,7 @@ def gen_pyi(
"float8_e5m2",
"float8_e5m2fnuz",
"float8_e8m0fnu",
"float4_e2m1fn_x2",
"half",
"uint8",
"uint16",

View File

@ -142,6 +142,14 @@ class _Formatter:
self.max_width = max(self.max_width, len(value_str))
else:
if tensor.dtype == torch.float4_e2m1fn_x2: # type: ignore[attr-defined]
# torch.float4_e2m1fn_x2 is special and does not support the casts necessary
# to print it, we choose to display the uint8 representation here for
# convenience of being able to print a tensor.
# TODO(#146647): extend this to other dtypes without casts defined, such
# as the bits, uint1..7 and int1..7 dtypes.
tensor_view = tensor_view.view(torch.uint8)
nonzero_finite_vals = torch.masked_select(
tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
)
@ -258,6 +266,14 @@ def _vector_str(self, indent, summarize, formatter1, formatter2=None):
else:
return formatter1.format(val)
if self.dtype == torch.float4_e2m1fn_x2: # type: ignore[attr-defined]
# torch.float4_e2m1fn_x2 is special and does not support the casts necessary
# to print it, we choose to display the uint8 representation here for
# convenience of being able to print a tensor.
# TODO(#146647): extend this to other dtypes without casts defined, such
# as the bits, uint1..7 and int1..7 dtypes.
self = self.view(torch.uint8)
if summarize and not PRINT_OPTS.edgeitems:
# Deal with edge case that negative zero is zero
data = ["..."]

View File

@ -540,6 +540,7 @@ def _new_dtypes():
torch.float8_e5m2fnuz,
torch.float8_e4m3fnuz,
torch.float8_e8m0fnu,
torch.float4_e2m1fn_x2,
torch.bits8,
torch.bits16,
torch.bits1x8,