This commit is contained in:
Arthur
2025-11-10 14:30:27 +01:00
parent 86a4e51647
commit 09bcd2ee11
5 changed files with 42 additions and 60 deletions

View File

@ -1702,8 +1702,8 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
# We can't initialize the model on meta device as some weights are modified during the initialization
_no_split_modules = None
_tied_weights_keys = {
r"bbox_embed.(\d+)": "bbox_embed.0",
r"class_embed.(\d+)": "class_embed.0",
r"bbox_embed.(?![0])\d+": "bbox_embed.0",
r"class_embed.(?![0])\d+": "class_embed.0",
}
def __init__(self, config: DeformableDetrConfig):
@ -1733,17 +1733,9 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
]
)
if config.with_box_refine:
self._tied_weights_keys.update(
{
"model.decoder.bbox_embed": "bbox_embed",
}
)
self._tied_weights_keys["model.decoder.bbox_embed"] = "bbox_embed"
if config.two_stage:
self._tied_weights_keys.update(
{
"model.decoder.class_embed": "class_embed",
}
)
self._tied_weights_keys["model.decoder.class_embed"] = "class_embed"
self.post_init()
@auto_docstring

View File

@ -2413,35 +2413,32 @@ def build_text_mask(logits, attention_mask):
class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
# the bbox_embed in the decoder are all clones though
_tied_weights_keys = {"bbox_embed": "model.decoder.bbox_embed"}
_tied_weights_keys = {
"model.decoder.bbox_embed":"bbox_embed",
"model.decoder.class_embed":"class_embed",
r"class_embed.(?![0])\d+": "class_embed.0",
}
def __init__(self, config: GroundingDinoConfig):
super().__init__(config)
self.model = GroundingDinoModel(config)
_class_embed = GroundingDinoContrastiveEmbedding(config)
if config.decoder_bbox_embed_share:
# a single shared instance
shared_head = GroundingDinoMLPPredictionHead(
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
)
self.bbox_embed = nn.ModuleList([shared_head] * config.decoder_layers)
else:
# each layer has its own head (implicit deep copy through a new instance)
self.bbox_embed = nn.ModuleList(
[
GroundingDinoMLPPredictionHead(
input_dim=config.d_model,
hidden_dim=config.d_model,
output_dim=4,
num_layers=3,
)
for _ in range(config.decoder_layers)
]
)
self._tied_weights_keys[r"bbox_embed.(?![0])\d+"]= "bbox_embed.0"
self.class_embed = nn.ModuleList([_class_embed for _ in range(config.decoder_layers)])
self.bbox_embed = nn.ModuleList(
[
GroundingDinoMLPPredictionHead(
input_dim=config.d_model,
hidden_dim=config.d_model,
output_dim=4,
num_layers=3,
)
for _ in range(config.decoder_layers)
]
)
self.class_embed = nn.ModuleList([GroundingDinoContrastiveEmbedding(config) for _ in range(config.decoder_layers)])
# hack for box-refinement
self.model.decoder.bbox_embed = self.bbox_embed
# hack implementation for two-stage

View File

@ -26,7 +26,7 @@ from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, logging
from ..deepseek_v3.modeling_deepseek_v3 import (
@ -34,14 +34,13 @@ from ..deepseek_v3.modeling_deepseek_v3 import (
DeepseekV3ForCausalLM,
DeepseekV3MLP,
DeepseekV3Model,
DeepseekV3PreTrainedModel,
DeepseekV3RMSNorm,
DeepseekV3RotaryEmbedding,
DeepseekV3TopkRouter,
apply_rotary_pos_emb_interleave,
eager_attention_forward,
)
from .configuration_longcat_flash import LongcatFlashConfig
logger = logging.get_logger(__name__)
@ -324,7 +323,18 @@ class LongcatFlashDecoderLayer(GradientCheckpointingLayer):
return hidden_states
class LongcatFlashPreTrainedModel(DeepseekV3PreTrainedModel):
@auto_docstring
class LongcatFlashPreTrainedModel(PreTrainedModel):
config: LongcatFlashConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LongcatFlashDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": LongcatFlashDecoderLayer,
"attentions": LongcatFlashMLA,

View File

@ -2388,11 +2388,8 @@ def build_text_mask(logits, attention_mask):
)
class MMGroundingDinoForObjectDetection(MMGroundingDinoPreTrainedModel):
_tied_weights_keys = {
r"bbox_embed\.[1-9]\d*": [
r"model\.decoder\.bbox_embed\.[0-9]\d*",
r"class_embed\.[1-9]\d*",
r"model\.decoder\.class_embed\.[0-9]\d*",
]
r"bbox_embed.(?![0])\d+": "bbox_embed.0",
r"class_embed.(?![0])\d+": "class_embed.0",
}
def __init__(self, config: MMGroundingDinoConfig):
@ -2412,12 +2409,6 @@ class MMGroundingDinoForObjectDetection(MMGroundingDinoPreTrainedModel):
for _ in range(config.decoder_layers)
]
)
# hack for box-refinement
self.model.decoder.bbox_embed = self.bbox_embed
# hack implementation for two-stage
self.model.decoder.class_embed = self.class_embed
# Initialize weights and apply final processing
self.post_init()

View File

@ -399,11 +399,9 @@ class MMGroundingDinoMLPPredictionHead(GroundingDinoMLPPredictionHead):
class MMGroundingDinoForObjectDetection(GroundingDinoForObjectDetection, MMGroundingDinoPreTrainedModel):
_tied_weights_keys = {
r"bbox_embed\.[1-9]\d*": [
r"model\.decoder\.bbox_embed\.[0-9]\d*",
r"class_embed\.[1-9]\d*",
r"model\.decoder\.class_embed\.[0-9]\d*",
]
"model.decoder.bbox_embed":"bbox_embed",
"model.decoder.class_embed":"class_embed",
r"class_embed.(?![0])\d+": "class_embed.0",
}
def __init__(self, config: MMGroundingDinoConfig):
@ -423,12 +421,6 @@ class MMGroundingDinoForObjectDetection(GroundingDinoForObjectDetection, MMGroun
for _ in range(config.decoder_layers)
]
)
# hack for box-refinement
self.model.decoder.bbox_embed = self.bbox_embed
# hack implementation for two-stage
self.model.decoder.class_embed = self.class_embed
# Initialize weights and apply final processing
self.post_init()