[torchao] Support quantization configs using module swap (#21982)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang
2025-09-10 23:53:24 -07:00
committed by GitHub
parent d13360183a
commit 2048c4e379
3 changed files with 33 additions and 7 deletions

View File

@ -507,6 +507,10 @@ steps:
commands:
# temporary install here since we need nightly, will move to requirements/test.in
# after torchao 0.12 release, and pin a working version of torchao nightly here
# since torchao nightly is only compatible with torch nightly currently
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
# we can only upgrade after this is resolved
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization

View File

@ -75,5 +75,25 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
print(output)
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now")
def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
torch._dynamo.reset()
model_name = ("torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2"
"-0.14.0.dev")
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0") as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)
assert output
print(output)
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -157,13 +157,15 @@ def torchao_quantize_param_data(param: torch.Tensor,
end up setting it to param.
"""
with torch.device("meta"):
dummy_linear = torch.nn.Linear(param.shape[1],
param.shape[0],
bias=False)
# linear can't be top level module since quantize_ is inplace
# while some of our configs need to do module swap, and only non-top
# level modules support module swap
dummy_linear = torch.nn.Sequential(
torch.nn.Linear(param.shape[1], param.shape[0], bias=False))
dummy_linear.weight = param
dummy_linear[0].weight = param
quantize_(dummy_linear, torchao_config)
return dummy_linear.weight
return dummy_linear[0].weight
class TorchAOLinearMethod(LinearMethodBase):