Files
transformers/docs/source/en/quantization/fp_quant.md
Benjamin Bossan aee5c2384a DOC Fix typo in argument name: pseudoquant (#41994)
The correct argument name is pseudoquantization. Since there is no error
on passing wrong arguments name (which is arguably an anti-pattern),
this is difficult for users to debug.
2025-11-04 10:48:39 +01:00

3.5 KiB

FP-Quant

FP-Quant is a family of quantization algorithms tailored for the Blackwell generation of Nvidia GPUs. The goal is to allow for efficient post-training quantization (PTQ) and quantization-aware training (QAT) of LLMs in the MXFP4 and NVFP4 data-types.

This integration accompanies the pre-print of the Bridging the Gap Between Promise and Performance for Microscaling FP4 Quantization pre-print.

Currently, only QAT is only supported with pseudoquantization=True. Models can either be quantized on the fly with quantization_config=FPQuantConfig():

from transformers import AutoModelForCausalLM, AutoTokenizer, FPQuantConfig
import torch

model = AutoModelForCausalLM.from_pretrained(
    "qwen/Qwen3-8B",
    quantization_config=FPQuantConfig(),
    device_map="auto",
    dtype=torch.bfloat16,
)

or pre-processed with GPTQ for better quality (see FP Format Quantization Harness).

You can choose between MXFP4 and NVFP4 with FPQuantConfig(forward_dtype="mxfp4"). NVFP4 provides better quality but uses a little more memory.

A Blackwell-generation GPU is required to run the kernels. Runtime support for FP-Quant is implemented through the QuTLASS library and a lightweight PyTorch interface lib fp_quant. We recommend installing the former from source and the latter with pip install fp_quant.

Users without a Blackwell-generation GPU , can use the method with quantization_config=FPQuantConfig(pseudoquantization=True) without having to install QuTLASS. This would provide no speedups but would fully emulate the effect of quantization.

Tip

Find models pre-quantized with FP-Quant in the official ISTA-DASLab collection.

torch.compile

FP-Quant is fully compatible with torch.compile.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, FPQuantConfig

model = AutoModelForCausalLM.from_pretrained(
    "qwen/Qwen3-8B",
    quantization_config=FPQuantConfig(),
    device_map="auto",
    dtype=torch.bfloat16,
)

model.forward = torch.compile(model.forward, mode="max-autotune", fullgraph=True)

Speedups

FP-Quant currently performs best for very large batch size processing.

See QuTLASS README for speedups.