mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Add Float8 support to onnx exporter (#121281)
Fixes #106877 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121281 Approved by: https://github.com/BowenBao, https://github.com/titaiwangms
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							5a2527db22
						
					
				
				
					commit
					418568d2e3
				
			| @ -650,6 +650,35 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase): | ||||
|  | ||||
|         self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature) | ||||
|  | ||||
|     @common_utils.parametrize( | ||||
|         "float8_type", | ||||
|         [ | ||||
|             common_utils.subtest( | ||||
|                 torch.float8_e5m2, | ||||
|                 name="torch_float8_e5m2", | ||||
|             ), | ||||
|             common_utils.subtest( | ||||
|                 torch.float8_e5m2fnuz, | ||||
|                 name="torch_float8_e5m2fnuz", | ||||
|             ), | ||||
|             common_utils.subtest( | ||||
|                 torch.float8_e4m3fn, | ||||
|                 name="torch_float8_e4m3fn", | ||||
|             ), | ||||
|             common_utils.subtest( | ||||
|                 torch.float8_e4m3fnuz, | ||||
|                 name="torch_float8_e4m3fnuz", | ||||
|             ), | ||||
|         ], | ||||
|     ) | ||||
|     def test_float8_support(self, float8_type): | ||||
|         class Float8Module(torch.nn.Module): | ||||
|             def forward(self, input: torch.Tensor): | ||||
|                 input = input.to(float8_type) | ||||
|                 return input + torch.tensor(1.0, dtype=float8_type) | ||||
|  | ||||
|         _ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     common_utils.run_tests() | ||||
|  | ||||
		Reference in New Issue
	
	Block a user