mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Medusa supports custom bias (#10361)
This commit is contained in:
@ -14,11 +14,14 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, num_layers: int) -> None:
|
||||
def __init__(self, config: VllmConfig, hidden_size: int,
|
||||
num_layers: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
nn.Linear(hidden_size,
|
||||
hidden_size,
|
||||
bias=getattr(config, "medusa_fc_bias", False))
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.act = nn.SiLU()
|
||||
@ -49,7 +52,8 @@ class Medusa(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.blocks = nn.ModuleList([
|
||||
ResidualBlock(hidden_size=self.config.hidden_size,
|
||||
ResidualBlock(config=config,
|
||||
hidden_size=self.config.hidden_size,
|
||||
num_layers=self.config.num_hidden_layers)
|
||||
for _ in range(self.config.num_heads)
|
||||
])
|
||||
|
Reference in New Issue
Block a user