[Inductor][Float8] Add float8_e4m3fn into assertion dtype list. (#157684)

Fix assert issue.
Add float8_e4m3fn into dtype list.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157684
Approved by: https://github.com/Xia-Weiwen, https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
wengshiy
2025-07-15 06:01:57 +00:00
committed by PyTorch MergeBot
parent 3341c131b7
commit c8c221c0b3
6 changed files with 126 additions and 8 deletions

View File

@ -1118,8 +1118,9 @@ static at::Tensor linear_int8_with_onednn_weight(
if(is_fp8 && !cpuinfo_has_x86_amx_int8()) {
#endif
// Fall back to ref impl on old platforms because not supported
// Transpose weight to align with behavior in oneDNN
return fp8_qlinear_onednn_ref(
input, input_scale, onednn_weight, weight_scales, bias,
input, input_scale, onednn_weight.t(), weight_scales, bias,
output_scale, output_dtype, other, other_scale,
binary_post_op, binary_alpha, unary_post_op,
unary_post_op_args, unary_post_op_algorithm);

View File

@ -305,11 +305,12 @@ static inline at::Tensor pack_weight_to_onednn_tensor(
#if defined(__powerpc__)
if (is_fp8){
#else
if(is_fp8 && !cpuinfo_has_x86_amx_int8()) {
if(is_fp8 && !cpuinfo_has_x86_amx_int8()) {
#endif
// oneDNN's fp8 requires AMX support
// If AMX is not available, fall back to reference implementation
return weight;
// Transpose weight to align with behavior in oneDNN
return weight.t();
}
std::vector<int64_t> w_dims = weight.sizes().vec();
auto w_data_type = is_fp8

View File

@ -2952,6 +2952,104 @@ class TestPatternMatcher(TestPatternMatcherBase):
is_dynamic=is_dynamic,
)
def _test_qlinear_fp8_inductor_cpu_helper(self, qlinear_op, post_op="none"):
dtype = torch.float8_e4m3fn
qlinear_prepack = torch.ops.onednn.qlinear_prepack
post_op_algo = "none"
unary_post_op_args = ()
batch_size = 1
output_dtype = torch.float8_e4m3fn
y_scale, y_zp = 0.07, 0
ic = 4
oc = 16
torch._dynamo.reset()
used_y_scale = y_scale
used_y_zp = y_zp
x = torch.rand(batch_size, ic)
w = torch.rand(oc, ic)
qx = x.to(dtype)
qw = w.to(dtype)
x_scale = 0.5
w_scales = torch.randn(oc)
b = torch.rand(oc)
x_zp = 0
w_zps = torch.zeros_like(w_scales, dtype=torch.int)
if post_op == "none":
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.qw_packed = qlinear_prepack(qw, x.shape)
def forward(self, qx):
qy = qlinear_op(
qx,
x_scale,
x_zp,
self.qw_packed,
w_scales,
w_zps,
b,
used_y_scale,
used_y_zp,
output_dtype,
post_op,
unary_post_op_args,
post_op_algo,
)
return qy
elif post_op == "add":
x2 = torch.rand(batch_size, oc)
binary_alpha = 1.0 # we only support alpha=1.0 now
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.qw_packed = qlinear_prepack(qw, x.shape)
def forward(self, qx):
qy = qlinear_op(
qx,
x_scale,
x_zp,
self.qw_packed,
w_scales,
w_zps,
x2,
b,
used_y_scale,
used_y_zp,
output_dtype,
1.0,
0,
"add",
binary_alpha,
"none",
unary_post_op_args,
post_op_algo,
)
return qy
with torch.no_grad():
model = Mod()
y_refe = model(qx)
y_test = torch.compile(model)(qx)
self.assertEqual(y_refe.float(), y_test.float())
@skipIfNoONEDNN
def test_qlinear_fp8_inductor_cpu(self):
qlinear_op = torch.ops.onednn.qlinear_pointwise.default
self._test_qlinear_fp8_inductor_cpu_helper(qlinear_op, "none")
@skipIfNoONEDNN
def test_qlinear_add_fp8_inductor_cpu(self):
qlinear_op = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_fp8_inductor_cpu_helper(qlinear_op, "add")
def _qlinear_dequant_promotion_test_helper(
self,
inputs,

View File

@ -72,7 +72,13 @@ def _get_pattern_output_dtype(match: Match):
output_node = pattern_output_nodes[0]
assert isinstance(output_node, torch.fx.Node)
output_dtype = output_node.meta["val"].dtype
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
assert output_dtype in [
torch.int8,
torch.uint8,
torch.float32,
torch.bfloat16,
torch.float8_e4m3fn,
]
return output_dtype

View File

@ -675,8 +675,8 @@ def register_onednn_fusion_ops():
algorithm,
layout=None,
):
assert packed_weight.get_dtype() is torch.int8, (
"Only int8 weights are supported by oneDNN qlinear."
assert packed_weight.get_dtype() in [torch.int8, torch.float8_e4m3fn], (
"Only int8 and e4m3fn weights are supported by oneDNN qlinear."
)
x_size = x.get_size()
if len(x_size) > 2:

View File

@ -2789,7 +2789,13 @@ if torch._C._has_mkldnn:
output_shape = list(x.shape)
# The weight has been transposed during the qlinear weight prepack process.
output_shape[-1] = w.shape[1]
assert output_dtype in [torch.float32, torch.bfloat16, torch.int8, torch.uint8]
assert output_dtype in [
torch.float32,
torch.bfloat16,
torch.int8,
torch.uint8,
torch.float8_e4m3fn,
]
out = x.new_empty(output_shape, dtype=output_dtype)
return out
@ -2820,7 +2826,13 @@ if torch._C._has_mkldnn:
output_shape = list(x.shape)
# The weight has been transposed during the qlinear weight prepack process.
output_shape[-1] = w.shape[1]
assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8]
assert output_dtype in [
torch.float32,
torch.bfloat16,
torch.uint8,
torch.int8,
torch.float8_e4m3fn,
]
out = x.new_empty(output_shape, dtype=output_dtype)
return out