Add Float8 support to onnx exporter (#121281)

Fixes #106877

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121281
Approved by: https://github.com/BowenBao, https://github.com/titaiwangms
This commit is contained in:
Thiago Crepaldi
2024-03-06 18:46:56 +00:00
committed by PyTorch MergeBot
parent 5a2527db22
commit 418568d2e3
2 changed files with 41 additions and 0 deletions

View File

@ -650,6 +650,35 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature)
@common_utils.parametrize(
"float8_type",
[
common_utils.subtest(
torch.float8_e5m2,
name="torch_float8_e5m2",
),
common_utils.subtest(
torch.float8_e5m2fnuz,
name="torch_float8_e5m2fnuz",
),
common_utils.subtest(
torch.float8_e4m3fn,
name="torch_float8_e4m3fn",
),
common_utils.subtest(
torch.float8_e4m3fnuz,
name="torch_float8_e4m3fnuz",
),
],
)
def test_float8_support(self, float8_type):
class Float8Module(torch.nn.Module):
def forward(self, input: torch.Tensor):
input = input.to(float8_type)
return input + torch.tensor(1.0, dtype=float8_type)
_ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4))
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -124,6 +124,10 @@ _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS: Dict[
torch.float64: {"tensor(double)"},
torch.float32: {"tensor(float)"},
torch.float16: {"tensor(float16)"},
torch.float8_e4m3fn: {"tensor(float8_e4m3fn)"},
torch.float8_e4m3fnuz: {"tensor(float8_e4m3fnuz)"},
torch.float8_e5m2: {"tensor(float8_e5m2)"},
torch.float8_e5m2fnuz: {"tensor(float8_e5m2fnuz)"},
torch.int16: {"tensor(int16)"},
torch.int32: {"tensor(int32)"},
torch.int64: {"tensor(int64)"},
@ -174,6 +178,10 @@ _TORCH_DTYPE_TO_ABBREVIATION = {
torch.float64: "f64",
torch.float32: "f32",
torch.float16: "f16",
torch.float8_e4m3fn: "e4m3fn",
torch.float8_e4m3fnuz: "e4m3fnuz",
torch.float8_e5m2: "f8e5m2",
torch.float8_e5m2fnuz: "e5m2fnuz",
torch.complex32: "c32",
torch.complex64: "c64",
torch.complex128: "c128",
@ -200,6 +208,10 @@ _TORCH_DTYPE_TO_NUMPY_DTYPE = {
_ONNX_TENSOR_ELEMENT_TYPE_TO_TORCH_DTYPE = {
onnx.TensorProto.FLOAT: torch.float32, # type: ignore[attr-defined]
onnx.TensorProto.FLOAT16: torch.float16, # type: ignore[attr-defined]
onnx.TensorProto.FLOAT8E5M2: torch.float8_e5m2, # type: ignore[attr-defined]
onnx.TensorProto.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, # type: ignore[attr-defined]
onnx.TensorProto.FLOAT8E4M3FN: torch.float8_e4m3fn, # type: ignore[attr-defined]
onnx.TensorProto.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, # type: ignore[attr-defined]
onnx.TensorProto.DOUBLE: torch.float64, # type: ignore[attr-defined]
onnx.TensorProto.BOOL: torch.bool, # type: ignore[attr-defined]
onnx.TensorProto.UINT8: torch.uint8, # type: ignore[attr-defined]