mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Inductor][Float8] Add float8_e4m3fn into assertion dtype list. (#157684)
Fix assert issue. Add float8_e4m3fn into dtype list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157684 Approved by: https://github.com/Xia-Weiwen, https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
3341c131b7
commit
c8c221c0b3
@ -1118,8 +1118,9 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||
if(is_fp8 && !cpuinfo_has_x86_amx_int8()) {
|
||||
#endif
|
||||
// Fall back to ref impl on old platforms because not supported
|
||||
// Transpose weight to align with behavior in oneDNN
|
||||
return fp8_qlinear_onednn_ref(
|
||||
input, input_scale, onednn_weight, weight_scales, bias,
|
||||
input, input_scale, onednn_weight.t(), weight_scales, bias,
|
||||
output_scale, output_dtype, other, other_scale,
|
||||
binary_post_op, binary_alpha, unary_post_op,
|
||||
unary_post_op_args, unary_post_op_algorithm);
|
||||
|
@ -305,11 +305,12 @@ static inline at::Tensor pack_weight_to_onednn_tensor(
|
||||
#if defined(__powerpc__)
|
||||
if (is_fp8){
|
||||
#else
|
||||
if(is_fp8 && !cpuinfo_has_x86_amx_int8()) {
|
||||
if(is_fp8 && !cpuinfo_has_x86_amx_int8()) {
|
||||
#endif
|
||||
// oneDNN's fp8 requires AMX support
|
||||
// If AMX is not available, fall back to reference implementation
|
||||
return weight;
|
||||
// Transpose weight to align with behavior in oneDNN
|
||||
return weight.t();
|
||||
}
|
||||
std::vector<int64_t> w_dims = weight.sizes().vec();
|
||||
auto w_data_type = is_fp8
|
||||
|
@ -2952,6 +2952,104 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
is_dynamic=is_dynamic,
|
||||
)
|
||||
|
||||
def _test_qlinear_fp8_inductor_cpu_helper(self, qlinear_op, post_op="none"):
|
||||
dtype = torch.float8_e4m3fn
|
||||
qlinear_prepack = torch.ops.onednn.qlinear_prepack
|
||||
post_op_algo = "none"
|
||||
unary_post_op_args = ()
|
||||
batch_size = 1
|
||||
output_dtype = torch.float8_e4m3fn
|
||||
y_scale, y_zp = 0.07, 0
|
||||
ic = 4
|
||||
oc = 16
|
||||
|
||||
torch._dynamo.reset()
|
||||
used_y_scale = y_scale
|
||||
used_y_zp = y_zp
|
||||
x = torch.rand(batch_size, ic)
|
||||
w = torch.rand(oc, ic)
|
||||
qx = x.to(dtype)
|
||||
qw = w.to(dtype)
|
||||
x_scale = 0.5
|
||||
w_scales = torch.randn(oc)
|
||||
b = torch.rand(oc)
|
||||
|
||||
x_zp = 0
|
||||
w_zps = torch.zeros_like(w_scales, dtype=torch.int)
|
||||
|
||||
if post_op == "none":
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.qw_packed = qlinear_prepack(qw, x.shape)
|
||||
|
||||
def forward(self, qx):
|
||||
qy = qlinear_op(
|
||||
qx,
|
||||
x_scale,
|
||||
x_zp,
|
||||
self.qw_packed,
|
||||
w_scales,
|
||||
w_zps,
|
||||
b,
|
||||
used_y_scale,
|
||||
used_y_zp,
|
||||
output_dtype,
|
||||
post_op,
|
||||
unary_post_op_args,
|
||||
post_op_algo,
|
||||
)
|
||||
return qy
|
||||
|
||||
elif post_op == "add":
|
||||
x2 = torch.rand(batch_size, oc)
|
||||
binary_alpha = 1.0 # we only support alpha=1.0 now
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.qw_packed = qlinear_prepack(qw, x.shape)
|
||||
|
||||
def forward(self, qx):
|
||||
qy = qlinear_op(
|
||||
qx,
|
||||
x_scale,
|
||||
x_zp,
|
||||
self.qw_packed,
|
||||
w_scales,
|
||||
w_zps,
|
||||
x2,
|
||||
b,
|
||||
used_y_scale,
|
||||
used_y_zp,
|
||||
output_dtype,
|
||||
1.0,
|
||||
0,
|
||||
"add",
|
||||
binary_alpha,
|
||||
"none",
|
||||
unary_post_op_args,
|
||||
post_op_algo,
|
||||
)
|
||||
return qy
|
||||
|
||||
with torch.no_grad():
|
||||
model = Mod()
|
||||
y_refe = model(qx)
|
||||
y_test = torch.compile(model)(qx)
|
||||
self.assertEqual(y_refe.float(), y_test.float())
|
||||
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_fp8_inductor_cpu(self):
|
||||
qlinear_op = torch.ops.onednn.qlinear_pointwise.default
|
||||
self._test_qlinear_fp8_inductor_cpu_helper(qlinear_op, "none")
|
||||
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_add_fp8_inductor_cpu(self):
|
||||
qlinear_op = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_fp8_inductor_cpu_helper(qlinear_op, "add")
|
||||
|
||||
def _qlinear_dequant_promotion_test_helper(
|
||||
self,
|
||||
inputs,
|
||||
|
@ -72,7 +72,13 @@ def _get_pattern_output_dtype(match: Match):
|
||||
output_node = pattern_output_nodes[0]
|
||||
assert isinstance(output_node, torch.fx.Node)
|
||||
output_dtype = output_node.meta["val"].dtype
|
||||
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
|
||||
assert output_dtype in [
|
||||
torch.int8,
|
||||
torch.uint8,
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
]
|
||||
return output_dtype
|
||||
|
||||
|
||||
|
@ -675,8 +675,8 @@ def register_onednn_fusion_ops():
|
||||
algorithm,
|
||||
layout=None,
|
||||
):
|
||||
assert packed_weight.get_dtype() is torch.int8, (
|
||||
"Only int8 weights are supported by oneDNN qlinear."
|
||||
assert packed_weight.get_dtype() in [torch.int8, torch.float8_e4m3fn], (
|
||||
"Only int8 and e4m3fn weights are supported by oneDNN qlinear."
|
||||
)
|
||||
x_size = x.get_size()
|
||||
if len(x_size) > 2:
|
||||
|
@ -2789,7 +2789,13 @@ if torch._C._has_mkldnn:
|
||||
output_shape = list(x.shape)
|
||||
# The weight has been transposed during the qlinear weight prepack process.
|
||||
output_shape[-1] = w.shape[1]
|
||||
assert output_dtype in [torch.float32, torch.bfloat16, torch.int8, torch.uint8]
|
||||
assert output_dtype in [
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.int8,
|
||||
torch.uint8,
|
||||
torch.float8_e4m3fn,
|
||||
]
|
||||
out = x.new_empty(output_shape, dtype=output_dtype)
|
||||
return out
|
||||
|
||||
@ -2820,7 +2826,13 @@ if torch._C._has_mkldnn:
|
||||
output_shape = list(x.shape)
|
||||
# The weight has been transposed during the qlinear weight prepack process.
|
||||
output_shape[-1] = w.shape[1]
|
||||
assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8]
|
||||
assert output_dtype in [
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
]
|
||||
out = x.new_empty(output_shape, dtype=output_dtype)
|
||||
return out
|
||||
|
||||
|
Reference in New Issue
Block a user