[Bugfix] Fix dynamic FP8 quantization for Mixtral (#4793)

This commit is contained in:
Philipp Moritz
2024-05-13 16:00:27 -07:00
committed by GitHub
parent 1356df53bd
commit 33d3914b1e

View File

@ -95,7 +95,7 @@ class MixtralMoE(nn.Module):
params_dtype=self.params_dtype,
quant_config=None)
if self.use_fp8:
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter(