Compare commits

...

2 Commits

2 changed files with 13 additions and 4 deletions

View File

@ -224,8 +224,13 @@ class Llama4TextConfig(PretrainedConfig):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
<TODO>
<TODO>
no_rope_layers (`int`, *optional*): TODO
no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO
no_rope_layers (`List[int]`, *optional*):
List with at least the same length as the number of layers in the model.
A `1` at an index position indicates that the corresponding layer will use RoPE,
while a `0` indicates that it's a NoPE layer.
no_rope_layer_interval (`int`, *optional*, defaults to 4):
If `no_rope_layers` is `None`, it will be created using a NoPE layer every
`no_rope_layer_interval` layers.
attention_chunk_size (`int`, *optional*, defaults to 8192):
<TODO>
attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO
@ -335,11 +340,15 @@ class Llama4TextConfig(PretrainedConfig):
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
# Backwards compatibility
if no_rope_layers == []:
no_rope_layers = None
default_no_rope_layers = [
int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers)
]
# no_rope_layers == [] is invalid as we cannot have 0 layers
self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers
self.interleave_moe_layer_step = interleave_moe_layer_step

View File

@ -397,7 +397,7 @@ class Llama4TextDecoderLayer(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Llama4TextAttention(config, layer_idx)
self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
self.use_chunked_attention = bool(config.no_rope_layers[layer_idx])
self.is_moe_layer = layer_idx in config.moe_layers
if self.is_moe_layer: # the 128E model interleaves dense / sparse
self.feed_forward = Llama4TextMoe(config)