Compare commits

...

1 Commits

Author SHA1 Message Date
66f30ce6dd [ONNX] Support float4 (#151069)
- Support exporting float4 models (note: currently we use IR version 10 universally in the exporter, which does not include float 4 support. Eventually when onnx runtime and the ecosystem moves to support the new IR version 11 we should bump our version to 11 in the exporter as well)
- The shape of the type is set according to https://github.com/pytorch/pytorch/pull/148791#discussion_r2038704986 (added last dim with size 2)
- Use ml_dtypes types when converting to numpy for consistency with ONNX IR

Fix https://github.com/pytorch/pytorch/issues/150202

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151069
Approved by: https://github.com/titaiwangms
2025-05-18 14:06:21 -07:00
9 changed files with 121 additions and 6 deletions

View File

@ -1585,6 +1585,38 @@ class TestCutlassBackend(TestCase):
)
torch.testing.assert_close(result, ref_result)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
def test_evt_dynamic_batch_size(self):
op = torch.add
class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
return op(acc.relu(), *extra_args)
B = 3
M = 1024
N = 512
K = 512
def run_model(B, M, N, K):
a = torch.ones(B, M, K).cuda().half()
b = torch.ones(B, K, N).cuda().half()
extra_args = gen_args(op, (M, N))
model = TestModel().cuda()
result = torch.compile(model)(a, b, extra_args)
ref_result = model(a, b, extra_args)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
)
torch.testing.assert_close(result, ref_result)
run_model(B, M, N, K)
run_model(B + 1, M, N, K)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu

View File

@ -35,10 +35,14 @@ class TorchTensorTest(common_utils.TestCase):
(torch.uint32, np.uint32),
(torch.uint64, np.uint64),
(torch.uint8, np.uint8),
(torch.float4_e2m1fn_x2, ml_dtypes.float4_e2m1fn),
],
)
def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype):
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
if dtype == torch.float4_e2m1fn_x2:
tensor = _core.TorchTensor(torch.tensor([1], dtype=torch.uint8).view(dtype))
else:
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
self.assertEqual(tensor.numpy().dtype, np_dtype)
self.assertEqual(tensor.__array__().dtype, np_dtype)
self.assertEqual(np.array(tensor).dtype, np_dtype)
@ -71,6 +75,12 @@ class TorchTensorTest(common_utils.TestCase):
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes())
def test_tobytes_float4(self):
tensor = _core.TorchTensor(
torch.tensor([1], dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
)
self.assertEqual(tensor.tobytes(), b"\x01")
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -216,7 +216,19 @@ class DynamoExporterTest(common_utils.TestCase):
input = input.to(float8_type)
return input
_ = self.export(Float8Module(), (torch.randn(1, 2),))
onnx_program = self.export(Float8Module(), (torch.randn(1, 2),))
self.assertEqual(onnx_program.model.graph.outputs[0].dtype, onnx_type)
def test_float4_support(self):
class Float4Module(torch.nn.Module):
def forward(self):
return torch.empty([1], dtype=torch.float4_e2m1fn_x2)
onnx_program = self.export(Float4Module())
output = onnx_program.model.graph.outputs[0]
self.assertEqual(output.dtype, ir.DataType.FLOAT4E2M1)
# The shape is [*shape, 2] because ONNX stores the shape of the unpacked tensor
self.assertEqual(output.shape.dims, [1, 2])
def test_bfloat16_support(self):
class BfloatModel(torch.nn.Module):

View File

@ -277,7 +277,9 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
) -> bool:
return all(
sympy.Eq(l, r) or sympy.Eq(l, 0) or sympy.Eq(r, 0)
for l, r in (zip(left, right))
for l, r in (
itertools.zip_longest(reversed(left), reversed(right), fillvalue=1)
) # fill with non-zero to verify the other side is zero
)
def _render_input_signature(self) -> str:

View File

@ -38,6 +38,7 @@ from torch.onnx._internal.exporter import (
_registration,
_reporting,
_tensors,
_type_casting,
_verification,
)
@ -61,6 +62,7 @@ _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
torch.float4_e2m1fn_x2: ir.DataType.FLOAT4E2M1,
torch.int16: ir.DataType.INT16,
torch.int32: ir.DataType.INT32,
torch.int64: ir.DataType.INT64,
@ -109,8 +111,17 @@ def torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
class TorchTensor(ir.Tensor):
def __init__(self, tensor: torch.Tensor, name: str | None = None):
# Pass the tensor as the raw data to ir.Tensor's constructor
if tensor.dtype == torch.float4_e2m1fn_x2:
# Change the shape to the unpacked shape
shape = ir.Shape(_type_casting.get_float4_shape(tensor), frozen=True)
else:
# The base class will set the shape to the tensor's shape
shape = None
super().__init__(
tensor, dtype=torch_dtype_to_onnx_dtype(tensor.dtype), name=name
tensor,
dtype=torch_dtype_to_onnx_dtype(tensor.dtype),
shape=shape,
name=name,
)
def numpy(self) -> npt.NDArray:
@ -132,6 +143,10 @@ class TorchTensor(ir.Tensor):
ir.DataType.FLOAT8E5M2FNUZ,
}:
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
if self.dtype == ir.DataType.FLOAT4E2M1:
return _type_casting.unpack_float4x2_as_uint8(self.raw).view(
self.dtype.numpy()
)
return self.raw.numpy(force=True)
@ -213,7 +228,13 @@ def _set_shape_type(
logger.warning("Setting shape and type of tensors is not supported yet")
if isinstance(meta_val, torch.Tensor):
dims = []
for dim in meta_val.shape:
shape: tuple[int, ...]
if meta_val.dtype == torch.float4_e2m1fn_x2:
# Change the shape to the unpacked shape
shape = _type_casting.get_float4_shape(meta_val)
else:
shape = meta_val.shape
for dim in shape:
if isinstance(dim, int):
dims.append(dim)
else:

View File

@ -27,6 +27,7 @@ _TORCH_DTYPE_TO_ONNX_COMPATIBLE: dict[torch.dtype, ir.DataType] = {
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
torch.float4_e2m1fn_x2: ir.DataType.FLOAT4E2M1,
torch.int16: ir.DataType.INT16,
torch.int32: ir.DataType.INT32,
torch.int64: ir.DataType.INT64,
@ -95,6 +96,7 @@ def _param_type_compatible_with_arg(
ir.TensorType(ir.DataType.INT32),
ir.TensorType(ir.DataType.INT64),
# Int inputs can be casted to a float too
ir.TensorType(ir.DataType.FLOAT4E2M1),
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
ir.TensorType(ir.DataType.FLOAT8E5M2),
@ -105,6 +107,7 @@ def _param_type_compatible_with_arg(
}:
return True
if isinstance(value, float) and param.type_constraint.allowed_types & {
ir.TensorType(ir.DataType.FLOAT4E2M1),
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
ir.TensorType(ir.DataType.FLOAT8E5M2),

View File

@ -0,0 +1,32 @@
import numpy as np
import torch
def unpack_float4x2_as_uint8(tensor: torch.Tensor) -> np.ndarray:
"""Convert a float4x2 tensor to unpacked uint8 np array."""
assert tensor.dtype == torch.float4_e2m1fn_x2
data = tensor.view(torch.uint8).numpy(force=True).flatten()
result_size = tensor.numel() * 2
result = np.empty([result_size], dtype=np.uint8)
array_low = data & np.uint8(0x0F)
array_high = data & np.uint8(0xF0)
array_high >>= np.uint8(4)
result[0::2] = array_low
result[1::2] = array_high
result.resize(get_float4_shape(tensor), refcheck=False)
return result
def get_float4_shape(tensor: torch.Tensor) -> tuple[int, ...]:
"""Get the shape of an unpacked float4 tensor.
The float4_e2m1fn_x2 type is a shell type described in
https://github.com/pytorch/pytorch/issues/146414.
the shell dtype is takes up 1 byte per element and semantically represents
two fp4 values packed into 1 byte. Semantically it represents (*tensor.shape, 2)
fp4 elements.
"""
assert tensor.dtype == torch.float4_e2m1fn_x2
return (*tensor.shape, 2)

View File

@ -38,6 +38,9 @@ _TORCH_DTYPE_TO_ONNX_DTYPE = {
torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ
torch.float8_e5m2: 19, # FLOAT8E5M2
torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ
# 21 = UINT4
# 22 = INT4
torch.float4_e2m1fn_x2: 23, # FLOAT4E2M1
}

View File

@ -40,7 +40,7 @@ _ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = {
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
21: torch.uint8, # UINT4
22: torch.uint8, # INT4
23: torch.uint8, # FLOAT4E2M1
23: torch.float4_e2m1fn_x2, # FLOAT4E2M1
}
_INT_TYPE = "i"