mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[torchao] Support quantization configs using module swap (#21982)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
@ -507,6 +507,10 @@ steps:
|
||||
commands:
|
||||
# temporary install here since we need nightly, will move to requirements/test.in
|
||||
# after torchao 0.12 release, and pin a working version of torchao nightly here
|
||||
|
||||
# since torchao nightly is only compatible with torch nightly currently
|
||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||
# we can only upgrade after this is resolved
|
||||
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
||||
|
||||
|
@ -75,5 +75,25 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
|
||||
print(output)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||
@pytest.mark.skip(
|
||||
reason="since torchao nightly is only compatible with torch nightly"
|
||||
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
|
||||
"torchao tests that requires newer versions (0.14.0.dev+) for now")
|
||||
def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
|
||||
torch._dynamo.reset()
|
||||
model_name = ("torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2"
|
||||
"-0.14.0.dev")
|
||||
with vllm_runner(model_name=model_name,
|
||||
quantization="torchao",
|
||||
dtype="bfloat16",
|
||||
pt_load_map_location="cuda:0") as llm:
|
||||
output = llm.generate_greedy(["The capital of France is"],
|
||||
max_tokens=32)
|
||||
|
||||
assert output
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
@ -157,13 +157,15 @@ def torchao_quantize_param_data(param: torch.Tensor,
|
||||
end up setting it to param.
|
||||
"""
|
||||
with torch.device("meta"):
|
||||
dummy_linear = torch.nn.Linear(param.shape[1],
|
||||
param.shape[0],
|
||||
bias=False)
|
||||
# linear can't be top level module since quantize_ is inplace
|
||||
# while some of our configs need to do module swap, and only non-top
|
||||
# level modules support module swap
|
||||
dummy_linear = torch.nn.Sequential(
|
||||
torch.nn.Linear(param.shape[1], param.shape[0], bias=False))
|
||||
|
||||
dummy_linear.weight = param
|
||||
dummy_linear[0].weight = param
|
||||
quantize_(dummy_linear, torchao_config)
|
||||
return dummy_linear.weight
|
||||
return dummy_linear[0].weight
|
||||
|
||||
|
||||
class TorchAOLinearMethod(LinearMethodBase):
|
||||
|
Reference in New Issue
Block a user