Add Float8 support to onnx exporter (#121281)

Fixes #106877

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121281
Approved by: https://github.com/BowenBao, https://github.com/titaiwangms
This commit is contained in:
Thiago Crepaldi
2024-03-06 18:46:56 +00:00
committed by PyTorch MergeBot
parent 5a2527db22
commit 418568d2e3
2 changed files with 41 additions and 0 deletions

View File

@ -650,6 +650,35 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature)
@common_utils.parametrize(
"float8_type",
[
common_utils.subtest(
torch.float8_e5m2,
name="torch_float8_e5m2",
),
common_utils.subtest(
torch.float8_e5m2fnuz,
name="torch_float8_e5m2fnuz",
),
common_utils.subtest(
torch.float8_e4m3fn,
name="torch_float8_e4m3fn",
),
common_utils.subtest(
torch.float8_e4m3fnuz,
name="torch_float8_e4m3fnuz",
),
],
)
def test_float8_support(self, float8_type):
class Float8Module(torch.nn.Module):
def forward(self, input: torch.Tensor):
input = input.to(float8_type)
return input + torch.tensor(1.0, dtype=float8_type)
_ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4))
if __name__ == "__main__":
common_utils.run_tests()