mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add torch.float4_e2m1fn_x2
to PyTorch (#148791)
Summary: Redo of https://github.com/pytorch/pytorch/pull/146578 to get around rebase conflicts. Test Plan: ``` pytest test/quantization/core/experimental/test_floatx.py -s ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/148791 Approved by: https://github.com/drisspg, https://github.com/eqy, https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
ac91f8765b
commit
e33bc41958
@ -71,6 +71,9 @@ DLDataType getDLDataType(const Tensor& t) {
|
||||
case ScalarType::Float8_e8m0fnu:
|
||||
TORCH_CHECK(false, "float8 types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::Float4_e2m1fn_x2:
|
||||
TORCH_CHECK(false, "float4 types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::QInt8:
|
||||
case ScalarType::QUInt8:
|
||||
case ScalarType::QInt32:
|
||||
|
@ -507,7 +507,8 @@ TORCH_IMPL_FUNC(cat_out_cuda)
|
||||
kBool,
|
||||
kBFloat16,
|
||||
AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
|
||||
kFloat4_e2m1fn_x2);
|
||||
}
|
||||
} else if (materialized.size() > 1 &&
|
||||
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
|
||||
@ -542,7 +543,9 @@ TORCH_IMPL_FUNC(cat_out_cuda)
|
||||
kFloat8_e4m3fnuz,
|
||||
kFloat8_e5m2,
|
||||
kFloat8_e5m2fnuz,
|
||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
|
||||
// TODO(#146647): extend this to other shell dtypes
|
||||
kFloat4_e2m1fn_x2);
|
||||
}
|
||||
} else {
|
||||
int64_t offset = 0;
|
||||
|
@ -225,6 +225,8 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
|
||||
case c10::ScalarType::Float8_e8m0fnu:
|
||||
// TODO(#146647): macroify all of this
|
||||
return std::make_pair("float8_e8m0fnu", "");
|
||||
case c10::ScalarType::Float4_e2m1fn_x2:
|
||||
return std::make_pair("float4_e2m1fn_x2", "");
|
||||
default:
|
||||
throw std::runtime_error("Unimplemented scalar type");
|
||||
}
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Deprecated.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Float4_e2m1fn_x2.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
@ -104,7 +105,8 @@ struct dummy_int1_7_t {};
|
||||
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
|
||||
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
|
||||
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
|
||||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */
|
||||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
|
||||
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
@ -387,7 +389,8 @@ inline bool isFloat8Type(ScalarType t) {
|
||||
}
|
||||
|
||||
inline bool isReducedFloatingType(ScalarType t) {
|
||||
return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t);
|
||||
return t == ScalarType::Half || t == ScalarType::BFloat16 ||
|
||||
isFloat8Type(t) || t == ScalarType::Float4_e2m1fn_x2;
|
||||
}
|
||||
|
||||
inline bool isFloatingType(ScalarType t) {
|
||||
@ -502,6 +505,7 @@ inline bool isSignedType(ScalarType t) {
|
||||
case ScalarType::Int5:
|
||||
case ScalarType::Int6:
|
||||
case ScalarType::Int7:
|
||||
case ScalarType::Float4_e2m1fn_x2:
|
||||
return true;
|
||||
case ScalarType::UInt1:
|
||||
case ScalarType::UInt2:
|
||||
|
28
c10/util/Float4_e2m1fn_x2.h
Normal file
28
c10/util/Float4_e2m1fn_x2.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed
|
||||
/// into one byte). This is the FP4 dtype from the OCP MX format spec
|
||||
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
|
||||
/// Section 5.3.3)
|
||||
///
|
||||
/// Given two high precision values val0 and val1, here is the
|
||||
/// binary configuration of their packed representation, from MSB to LSB:
|
||||
///
|
||||
/// original value | val1 : val0
|
||||
/// ========================================
|
||||
/// bit index (MSB==7, LSB==0) | 7654 : 3210
|
||||
/// sign/exponent/mantissa | seem : seem
|
||||
///
|
||||
|
||||
namespace c10 {
|
||||
|
||||
struct alignas(1) Float4_e2m1fn_x2 {
|
||||
uint8_t val_;
|
||||
Float4_e2m1fn_x2() = default;
|
||||
C10_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {}
|
||||
};
|
||||
|
||||
} // namespace c10
|
@ -392,7 +392,37 @@ class TestFloat8Dtype(TestCase):
|
||||
torch.testing.assert_close(x1, x1_save_load, atol=0, rtol=0)
|
||||
|
||||
|
||||
class TestFloat4Dtype(TestCase):
|
||||
# TODO(#146647): make the testing generic for shell dtypes
|
||||
def test_float4_e2m1fn_x2(self, device):
|
||||
# can create a tensor of dtype float4
|
||||
x1 = torch.empty(4096, 4096, device=device, dtype=torch.float4_e2m1fn_x2)
|
||||
|
||||
# can create a string (so printing will work)
|
||||
str(x1)
|
||||
|
||||
# can view float4_e2m1fn_x2 as uint8
|
||||
x2 = x1.view(torch.uint8)
|
||||
|
||||
# can view uint8 as float4_e2m1fn_x2
|
||||
x2.view(torch.float4_e2m1fn_x2)
|
||||
|
||||
def test_f4_save_load(self, device):
|
||||
x1 = torch.randint(0, 10, (4, 4), device=device, dtype=torch.uint8).view(
|
||||
torch.float4_e2m1fn_x2
|
||||
)
|
||||
with TemporaryFileName() as fname:
|
||||
torch.save(x1, fname)
|
||||
x1_save_load = torch.load(fname)
|
||||
# TODO(#146647): make this and all other shell dtypes support equality
|
||||
# comparison
|
||||
torch.testing.assert_close(
|
||||
x1.view(torch.uint8), x1_save_load.view(torch.uint8), atol=0, rtol=0
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestFloat8Dtype, globals())
|
||||
instantiate_device_type_tests(TestFloat4Dtype, globals())
|
||||
|
||||
|
||||
class TestFloat8DtypeCPUOnly(TestCase):
|
@ -163,15 +163,15 @@ try:
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
try:
|
||||
from quantization.core.experimental.test_float8 import TestFloat8DtypeCPU # noqa: F401
|
||||
from quantization.core.experimental.test_floatx import TestFloat8DtypeCPU # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
try:
|
||||
from quantization.core.experimental.test_float8 import TestFloat8DtypeCUDA # noqa: F401
|
||||
from quantization.core.experimental.test_floatx import TestFloat8DtypeCUDA # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
try:
|
||||
from quantization.core.experimental.test_float8 import TestFloat8DtypeCPUOnlyCPU # noqa: F401
|
||||
from quantization.core.experimental.test_floatx import TestFloat8DtypeCPUOnlyCPU # noqa: F401
|
||||
except ImportError as e:
|
||||
logging.warning(e)
|
||||
|
||||
|
@ -1378,6 +1378,7 @@ def gen_pyi(
|
||||
"float8_e5m2",
|
||||
"float8_e5m2fnuz",
|
||||
"float8_e8m0fnu",
|
||||
"float4_e2m1fn_x2",
|
||||
"half",
|
||||
"uint8",
|
||||
"uint16",
|
||||
|
@ -142,6 +142,14 @@ class _Formatter:
|
||||
self.max_width = max(self.max_width, len(value_str))
|
||||
|
||||
else:
|
||||
if tensor.dtype == torch.float4_e2m1fn_x2: # type: ignore[attr-defined]
|
||||
# torch.float4_e2m1fn_x2 is special and does not support the casts necessary
|
||||
# to print it, we choose to display the uint8 representation here for
|
||||
# convenience of being able to print a tensor.
|
||||
# TODO(#146647): extend this to other dtypes without casts defined, such
|
||||
# as the bits, uint1..7 and int1..7 dtypes.
|
||||
tensor_view = tensor_view.view(torch.uint8)
|
||||
|
||||
nonzero_finite_vals = torch.masked_select(
|
||||
tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
|
||||
)
|
||||
@ -258,6 +266,14 @@ def _vector_str(self, indent, summarize, formatter1, formatter2=None):
|
||||
else:
|
||||
return formatter1.format(val)
|
||||
|
||||
if self.dtype == torch.float4_e2m1fn_x2: # type: ignore[attr-defined]
|
||||
# torch.float4_e2m1fn_x2 is special and does not support the casts necessary
|
||||
# to print it, we choose to display the uint8 representation here for
|
||||
# convenience of being able to print a tensor.
|
||||
# TODO(#146647): extend this to other dtypes without casts defined, such
|
||||
# as the bits, uint1..7 and int1..7 dtypes.
|
||||
self = self.view(torch.uint8)
|
||||
|
||||
if summarize and not PRINT_OPTS.edgeitems:
|
||||
# Deal with edge case that negative zero is zero
|
||||
data = ["..."]
|
||||
|
@ -540,6 +540,7 @@ def _new_dtypes():
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e8m0fnu,
|
||||
torch.float4_e2m1fn_x2,
|
||||
torch.bits8,
|
||||
torch.bits16,
|
||||
torch.bits1x8,
|
||||
|
Reference in New Issue
Block a user