Files
pytorch/aten/src/ATen/native/cpu/FillKernel.cpp
vasiliy 382fbcc1e4 add the torch.float8_e8m0fnu dtype to PyTorch (#147466)
Summary:

Continuing the work from https://github.com/pytorch/pytorch/pull/146427

Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in
https://github.com/pytorch/pytorch/issues/146414 . Please see the issue for a detailed definition of the format.  Example of basic functionality:

```python
import torch

# round trip
x0 = torch.randn(4, 4, dtype=torch.float32)
x1 = x0.to(torch.float8_e8m0fnu)  # RNE rounding
x2 = x1.to(torch.float32)  # 2 ** exponent

# creation with empty
x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu)

# printing
print(x0)
```

Done in this PR:
* numerical correctness
* op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32
* printing a tensor works

For future PRs:
* performance optimizations for casting
* torch._scaled_mm
* PT2
* various cleanups (detailed in comments with issue numbers)

Test Plan:

```
pytest test/quantization/core/experimental/test_float8.py -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147466
Approved by: https://github.com/drisspg
2025-02-20 13:55:42 +00:00

76 lines
2.9 KiB
C++

#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/Dispatch_v2.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/Fill.h>
#include <c10/core/Scalar.h>
namespace at::native {
namespace {
template <typename scalar_t>
void fill_non_native_type(TensorIterator& iter, const Scalar& value_scalar) {
auto value = value_scalar.to<scalar_t>().x;
using H = typename std::make_signed_t<decltype(value)>; // Signed type has more acceleration
// Reserve the representation of value. static_cast<H>(value) is implementation defined.
H val = *reinterpret_cast<H*>(std::addressof(value));
cpu_kernel_vec</*check_dynamic_cast=*/false>(
iter,
[val]() -> H { return val; },
[val]() { return Vectorized<H>(val); });
}
template <>
void fill_non_native_type<c10::complex<at::Half>>(TensorIterator& iter, const Scalar& value_scalar) {
static_assert(sizeof(c10::complex<at::Half>) == sizeof(int32_t), "Size of ComplexHalf should be 32-bits");
auto value = c10::complex<at::Half>(value_scalar.to<c10::complex<float>>());
auto val = *reinterpret_cast<int32_t*>(std::addressof(value));
cpu_kernel_vec</*check_dynamic_cast=*/false>(
iter,
[val]() -> int32_t { return val; },
[val]() { return Vectorized<int32_t>(val); });
}
void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
if (iter.dtype() == ScalarType::Half) {
fill_non_native_type<at::Half>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::BFloat16) {
fill_non_native_type<at::BFloat16>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::ComplexHalf) {
fill_non_native_type<c10::complex<at::Half>>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e4m3fn) {
fill_non_native_type<at::Float8_e4m3fn>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e5m2) {
fill_non_native_type<at::Float8_e5m2>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e4m3fnuz) {
fill_non_native_type<at::Float8_e4m3fnuz>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) {
fill_non_native_type<at::Float8_e5m2fnuz>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e8m0fnu) {
// TODO(#146647): use macro here instead of spelling out each float8 dtype
fill_non_native_type<at::Float8_e8m0fnu>(iter, value_scalar);
} else {
AT_DISPATCH_V2(
iter.dtype(), "fill_cpu", AT_WRAP([&]() {
scalar_t value = value_scalar.to<scalar_t>();
cpu_kernel_vec(
iter,
[=]() -> scalar_t { return value; },
[=]() { return Vectorized<scalar_t>(value); });
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kBool, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
);
}
}
} // namespace
REGISTER_DISPATCH(fill_stub, &fill_kernel)
} // namespace at::native