mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add finfo
properties for float8 dtypes (#109744)
Add float8 finfo checks to `test_type_info.py` Fixes https://github.com/pytorch/pytorch/issues/109737 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109744 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
e2e9d15726
commit
cddd0db241
@ -371,6 +371,22 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
|
||||
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
||||
|
||||
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
|
||||
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
|
||||
|
||||
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
||||
|
||||
#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
|
@ -68,6 +68,26 @@ class TestDTypeInfo(TestCase):
|
||||
with set_default_dtype(x.dtype):
|
||||
self.assertEqual(torch.finfo(x.dtype), torch.finfo())
|
||||
|
||||
# Special test case for Float8_E5M2
|
||||
xinfo = torch.finfo(torch.float8_e5m2)
|
||||
self.assertEqual(xinfo.bits, 8)
|
||||
self.assertEqual(xinfo.max, 57344.0)
|
||||
self.assertEqual(xinfo.min, -57344.0)
|
||||
self.assertEqual(xinfo.eps, .25)
|
||||
self.assertEqual(xinfo.tiny, 6.10352e-05)
|
||||
self.assertEqual(xinfo.resolution, 1.0)
|
||||
self.assertEqual(xinfo.dtype, "float8_e5m2")
|
||||
|
||||
# Special test case for Float8_E4M3FN
|
||||
xinfo = torch.finfo(torch.float8_e4m3fn)
|
||||
self.assertEqual(xinfo.bits, 8)
|
||||
self.assertEqual(xinfo.max, 448.0)
|
||||
self.assertEqual(xinfo.min, -448.0)
|
||||
self.assertEqual(xinfo.eps, .125)
|
||||
self.assertEqual(xinfo.tiny, 0.015625)
|
||||
self.assertEqual(xinfo.resolution, 1.0)
|
||||
self.assertEqual(xinfo.dtype, "float8_e4m3fn")
|
||||
|
||||
if __name__ == '__main__':
|
||||
TestCase._default_dtype_check_enabled = True
|
||||
run_tests()
|
||||
|
@ -112,8 +112,14 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
|
||||
}
|
||||
|
||||
static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::kHalf, at::ScalarType::BFloat16, self->type, "epsilon", [] {
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
||||
at::kHalf,
|
||||
at::ScalarType::BFloat16,
|
||||
at::ScalarType::Float8_e4m3fn,
|
||||
at::ScalarType::Float8_e5m2,
|
||||
self->type,
|
||||
"epsilon",
|
||||
[] {
|
||||
return PyFloat_FromDouble(
|
||||
std::numeric_limits<
|
||||
at::scalar_value_type<scalar_t>::type>::epsilon());
|
||||
@ -121,16 +127,28 @@ static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
|
||||
}
|
||||
|
||||
static PyObject* THPFInfo_max(THPFInfo* self, void*) {
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::kHalf, at::ScalarType::BFloat16, self->type, "max", [] {
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
||||
at::kHalf,
|
||||
at::ScalarType::BFloat16,
|
||||
at::ScalarType::Float8_e4m3fn,
|
||||
at::ScalarType::Float8_e5m2,
|
||||
self->type,
|
||||
"max",
|
||||
[] {
|
||||
return PyFloat_FromDouble(
|
||||
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
|
||||
});
|
||||
}
|
||||
|
||||
static PyObject* THPFInfo_min(THPFInfo* self, void*) {
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::kHalf, at::ScalarType::BFloat16, self->type, "lowest", [] {
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
||||
at::kHalf,
|
||||
at::ScalarType::BFloat16,
|
||||
at::ScalarType::Float8_e4m3fn,
|
||||
at::ScalarType::Float8_e5m2,
|
||||
self->type,
|
||||
"lowest",
|
||||
[] {
|
||||
return PyFloat_FromDouble(
|
||||
std::numeric_limits<
|
||||
at::scalar_value_type<scalar_t>::type>::lowest());
|
||||
@ -169,8 +187,14 @@ static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
|
||||
}
|
||||
|
||||
static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) {
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::kHalf, at::ScalarType::BFloat16, self->type, "min", [] {
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
||||
at::kHalf,
|
||||
at::ScalarType::BFloat16,
|
||||
at::ScalarType::Float8_e4m3fn,
|
||||
at::ScalarType::Float8_e5m2,
|
||||
self->type,
|
||||
"smallest",
|
||||
[] {
|
||||
return PyFloat_FromDouble(
|
||||
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
|
||||
});
|
||||
@ -182,8 +206,14 @@ static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
|
||||
}
|
||||
|
||||
static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::kHalf, at::ScalarType::BFloat16, self->type, "digits10", [] {
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
||||
at::kHalf,
|
||||
at::ScalarType::BFloat16,
|
||||
at::ScalarType::Float8_e4m3fn,
|
||||
at::ScalarType::Float8_e5m2,
|
||||
self->type,
|
||||
"digits10",
|
||||
[] {
|
||||
return PyFloat_FromDouble(std::pow(
|
||||
10,
|
||||
-std::numeric_limits<
|
||||
@ -193,9 +223,11 @@ static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
|
||||
|
||||
static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
|
||||
auto primary_name = torch::utils::getDtypeNames(self->type).first;
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
||||
at::kHalf,
|
||||
at::ScalarType::BFloat16,
|
||||
at::ScalarType::Float8_e4m3fn,
|
||||
at::ScalarType::Float8_e5m2,
|
||||
self->type,
|
||||
"dtype",
|
||||
[&primary_name] { return PyUnicode_FromString(primary_name.data()); });
|
||||
|
Reference in New Issue
Block a user