mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Float8 support to onnx exporter (#121281)
Fixes #106877 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121281 Approved by: https://github.com/BowenBao, https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
5a2527db22
commit
418568d2e3
@ -650,6 +650,35 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||
|
||||
self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature)
|
||||
|
||||
@common_utils.parametrize(
|
||||
"float8_type",
|
||||
[
|
||||
common_utils.subtest(
|
||||
torch.float8_e5m2,
|
||||
name="torch_float8_e5m2",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.float8_e5m2fnuz,
|
||||
name="torch_float8_e5m2fnuz",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.float8_e4m3fn,
|
||||
name="torch_float8_e4m3fn",
|
||||
),
|
||||
common_utils.subtest(
|
||||
torch.float8_e4m3fnuz,
|
||||
name="torch_float8_e4m3fnuz",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_float8_support(self, float8_type):
|
||||
class Float8Module(torch.nn.Module):
|
||||
def forward(self, input: torch.Tensor):
|
||||
input = input.to(float8_type)
|
||||
return input + torch.tensor(1.0, dtype=float8_type)
|
||||
|
||||
_ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
||||
Reference in New Issue
Block a user