Add finfo properties for float8 dtypes (#109744)

Add float8 finfo checks to `test_type_info.py`
Fixes https://github.com/pytorch/pytorch/issues/109737
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109744
Approved by: https://github.com/drisspg
This commit is contained in:
Nikita Shulga
2023-09-20 16:37:55 -07:00
committed by PyTorch MergeBot
parent e2e9d15726
commit cddd0db241
3 changed files with 79 additions and 11 deletions

View File

@ -371,6 +371,22 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \

View File

@ -68,6 +68,26 @@ class TestDTypeInfo(TestCase):
with set_default_dtype(x.dtype):
self.assertEqual(torch.finfo(x.dtype), torch.finfo())
# Special test case for Float8_E5M2
xinfo = torch.finfo(torch.float8_e5m2)
self.assertEqual(xinfo.bits, 8)
self.assertEqual(xinfo.max, 57344.0)
self.assertEqual(xinfo.min, -57344.0)
self.assertEqual(xinfo.eps, .25)
self.assertEqual(xinfo.tiny, 6.10352e-05)
self.assertEqual(xinfo.resolution, 1.0)
self.assertEqual(xinfo.dtype, "float8_e5m2")
# Special test case for Float8_E4M3FN
xinfo = torch.finfo(torch.float8_e4m3fn)
self.assertEqual(xinfo.bits, 8)
self.assertEqual(xinfo.max, 448.0)
self.assertEqual(xinfo.min, -448.0)
self.assertEqual(xinfo.eps, .125)
self.assertEqual(xinfo.tiny, 0.015625)
self.assertEqual(xinfo.resolution, 1.0)
self.assertEqual(xinfo.dtype, "float8_e4m3fn")
if __name__ == '__main__':
TestCase._default_dtype_check_enabled = True
run_tests()

View File

@ -112,8 +112,14 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
}
static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::kHalf, at::ScalarType::BFloat16, self->type, "epsilon", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"epsilon",
[] {
return PyFloat_FromDouble(
std::numeric_limits<
at::scalar_value_type<scalar_t>::type>::epsilon());
@ -121,16 +127,28 @@ static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
}
static PyObject* THPFInfo_max(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::kHalf, at::ScalarType::BFloat16, self->type, "max", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"max",
[] {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
});
}
static PyObject* THPFInfo_min(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::kHalf, at::ScalarType::BFloat16, self->type, "lowest", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"lowest",
[] {
return PyFloat_FromDouble(
std::numeric_limits<
at::scalar_value_type<scalar_t>::type>::lowest());
@ -169,8 +187,14 @@ static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
}
static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::kHalf, at::ScalarType::BFloat16, self->type, "min", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"smallest",
[] {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
});
@ -182,8 +206,14 @@ static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
}
static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::kHalf, at::ScalarType::BFloat16, self->type, "digits10", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"digits10",
[] {
return PyFloat_FromDouble(std::pow(
10,
-std::numeric_limits<
@ -193,9 +223,11 @@ static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
auto primary_name = torch::utils::getDtypeNames(self->type).first;
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"dtype",
[&primary_name] { return PyUnicode_FromString(primary_name.data()); });