🚨 🚨 🚨 [Breaking change] Deformable DETR intermediate representations (#19678)

* [Breaking change] Deformable DETR intermediate representations

- Fixes naturally the `object-detection` pipeline.
- Moves from `[n_decoders, batch_size, ...]` to `[batch_size,
  n_decoders, ...]` instead.

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2022-10-18 15:00:39 +02:00
committed by GitHub
parent fd99ce3329
commit 713eab45d3
2 changed files with 19 additions and 23 deletions

View File

@ -135,9 +135,9 @@ class DeformableDetrDecoderOutput(ModelOutput):
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
Stacked intermediate hidden states (output of each layer of the decoder).
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
Stacked intermediate reference points (reference points of each layer of the decoder).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
@ -171,9 +171,9 @@ class DeformableDetrModelOutput(ModelOutput):
Initial reference points sent through the Transformer decoder.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`):
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
Stacked intermediate hidden states (output of each layer of the decoder).
intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`):
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
Stacked intermediate reference points (reference points of each layer of the decoder).
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
@ -266,9 +266,9 @@ class DeformableDetrObjectDetectionOutput(ModelOutput):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_heads, 4,
4)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average
in the self-attention heads.
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`):
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
Stacked intermediate hidden states (output of each layer of the decoder).
intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`):
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
Stacked intermediate reference points (reference points of each layer of the decoder).
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Initial reference points sent through the Transformer decoder.
@ -1390,8 +1390,9 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
intermediate = torch.stack(intermediate)
intermediate_reference_points = torch.stack(intermediate_reference_points)
# Keep batch_size as first dimension
intermediate = torch.stack(intermediate, dim=1)
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
# add hidden states from the last decoder layer
if output_hidden_states:
@ -1913,14 +1914,14 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
outputs_classes = []
outputs_coords = []
for level in range(hidden_states.shape[0]):
for level in range(hidden_states.shape[1]):
if level == 0:
reference = init_reference
else:
reference = inter_references[level - 1]
reference = inter_references[:, level - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.class_embed[level](hidden_states[level])
delta_bbox = self.bbox_embed[level](hidden_states[level])
outputs_class = self.class_embed[level](hidden_states[:, level])
delta_bbox = self.bbox_embed[level](hidden_states[:, level])
if reference.shape[-1] == 4:
outputs_coord_logits = delta_bbox + reference
elif reference.shape[-1] == 2:
@ -1931,11 +1932,12 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
outputs_coord = outputs_coord_logits.sigmoid()
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
outputs_class = torch.stack(outputs_classes)
outputs_coord = torch.stack(outputs_coords)
# Keep batch_size as first dimension
outputs_class = torch.stack(outputs_classes, dim=1)
outputs_coord = torch.stack(outputs_coords, dim=1)
logits = outputs_class[-1]
pred_boxes = outputs_coord[-1]
logits = outputs_class[:, -1]
pred_boxes = outputs_coord[:, -1]
loss, loss_dict, auxiliary_outputs = None, None, None
if labels is not None:
@ -2000,8 +2002,8 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
intermediate_hidden_states=outputs.intermediate_hidden_states,
init_reference_points=outputs.init_reference_points,
intermediate_reference_points=outputs.intermediate_reference_points,
init_reference_points=outputs.init_reference_points,
enc_outputs_class=outputs.enc_outputs_class,
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
)

View File

@ -44,12 +44,6 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
def get_test_pipeline(self, model, tokenizer, feature_extractor):
if model.__class__.__name__ == "DeformableDetrForObjectDetection":
self.skipTest(
"""Deformable DETR requires a custom CUDA kernel.
"""
)
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]