[1/N][Feat] Add weight prefetch feature for Attention layers (#3146)

### What this PR does / why we need it?

- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention
modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency

### Does this PR introduce _any_ user-facing change?

Add a new config in `--additional-config` for configuration:
```json
{
    "weight_prefetch_config": {
        "enabled": false,
        "prefetch_ratio": {
            "attn": {
                "qkv": 1.0,
                "o": 1.0,
            },
        },
    },
}
```
This feature is enabled by default, and can be disabled through this
configuration

### How was this patch tested?


- vLLM version: v0.11.0

---------

Signed-off-by: yuzhup <15705211260@163.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Co-authored-by: yuzhup <15705211260@163.com>
This commit is contained in:
Ruri
2025-10-09 20:38:39 +08:00
committed by GitHub
parent 23db56a340
commit ff37575936
13 changed files with 264 additions and 69 deletions

View File

@ -25,12 +25,12 @@ LLM(model="Qwen/Qwen3-8B", additional_config={"config_key":"config_value"})
The following table lists the additional configuration options available in vLLM Ascend:
| Name | Type | Default | Description |
|-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------|
|-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------|
| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode |
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
| `weight_prefetch_config` | dict | `{}` | The config options for weight prefetch |
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. |
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |
@ -71,6 +71,13 @@ The details of each config option are as follows:
ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.
**weight_prefetch_config**
| Name | Type | Default | Description |
|------------------|------|------------------------------------|------------------------------------|
| `enabled` | bool | `False` | Whether to enable weight prefetch. |
| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}}` | Prefetch ratio of each weights. |
### Example
An example of additional configuration is as follows:
@ -90,6 +97,15 @@ An example of additional configuration is as follows:
"max_long_partial_prefills": 1,
"long_prefill_token_threshold": 4096,
},
"weight_prefetch_config": {
"enabled": True,
"prefetch_ratio": {
"attn": {
"qkv": 1.0,
"o": 1.0,
},
},
},
"multistream_overlap_shared_expert": True,
"refresh": False,
}

View File

@ -495,7 +495,7 @@ class TestAscendMLAImpl(TestBase):
mock_up_proj.assert_called_once()
mock_npu_fused_infer_attention_score.assert_called_once()
@patch("vllm_ascend.attention.mla_v1.npu_prefetch")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch):
magic_npu_fetch.return_value = MagicMock()
batch_size = 4

View File

@ -68,16 +68,23 @@ class TestAscendW8A8LinearMethod(TestBase):
self.assertEqual(params['weight_scale'].shape, (10, 1))
self.assertEqual(params['weight_offset'].shape, (10, 1))
@patch("vllm_ascend.quantization.w8a8.get_forward_context")
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul,
mock_quant_per_tensor):
mock_quant_per_tensor,
mock_get_forward_context):
layer = MagicMock()
layer.aclnn_input_scale = 0.1
layer.aclnn_input_offset = 0.2
layer.weight = torch.randn(128, 256)
layer.deq_scale = 0.3
mock_forward_context = MagicMock()
mock_get_forward_context.return_value = mock_forward_context
mock_weight_prefetch_method = MagicMock()
mock_forward_context.weight_prefetch_method = mock_weight_prefetch_method
x = torch.randn(32, 128)
bias = torch.randn(256)
mock_quant_per_tensor.return_value = torch.randint(-128,

View File

@ -45,6 +45,12 @@ class AscendConfig:
"ascend_scheduler_config", {})
self.ascend_scheduler_config = AscendSchedulerConfig(
ascend_scheduler_config)
weight_prefetch_config = additional_config.get(
"weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(
weight_prefetch_config)
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
self.expert_map_path = additional_config.get("expert_map_path", None)
self.expert_map_record_path = additional_config.get(
@ -65,7 +71,6 @@ class AscendConfig:
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
self.multistream_overlap_shared_expert = additional_config.get(
"multistream_overlap_shared_expert", False)
self.enable_prefetch = additional_config.get("enable_prefetch", False)
self.lmhead_tensor_parallel_size = additional_config.get(
"lmhead_tensor_parallel_size", None)
if self.lmhead_tensor_parallel_size is not None:
@ -185,6 +190,24 @@ class AscendSchedulerConfig:
setattr(self, k, v)
class WeightPrefetchConfig:
"""
Configuration Object for weight_prefetch_config from additional_config
"""
prefetch_ratio: dict = {
"attn": {
"qkv": 1.0,
"o": 1.0,
},
}
def __init__(self, weight_prefetch_config: dict):
self.enabled = weight_prefetch_config.get("enabled", False)
self.prefetch_ratio = weight_prefetch_config.get(
"prefetch_ratio", self.prefetch_ratio)
_ASCEND_CONFIG: Optional[AscendConfig] = None

View File

@ -1,7 +1,7 @@
import math
from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.config import CUDAGraphMode, VllmConfig
@ -13,6 +13,11 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp
if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
else:
WeightPrefetchMethod = None
class FusedMoEState(Enum):
AllGather = 0
@ -65,7 +70,8 @@ def set_ascend_forward_context(
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None):
model_instance: torch.nn.Module = None,
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
@ -127,6 +133,7 @@ def set_ascend_forward_context(
hasattr(model_instance.model, "start_layer"):
forward_context.layer_idx = model_instance.model.start_layer
# TODO(rjg-lyh): refactor mlp weight prefetch method
# set for mlp weight prefetch
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
@ -138,6 +145,8 @@ def set_ascend_forward_context(
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
# TODO(yuzhup): integrate moe weight prefetch method
forward_context.weight_prefetch_method = weight_prefetch_method
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
# It will be improved later by implementing operator fusion through the FX graph.

View File

@ -24,7 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.utils import npu_prefetch
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
@ -493,7 +493,7 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.enable_prefetch
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
vllm_config = get_current_vllm_config()
@ -877,8 +877,8 @@ class AscendMLAImpl(MLAAttentionImpl):
num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
if self.q_a_proj is not None:
npu_prefetch(self.q_a_proj.weight,
hidden_states,
maybe_npu_prefetch(inputs=self.q_a_proj.weight,
dependency=hidden_states,
enabled=self.enable_prefetch)
ckq = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(ckq)
@ -1005,8 +1005,8 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
if current_ms_metadata is None:
npu_prefetch(self.o_proj.weight,
o_proj_input,
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
@ -1016,8 +1016,8 @@ class AscendMLAImpl(MLAAttentionImpl):
is_force_scatter=self.enable_shared_expert_dp)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
npu_prefetch(self.o_proj.weight,
o_proj_input,
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(

View File

@ -11,6 +11,8 @@ from vllm.utils import direct_register_custom_op
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
def _maybe_chunk_residual_impl(x: torch.Tensor,
@ -148,6 +150,33 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return
def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor,
max_weight_size: int) -> None:
calculation_stream = torch_npu.npu.current_stream()
weight_prefetch_stream = prefetch_stream()
weight_prefetch_stream.wait_stream(calculation_stream)
with npu_stream_switch(weight_prefetch_stream):
maybe_npu_prefetch(inputs=weight,
dependency=start_flag,
max_size=max_weight_size)
def _prefetch_preprocess_impl_fake(weight: torch.Tensor,
start_flag: torch.Tensor,
max_weight_size: int) -> None:
return
def _prefetch_postprocess_impl(stop_flag: torch.Tensor) -> None:
calculation_stream = torch_npu.npu.current_stream()
weight_prefetch_stream = prefetch_stream()
calculation_stream.wait_stream(weight_prefetch_stream)
def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
return
def _maybe_all_reduce_tensor_model_parallel_impl(
final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context()
@ -194,6 +223,18 @@ direct_register_custom_op(op_name="maybe_wait_prefetch_done",
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="prefetch_preprocess",
op_func=_prefetch_preprocess_impl,
fake_impl=_prefetch_preprocess_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="prefetch_postprocess",
op_func=_prefetch_postprocess_impl,
fake_impl=_prefetch_postprocess_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
fake_impl=lambda x: x,

View File

@ -0,0 +1,75 @@
from dataclasses import dataclass, field
import torch
import torch_npu
from vllm_ascend.ascend_config import WeightPrefetchConfig
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
@dataclass
class ModuleWeightPrefetchConfig:
module_name: str
enable: bool = False
prefetch_ratio: dict = field(default_factory=dict)
def __post_init__(self) -> None:
self.prefetch_ratio = {
prefix: ratio
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
}
assert self.module_name in SUPPORTED_MODULES, (
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
)
if self.module_name in SUPPORTED_MODULES:
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
class WeightPrefetchMethod:
"""
Unified weight prefetch method.
"""
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
self.attn = ModuleWeightPrefetchConfig(
module_name="attn",
enable=weight_prefetch_config.enabled,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"attn", {}))
def maybe_prefetch_attn_weight_preprocess(
self, prefix: str, weight: torch.Tensor,
start_flag: torch.Tensor) -> None:
if not self.attn.enable:
return
weight_size = weight.data.element_size() * weight.data.numel(
) * self.attn.prefetch_ratio.get(prefix, 0)
torch.ops.vllm.prefetch_preprocess(weight=weight,
start_flag=start_flag,
max_weight_size=int(weight_size))
def maybe_prefetch_attn_weight_postprocess(
self, stop_flag: torch.Tensor) -> None:
if not self.attn.enable:
return
torch.ops.vllm.prefetch_postprocess(stop_flag)
def maybe_npu_prefetch(inputs: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
offset: int = 0,
*,
enabled: bool = True) -> None:
if not enabled:
return
input_size = inputs.element_size() * inputs.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)

View File

@ -21,6 +21,7 @@ import torch
import torch_npu
from vllm.attention.backends.abstract import AttentionType
from vllm.distributed.parallel_state import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.moe.experts_selector import select_experts
@ -97,11 +98,32 @@ class AscendW8A8LinearMethod:
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
if x.dtype != torch.int8:
attn_weight_map = {
"AscendQKVParallelLinear": "qkv",
"AscendRowParallelLinear": "o",
}
layer_cls_name = layer.__class__.__name__
weight_prefetch_method = get_forward_context(
).weight_prefetch_method
assert weight_prefetch_method is not None
# prefetch_qkvo_proj.weight preprocess
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
prefix=attn_weight_map.get(layer_cls_name, ""),
weight=layer.weight,
start_flag=x,
)
# quant
x = quant_per_tensor(
x,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
)
# prefetch_qkvo_proj.weight postprocess
if layer_cls_name in attn_weight_map.keys():
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
x)
quant_bias = layer.quant_bias if tp_rank == 0 else None
if is_310p():
# On 300I Duo platform, we need transpose again if

View File

@ -70,11 +70,12 @@ from vllm.sequence import IntermediateTensors
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.sfa import Indexer
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
TorchairAscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor, npu_prefetch, oproj_tp_enable
from vllm_ascend.utils import dispose_tensor, oproj_tp_enable
class TorchairDeepseekV2SiluAndMul(SiluAndMul):
@ -589,7 +590,7 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
and attn_metadata.num_decodes > 0)
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
if self.q_lora_rank is not None:
npu_prefetch(self.q_a_proj.weight,
maybe_npu_prefetch(self.q_a_proj.weight,
hidden_states,
enabled=enable_multistream_mla)
ckq = self.q_a_proj(hidden_states)[0]

View File

@ -23,9 +23,9 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
npu_stream_switch, npu_wait_tensor)
from vllm_ascend.utils import npu_prefetch
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
@ -684,7 +684,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
if hasattr(self, "running_in_graph") and not self.running_in_graph:
return x
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
npu_prefetch(self.o_proj.weight,
maybe_npu_prefetch(self.o_proj.weight,
x,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
@ -1281,7 +1281,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
if current_ms_metadata is None:
npu_prefetch(self.o_proj.weight,
maybe_npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
@ -1292,7 +1292,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
is_force_scatter=self.enable_shared_expert_dp)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
npu_prefetch(self.o_proj.weight,
maybe_npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)

View File

@ -51,6 +51,7 @@ _CUSTOM_OP_ENABLED = None
_IS_310P = None
_SLEEP_MODE_ENABLED = None
_CURRENT_STREAM = None
_PREFETCH_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
@ -241,6 +242,15 @@ def current_stream() -> torch.npu.Stream:
return _CURRENT_STREAM
def prefetch_stream() -> torch.npu.Stream:
global _PREFETCH_STREAM
if _PREFETCH_STREAM is None:
# when this function is called before any stream is set,
# we return the default stream.
_PREFETCH_STREAM = torch_npu.npu.Stream()
return _PREFETCH_STREAM
def adapt_patch(is_global_patch: bool = False):
if is_global_patch:
from vllm_ascend.patch import platform # noqa: F401
@ -446,20 +456,6 @@ class ProfileExecuteDuration:
return durations
# TODO(wxy): Move to ops module
def npu_prefetch(input: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
*,
enabled: bool = True):
if not enabled:
return
input_size = input.element_size() * input.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch_npu.npu_prefetch(input, dependency, max_size)
# TODO(ttanzhiqiang): rm_router_logits
# dp>1 will trigger
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.

View File

@ -113,6 +113,7 @@ from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
@ -285,6 +286,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
else:
self.chunked_prefill_enabled = True
self.weight_prefetch_method = WeightPrefetchMethod(
self.ascend_config.weight_prefetch_config)
if self.cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
@ -1856,7 +1859,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_actual_tokens=scheduler_output.
total_num_scheduled_tokens,
prefetch_stream=self.prefetch_stream,
model_instance=self.model):
model_instance=self.model,
weight_prefetch_method=self.weight_prefetch_method):
self.maybe_setup_kv_connector(scheduler_output)
hidden_states = self._generate_process_reqs_hidden_states(
@ -2370,7 +2374,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
prefetch_stream=self.prefetch_stream,
model_instance=self.model):
model_instance=self.model,
weight_prefetch_method=self.weight_prefetch_method):
hidden_states = self._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors,