mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Continuing the work from https://github.com/pytorch/pytorch/pull/146427 Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in https://github.com/pytorch/pytorch/issues/146414 . Please see the issue for a detailed definition of the format. Example of basic functionality: ```python import torch # round trip x0 = torch.randn(4, 4, dtype=torch.float32) x1 = x0.to(torch.float8_e8m0fnu) # RNE rounding x2 = x1.to(torch.float32) # 2 ** exponent # creation with empty x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu) # printing print(x0) ``` Done in this PR: * numerical correctness * op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32 * printing a tensor works For future PRs: * performance optimizations for casting * torch._scaled_mm * PT2 * various cleanups (detailed in comments with issue numbers) Test Plan: ``` pytest test/quantization/core/experimental/test_float8.py -s ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/147466 Approved by: https://github.com/drisspg
13 lines
344 B
C++
13 lines
344 B
C++
#include <c10/macros/Macros.h>
|
|
#include <c10/util/Float8_e8m0fnu.h>
|
|
|
|
namespace c10 {
|
|
|
|
// TODO(#146647): Can we have these in a single shared cpp file
|
|
// built with macro to remove the need for a new cpp file?
|
|
static_assert(
|
|
std::is_standard_layout_v<Float8_e8m0fnu>,
|
|
"c10::Float8_e8m0fnu must be standard layout.");
|
|
|
|
} // namespace c10
|