mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
### What this PR does / why we need it?
Follow vllm-project/vllm lint way:
https://github.com/vllm-project/vllm/blob/main/.pre-commit-config.yaml
Enable pre-commit to avoid some low level error AMAP.
This pr is one step of #1241, The purpose is make linting system more
clear and convenient, on this step, Mainly did the following things:
yapf, actionlint, ruff, typos, isort, mypy, png-lint, signoff-commit,
enforce-import-regex-instead-of-re.
TODO:
- clang-format(check for csrc with google style)
need clean code, disable for now
- pymarkdown
need clean code, disable for now
- shellcheck
need clean code, disable for now
### Does this PR introduce _any_ user-facing change?
Only developer UX change:
https://vllm-ascend--1256.org.readthedocs.build/en/1256/developer_guide/contributing.html#run-lint-locally
```
pip install -r requirements-lint.txt && pre-commit install
bash format.sh
```
### How was this patch tested?
CI passed with new added/existing test.
Co-authored-by: Yikun [yikunkero@gmail.com](mailto:yikunkero@gmail.com)
Co-authored-by: wangli
[wangli858794774@gmail.com](mailto:wangli858794774@gmail.com)
- vLLM version: v0.9.1
- vLLM main:
5358cce5ff
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
159 lines
4.5 KiB
Python
159 lines
4.5 KiB
Python
from typing import Tuple
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
import torch_npu # noqa: F401
|
|
import vllm # noqa: F401
|
|
|
|
import vllm_ascend.platform # noqa: F401
|
|
|
|
|
|
def benchmark_npu(fn, num_iterations=100, num_warmup_iterations=50):
|
|
"""
|
|
Benchmark function for NPU operations
|
|
|
|
Args:
|
|
fn: Function to benchmark
|
|
num_iterations: Number of timing iterations
|
|
num_warmup_iterations: Number of warmup iterations
|
|
|
|
Returns:
|
|
float: Minimum elapsed time in seconds
|
|
"""
|
|
start = torch.npu.Event(enable_timing=True)
|
|
end = torch.npu.Event(enable_timing=True)
|
|
times = np.zeros(num_iterations + num_warmup_iterations)
|
|
|
|
# Run iterations
|
|
for i in range(num_warmup_iterations + num_iterations):
|
|
with torch.no_grad():
|
|
start.record()
|
|
fn() # Execute the function
|
|
end.record()
|
|
torch.npu.synchronize()
|
|
times[i] = start.elapsed_time(end)
|
|
|
|
# Remove warmup iterations and convert to seconds
|
|
times = times[num_warmup_iterations:]
|
|
elapsed_time = np.amin(times) / 1000
|
|
return elapsed_time
|
|
|
|
|
|
def get_masked_input_and_mask_ref(
|
|
input_: torch.Tensor,
|
|
org_vocab_start_index: int,
|
|
org_vocab_end_index: int,
|
|
num_org_vocab_padding: int,
|
|
added_vocab_start_index: int,
|
|
added_vocab_end_index: int,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Reference implementation for verification"""
|
|
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
|
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
|
input_ < added_vocab_end_index
|
|
)
|
|
added_offset = (
|
|
added_vocab_start_index
|
|
- (org_vocab_end_index - org_vocab_start_index)
|
|
- num_org_vocab_padding
|
|
)
|
|
valid_offset = (org_vocab_start_index * org_vocab_mask) + (
|
|
added_offset * added_vocab_mask
|
|
)
|
|
vocab_mask = org_vocab_mask | added_vocab_mask
|
|
masked_input = vocab_mask * (input_ - valid_offset)
|
|
return masked_input, ~vocab_mask
|
|
|
|
|
|
DTYPES = [torch.int32]
|
|
SHAPES = [(3, 4, 5)]
|
|
DEVICES = [f"npu:{0}"]
|
|
SEEDS = [0]
|
|
|
|
|
|
@pytest.mark.parametrize("shape", SHAPES)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("device", DEVICES)
|
|
@pytest.mark.parametrize("seed", SEEDS)
|
|
@torch.inference_mode()
|
|
def test_get_masked_input_and_mask(
|
|
shape: Tuple[int, ...],
|
|
dtype: torch.dtype,
|
|
device: str,
|
|
seed: int,
|
|
) -> None:
|
|
# Set random seed and device
|
|
torch.manual_seed(seed)
|
|
torch.set_default_device(device)
|
|
|
|
# Generate random input tensor
|
|
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
|
|
|
|
# Test parameters
|
|
test_case = {
|
|
"org_start": 100,
|
|
"org_end": 200,
|
|
"padding": 0,
|
|
"added_start": 300,
|
|
"added_end": 400,
|
|
}
|
|
|
|
# Define reference function
|
|
def ref_fn():
|
|
return get_masked_input_and_mask_ref(
|
|
input_tensor,
|
|
test_case["org_start"],
|
|
test_case["org_end"],
|
|
test_case["padding"],
|
|
test_case["added_start"],
|
|
test_case["added_end"],
|
|
)
|
|
|
|
# Define custom function
|
|
def custom_fn():
|
|
return torch.ops._C.get_masked_input_and_mask(
|
|
input_tensor,
|
|
test_case["org_start"],
|
|
test_case["org_end"],
|
|
test_case["padding"],
|
|
test_case["added_start"],
|
|
test_case["added_end"],
|
|
)
|
|
|
|
# Get results for correctness testing
|
|
ref_masked_input, ref_mask = ref_fn()
|
|
custom_masked_input, custom_mask = custom_fn()
|
|
|
|
# Benchmark both implementations
|
|
ref_time = benchmark_npu(ref_fn)
|
|
custom_time = benchmark_npu(custom_fn)
|
|
|
|
# Print performance results
|
|
print("\nPerformance Results:")
|
|
print(f"Reference implementation: {ref_time * 1000:.3f} ms")
|
|
print(f"Custom implementation: {custom_time * 1000:.3f} ms")
|
|
print(f"Speedup: {ref_time / custom_time:.2f}x")
|
|
|
|
# Compare results for correctness
|
|
ref_masked_input = ref_masked_input.to(dtype)
|
|
print("\nResults comparison:")
|
|
print("custom_masked_input:", custom_masked_input)
|
|
print("ref_masked_input:", ref_masked_input)
|
|
print("custom_mask:", custom_mask)
|
|
print("ref_mask:", ref_mask)
|
|
torch.testing.assert_close(
|
|
custom_masked_input,
|
|
ref_masked_input,
|
|
rtol=1e-5,
|
|
atol=1e-5,
|
|
msg=f"Masked input mismatch for case: {test_case}",
|
|
)
|
|
torch.testing.assert_close(
|
|
custom_mask,
|
|
ref_mask,
|
|
rtol=1e-5,
|
|
atol=1e-5,
|
|
msg=f"Mask mismatch for case: {test_case}",
|
|
)
|