mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156077 Approved by: https://github.com/Skylion007, https://github.com/malfet ghstack dependencies: #156069
958 lines
34 KiB
Python
958 lines
34 KiB
Python
# Taken from https://github.com/pytorch/vision
|
|
# So that we don't need torchvision to be installed
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.jit.annotations import Dict
|
|
from torch.nn import functional as F
|
|
|
|
|
|
try:
|
|
from scipy.optimize import linear_sum_assignment
|
|
|
|
scipy_available = True
|
|
except Exception:
|
|
scipy_available = False
|
|
|
|
|
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
|
"""3x3 convolution with padding"""
|
|
return nn.Conv2d(
|
|
in_planes,
|
|
out_planes,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=dilation,
|
|
groups=groups,
|
|
bias=False,
|
|
dilation=dilation,
|
|
)
|
|
|
|
|
|
def conv1x1(in_planes, out_planes, stride=1):
|
|
"""1x1 convolution"""
|
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
|
|
|
|
|
class BasicBlock(nn.Module):
|
|
expansion = 1
|
|
|
|
def __init__(
|
|
self,
|
|
inplanes,
|
|
planes,
|
|
stride=1,
|
|
downsample=None,
|
|
groups=1,
|
|
base_width=64,
|
|
dilation=1,
|
|
norm_layer=None,
|
|
):
|
|
super().__init__()
|
|
if norm_layer is None:
|
|
norm_layer = nn.BatchNorm2d
|
|
if groups != 1 or base_width != 64:
|
|
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
|
if dilation > 1:
|
|
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
|
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
|
self.bn1 = norm_layer(planes)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.conv2 = conv3x3(planes, planes)
|
|
self.bn2 = norm_layer(planes)
|
|
self.downsample = downsample
|
|
self.stride = stride
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
if self.downsample is not None:
|
|
identity = self.downsample(x)
|
|
|
|
out += identity
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
class Bottleneck(nn.Module):
|
|
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
|
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
|
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
|
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
|
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
|
|
|
expansion = 4
|
|
|
|
def __init__(
|
|
self,
|
|
inplanes,
|
|
planes,
|
|
stride=1,
|
|
downsample=None,
|
|
groups=1,
|
|
base_width=64,
|
|
dilation=1,
|
|
norm_layer=None,
|
|
):
|
|
super().__init__()
|
|
if norm_layer is None:
|
|
norm_layer = nn.BatchNorm2d
|
|
width = int(planes * (base_width / 64.0)) * groups
|
|
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
|
self.conv1 = conv1x1(inplanes, width)
|
|
self.bn1 = norm_layer(width)
|
|
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
|
self.bn2 = norm_layer(width)
|
|
self.conv3 = conv1x1(width, planes * self.expansion)
|
|
self.bn3 = norm_layer(planes * self.expansion)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.downsample = downsample
|
|
self.stride = stride
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv3(out)
|
|
out = self.bn3(out)
|
|
|
|
if self.downsample is not None:
|
|
identity = self.downsample(x)
|
|
|
|
out += identity
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
class ResNet(nn.Module):
|
|
def __init__(
|
|
self,
|
|
block,
|
|
layers,
|
|
num_classes=1000,
|
|
zero_init_residual=False,
|
|
groups=1,
|
|
width_per_group=64,
|
|
replace_stride_with_dilation=None,
|
|
norm_layer=None,
|
|
):
|
|
super().__init__()
|
|
if norm_layer is None:
|
|
norm_layer = nn.BatchNorm2d
|
|
self._norm_layer = norm_layer
|
|
|
|
self.inplanes = 64
|
|
self.dilation = 1
|
|
if replace_stride_with_dilation is None:
|
|
# each element in the tuple indicates if we should replace
|
|
# the 2x2 stride with a dilated convolution instead
|
|
replace_stride_with_dilation = [False, False, False]
|
|
if len(replace_stride_with_dilation) != 3:
|
|
raise ValueError(
|
|
"replace_stride_with_dilation should be None "
|
|
f"or a 3-element tuple, got {replace_stride_with_dilation}"
|
|
)
|
|
self.groups = groups
|
|
self.base_width = width_per_group
|
|
self.conv1 = nn.Conv2d(
|
|
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
|
|
)
|
|
self.bn1 = norm_layer(self.inplanes)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
self.layer2 = self._make_layer(
|
|
block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
|
|
)
|
|
self.layer3 = self._make_layer(
|
|
block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
|
|
)
|
|
self.layer4 = self._make_layer(
|
|
block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
|
|
)
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
# Zero-initialize the last BN in each residual branch,
|
|
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
|
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
|
if zero_init_residual:
|
|
for m in self.modules():
|
|
if isinstance(m, Bottleneck):
|
|
nn.init.constant_(m.bn3.weight, 0)
|
|
elif isinstance(m, BasicBlock):
|
|
nn.init.constant_(m.bn2.weight, 0)
|
|
|
|
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
|
norm_layer = self._norm_layer
|
|
downsample = None
|
|
previous_dilation = self.dilation
|
|
if dilate:
|
|
self.dilation *= stride
|
|
stride = 1
|
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
downsample = nn.Sequential(
|
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
|
norm_layer(planes * block.expansion),
|
|
)
|
|
|
|
layers = []
|
|
layers.append(
|
|
block(
|
|
self.inplanes,
|
|
planes,
|
|
stride,
|
|
downsample,
|
|
self.groups,
|
|
self.base_width,
|
|
previous_dilation,
|
|
norm_layer,
|
|
)
|
|
)
|
|
self.inplanes = planes * block.expansion
|
|
for _ in range(1, blocks):
|
|
layers.append(
|
|
block(
|
|
self.inplanes,
|
|
planes,
|
|
groups=self.groups,
|
|
base_width=self.base_width,
|
|
dilation=self.dilation,
|
|
norm_layer=norm_layer,
|
|
)
|
|
)
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def _forward_impl(self, x):
|
|
# See note [TorchScript super()]
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.maxpool(x)
|
|
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
x = self.layer4(x)
|
|
|
|
x = self.avgpool(x)
|
|
x = torch.flatten(x, 1)
|
|
x = self.fc(x)
|
|
|
|
return x
|
|
|
|
def forward(self, x):
|
|
return self._forward_impl(x)
|
|
|
|
|
|
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
|
model = ResNet(block, layers, **kwargs)
|
|
# if pretrained:
|
|
# state_dict = load_state_dict_from_url(model_urls[arch],
|
|
# progress=progress)
|
|
# model.load_state_dict(state_dict)
|
|
return model
|
|
|
|
|
|
def resnet18(pretrained=False, progress=True, **kwargs):
|
|
r"""ResNet-18 model from
|
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
|
Args:
|
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
progress (bool): If True, displays a progress bar of the download to stderr
|
|
"""
|
|
return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
|
|
|
|
|
def resnet50(pretrained=False, progress=True, **kwargs):
|
|
r"""ResNet-50 model from
|
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
|
Args:
|
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
progress (bool): If True, displays a progress bar of the download to stderr
|
|
"""
|
|
return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
|
|
|
|
|
|
class IntermediateLayerGetter(nn.ModuleDict):
|
|
"""
|
|
Module wrapper that returns intermediate layers from a model
|
|
It has a strong assumption that the modules have been registered
|
|
into the model in the same order as they are used.
|
|
This means that one should **not** reuse the same nn.Module
|
|
twice in the forward if you want this to work.
|
|
Additionally, it is only able to query submodules that are directly
|
|
assigned to the model. So if `model` is passed, `model.feature1` can
|
|
be returned, but not `model.feature1.layer2`.
|
|
Args:
|
|
model (nn.Module): model on which we will extract the features
|
|
return_layers (Dict[name, new_name]): a dict containing the names
|
|
of the modules for which the activations will be returned as
|
|
the key of the dict, and the value of the dict is the name
|
|
of the returned activation (which the user can specify).
|
|
Examples::
|
|
>>> m = torchvision.models.resnet18(pretrained=True)
|
|
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
|
|
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
|
|
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
|
|
>>> out = new_m(torch.rand(1, 3, 224, 224))
|
|
>>> print([(k, v.shape) for k, v in out.items()])
|
|
>>> [('feat1', torch.Size([1, 64, 56, 56])),
|
|
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
|
|
"""
|
|
|
|
_version = 2
|
|
__annotations__ = {
|
|
"return_layers": Dict[str, str],
|
|
}
|
|
|
|
def __init__(self, model, return_layers):
|
|
if not set(return_layers).issubset(
|
|
[name for name, _ in model.named_children()]
|
|
):
|
|
raise ValueError("return_layers are not present in model")
|
|
orig_return_layers = return_layers
|
|
return_layers = {str(k): str(v) for k, v in return_layers.items()}
|
|
layers = OrderedDict()
|
|
for name, module in model.named_children():
|
|
layers[name] = module
|
|
if name in return_layers:
|
|
del return_layers[name]
|
|
if not return_layers:
|
|
break
|
|
|
|
super().__init__(layers)
|
|
self.return_layers = orig_return_layers
|
|
|
|
def forward(self, x):
|
|
out = OrderedDict()
|
|
for name, module in self.items():
|
|
x = module(x)
|
|
if name in self.return_layers:
|
|
out_name = self.return_layers[name]
|
|
out[out_name] = x
|
|
return out
|
|
|
|
|
|
class _SimpleSegmentationModel(nn.Module):
|
|
__constants__ = ["aux_classifier"]
|
|
|
|
def __init__(self, backbone, classifier, aux_classifier=None):
|
|
super().__init__()
|
|
self.backbone = backbone
|
|
self.classifier = classifier
|
|
self.aux_classifier = aux_classifier
|
|
|
|
def forward(self, x):
|
|
input_shape = x.shape[-2:]
|
|
# contract: features is a dict of tensors
|
|
features = self.backbone(x)
|
|
|
|
result = OrderedDict()
|
|
x = features["out"]
|
|
x = self.classifier(x)
|
|
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
|
result["out"] = x
|
|
|
|
if self.aux_classifier is not None:
|
|
x = features["aux"]
|
|
x = self.aux_classifier(x)
|
|
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
|
result["aux"] = x
|
|
|
|
return result
|
|
|
|
|
|
class FCN(_SimpleSegmentationModel):
|
|
"""
|
|
Implements a Fully-Convolutional Network for semantic segmentation.
|
|
Args:
|
|
backbone (nn.Module): the network used to compute the features for the model.
|
|
The backbone should return an OrderedDict[Tensor], with the key being
|
|
"out" for the last feature map used, and "aux" if an auxiliary classifier
|
|
is used.
|
|
classifier (nn.Module): module that takes the "out" element returned from
|
|
the backbone and returns a dense prediction.
|
|
aux_classifier (nn.Module, optional): auxiliary classifier used during training
|
|
"""
|
|
|
|
|
|
class FCNHead(nn.Sequential):
|
|
def __init__(self, in_channels, channels):
|
|
inter_channels = in_channels // 4
|
|
layers = [
|
|
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
|
|
nn.BatchNorm2d(inter_channels),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
nn.Conv2d(inter_channels, channels, 1),
|
|
]
|
|
|
|
super().__init__(*layers)
|
|
|
|
|
|
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
|
|
# backbone = resnet.__dict__[backbone_name](
|
|
# pretrained=pretrained_backbone,
|
|
# replace_stride_with_dilation=[False, True, True])
|
|
# Hardcoded resnet 50
|
|
assert backbone_name == "resnet50"
|
|
backbone = resnet50(
|
|
pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]
|
|
)
|
|
|
|
return_layers = {"layer4": "out"}
|
|
if aux:
|
|
return_layers["layer3"] = "aux"
|
|
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
|
|
|
aux_classifier = None
|
|
if aux:
|
|
inplanes = 1024
|
|
aux_classifier = FCNHead(inplanes, num_classes)
|
|
|
|
model_map = {
|
|
# 'deeplabv3': (DeepLabHead, DeepLabV3), # Not used
|
|
"fcn": (FCNHead, FCN),
|
|
}
|
|
inplanes = 2048
|
|
classifier = model_map[name][0](inplanes, num_classes)
|
|
base_model = model_map[name][1]
|
|
|
|
model = base_model(backbone, classifier, aux_classifier)
|
|
return model
|
|
|
|
|
|
def _load_model(
|
|
arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs
|
|
):
|
|
if pretrained:
|
|
aux_loss = True
|
|
model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
|
|
# if pretrained:
|
|
# arch = arch_type + '_' + backbone + '_coco'
|
|
# model_url = model_urls[arch]
|
|
# if model_url is None:
|
|
# raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
|
|
# else:
|
|
# state_dict = load_state_dict_from_url(model_url, progress=progress)
|
|
# model.load_state_dict(state_dict)
|
|
return model
|
|
|
|
|
|
def fcn_resnet50(
|
|
pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs
|
|
):
|
|
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
|
|
Args:
|
|
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
|
|
contains the same classes as Pascal VOC
|
|
progress (bool): If True, displays a progress bar of the download to stderr
|
|
"""
|
|
return _load_model(
|
|
"fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs
|
|
)
|
|
|
|
|
|
# Taken from @fmassa example slides and https://github.com/facebookresearch/detr
|
|
class DETR(nn.Module):
|
|
"""
|
|
Demo DETR implementation.
|
|
|
|
Demo implementation of DETR in minimal number of lines, with the
|
|
following differences wrt DETR in the paper:
|
|
* learned positional encoding (instead of sine)
|
|
* positional encoding is passed at input (instead of attention)
|
|
* fc bbox predictor (instead of MLP)
|
|
The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
|
|
Only batch size 1 supported.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_classes,
|
|
hidden_dim=256,
|
|
nheads=8,
|
|
num_encoder_layers=6,
|
|
num_decoder_layers=6,
|
|
):
|
|
super().__init__()
|
|
|
|
# create ResNet-50 backbone
|
|
self.backbone = resnet50()
|
|
del self.backbone.fc
|
|
|
|
# create conversion layer
|
|
self.conv = nn.Conv2d(2048, hidden_dim, 1)
|
|
|
|
# create a default PyTorch transformer
|
|
self.transformer = nn.Transformer(
|
|
hidden_dim, nheads, num_encoder_layers, num_decoder_layers
|
|
)
|
|
|
|
# prediction heads, one extra class for predicting non-empty slots
|
|
# note that in baseline DETR linear_bbox layer is 3-layer MLP
|
|
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
|
|
self.linear_bbox = nn.Linear(hidden_dim, 4)
|
|
|
|
# output positional encodings (object queries)
|
|
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
|
|
|
|
# spatial positional encodings
|
|
# note that in baseline DETR we use sine positional encodings
|
|
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
|
|
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
|
|
|
|
def forward(self, inputs):
|
|
# propagate inputs through ResNet-50 up to avg-pool layer
|
|
x = self.backbone.conv1(inputs)
|
|
x = self.backbone.bn1(x)
|
|
x = self.backbone.relu(x)
|
|
x = self.backbone.maxpool(x)
|
|
|
|
x = self.backbone.layer1(x)
|
|
x = self.backbone.layer2(x)
|
|
x = self.backbone.layer3(x)
|
|
x = self.backbone.layer4(x)
|
|
|
|
# convert from 2048 to 256 feature planes for the transformer
|
|
h = self.conv(x)
|
|
|
|
# construct positional encodings
|
|
H, W = h.shape[-2:]
|
|
pos = (
|
|
torch.cat(
|
|
[
|
|
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
|
|
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
|
|
],
|
|
dim=-1,
|
|
)
|
|
.flatten(0, 1)
|
|
.unsqueeze(1)
|
|
)
|
|
|
|
# propagate through the transformer
|
|
# TODO (alband) Why this is not automatically broadcasted? (had to add the repeat)
|
|
f = pos + 0.1 * h.flatten(2).permute(2, 0, 1)
|
|
s = self.query_pos.unsqueeze(1)
|
|
s = s.expand(s.size(0), inputs.size(0), s.size(2))
|
|
h = self.transformer(f, s).transpose(0, 1)
|
|
|
|
# finally project transformer outputs to class labels and bounding boxes
|
|
return {
|
|
"pred_logits": self.linear_class(h),
|
|
"pred_boxes": self.linear_bbox(h).sigmoid(),
|
|
}
|
|
|
|
|
|
def generalized_box_iou(boxes1, boxes2):
|
|
"""
|
|
Generalized IoU from https://giou.stanford.edu/
|
|
The boxes should be in [x0, y0, x1, y1] format
|
|
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
|
and M = len(boxes2)
|
|
"""
|
|
# degenerate boxes gives inf / nan results
|
|
# so do an early check
|
|
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
|
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
|
iou, union = box_iou(boxes1, boxes2)
|
|
|
|
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
|
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
|
|
|
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
|
area = wh[:, :, 0] * wh[:, :, 1]
|
|
|
|
return iou - (area - union) / area
|
|
|
|
|
|
def box_cxcywh_to_xyxy(x):
|
|
x_c, y_c, w, h = x.unbind(-1)
|
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
|
return torch.stack(b, dim=-1)
|
|
|
|
|
|
def box_area(boxes):
|
|
"""
|
|
Computes the area of a set of bounding boxes, which are specified by its
|
|
(x1, y1, x2, y2) coordinates.
|
|
Args:
|
|
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
|
|
are expected to be in (x1, y1, x2, y2) format
|
|
Returns:
|
|
area (Tensor[N]): area for each box
|
|
"""
|
|
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
|
|
|
|
|
# modified from torchvision to also return the union
|
|
def box_iou(boxes1, boxes2):
|
|
area1 = box_area(boxes1)
|
|
area2 = box_area(boxes2)
|
|
|
|
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
|
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
|
|
|
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
|
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
|
|
|
union = area1[:, None] + area2 - inter
|
|
|
|
iou = inter / union
|
|
return iou, union
|
|
|
|
|
|
def is_dist_avail_and_initialized():
|
|
return False
|
|
|
|
|
|
def get_world_size():
|
|
if not is_dist_avail_and_initialized():
|
|
return 1
|
|
|
|
|
|
@torch.no_grad()
|
|
def accuracy(output, target, topk=(1,)):
|
|
"""Computes the precision@k for the specified values of k"""
|
|
if target.numel() == 0:
|
|
return [torch.zeros([], device=output.device)]
|
|
maxk = max(topk)
|
|
batch_size = target.size(0)
|
|
|
|
_, pred = output.topk(maxk, 1, True, True)
|
|
pred = pred.t()
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
res = []
|
|
for k in topk:
|
|
correct_k = correct[:k].view(-1).float().sum(0)
|
|
res.append(correct_k.mul_(100.0 / batch_size))
|
|
return res
|
|
|
|
|
|
class SetCriterion(nn.Module):
|
|
"""This class computes the loss for DETR.
|
|
The process happens in two steps:
|
|
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
|
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
|
"""
|
|
|
|
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
|
|
"""Create the criterion.
|
|
Parameters:
|
|
num_classes: number of object categories, omitting the special no-object category
|
|
matcher: module able to compute a matching between targets and proposals
|
|
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
|
eos_coef: relative classification weight applied to the no-object category
|
|
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
|
"""
|
|
super().__init__()
|
|
self.num_classes = num_classes
|
|
self.matcher = matcher
|
|
self.weight_dict = weight_dict
|
|
self.eos_coef = eos_coef
|
|
self.losses = losses
|
|
empty_weight = torch.ones(self.num_classes + 1)
|
|
empty_weight[-1] = self.eos_coef
|
|
self.register_buffer("empty_weight", empty_weight)
|
|
|
|
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
|
"""Classification loss (NLL)
|
|
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
|
"""
|
|
assert "pred_logits" in outputs
|
|
src_logits = outputs["pred_logits"]
|
|
|
|
idx = self._get_src_permutation_idx(indices)
|
|
target_classes_o = torch.cat(
|
|
[t["labels"][J] for t, (_, J) in zip(targets, indices)]
|
|
)
|
|
target_classes = torch.full(
|
|
src_logits.shape[:2],
|
|
self.num_classes,
|
|
dtype=torch.int64,
|
|
device=src_logits.device,
|
|
)
|
|
target_classes[idx] = target_classes_o
|
|
|
|
loss_ce = F.cross_entropy(
|
|
src_logits.transpose(1, 2), target_classes, self.empty_weight
|
|
)
|
|
losses = {"loss_ce": loss_ce}
|
|
|
|
if log:
|
|
# TODO this should probably be a separate loss, not hacked in this one here
|
|
losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
|
|
return losses
|
|
|
|
@torch.no_grad()
|
|
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
|
"""Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
|
|
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
|
|
"""
|
|
pred_logits = outputs["pred_logits"]
|
|
device = pred_logits.device
|
|
tgt_lengths = torch.as_tensor(
|
|
[len(v["labels"]) for v in targets], device=device
|
|
)
|
|
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
|
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
|
|
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
|
|
losses = {"cardinality_error": card_err}
|
|
return losses
|
|
|
|
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
|
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
|
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
|
The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
|
|
"""
|
|
assert "pred_boxes" in outputs
|
|
idx = self._get_src_permutation_idx(indices)
|
|
src_boxes = outputs["pred_boxes"][idx]
|
|
target_boxes = torch.cat(
|
|
[t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0
|
|
)
|
|
|
|
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
|
|
|
|
losses = {}
|
|
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
|
|
|
loss_giou = 1 - torch.diag(
|
|
generalized_box_iou(
|
|
box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes)
|
|
)
|
|
)
|
|
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
|
return losses
|
|
|
|
def loss_masks(self, outputs, targets, indices, num_boxes):
|
|
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
|
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
|
"""
|
|
assert "pred_masks" in outputs
|
|
|
|
src_idx = self._get_src_permutation_idx(indices)
|
|
tgt_idx = self._get_tgt_permutation_idx(indices)
|
|
|
|
src_masks = outputs["pred_masks"]
|
|
|
|
# TODO use valid to mask invalid areas due to padding in loss
|
|
target_masks, valid = nested_tensor_from_tensor_list( # noqa: F821
|
|
[t["masks"] for t in targets]
|
|
).decompose()
|
|
target_masks = target_masks.to(src_masks)
|
|
|
|
src_masks = src_masks[src_idx]
|
|
# upsample predictions to the target size
|
|
src_masks = interpolate( # noqa: F821
|
|
src_masks[:, None],
|
|
size=target_masks.shape[-2:],
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
src_masks = src_masks[:, 0].flatten(1)
|
|
|
|
target_masks = target_masks[tgt_idx].flatten(1)
|
|
|
|
losses = {
|
|
"loss_mask": sigmoid_focal_loss( # noqa: F821
|
|
src_masks, target_masks, num_boxes
|
|
), # noqa: F821
|
|
"loss_dice": dice_loss(src_masks, target_masks, num_boxes), # noqa: F821
|
|
}
|
|
return losses
|
|
|
|
def _get_src_permutation_idx(self, indices):
|
|
# permute predictions following indices
|
|
batch_idx = torch.cat(
|
|
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
|
|
)
|
|
src_idx = torch.cat([src for (src, _) in indices])
|
|
return batch_idx, src_idx
|
|
|
|
def _get_tgt_permutation_idx(self, indices):
|
|
# permute targets following indices
|
|
batch_idx = torch.cat(
|
|
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
|
|
)
|
|
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
|
return batch_idx, tgt_idx
|
|
|
|
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
|
loss_map = {
|
|
"labels": self.loss_labels,
|
|
"cardinality": self.loss_cardinality,
|
|
"boxes": self.loss_boxes,
|
|
"masks": self.loss_masks,
|
|
}
|
|
assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
|
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
|
|
|
def forward(self, outputs, targets):
|
|
"""This performs the loss computation.
|
|
Parameters:
|
|
outputs: dict of tensors, see the output specification of the model for the format
|
|
targets: list of dicts, such that len(targets) == batch_size.
|
|
The expected keys in each dict depends on the losses applied, see each loss' doc
|
|
"""
|
|
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
|
|
|
|
# Retrieve the matching between the outputs of the last layer and the targets
|
|
indices = self.matcher(outputs_without_aux, targets)
|
|
|
|
# Compute the average number of target boxes across all nodes, for normalization purposes
|
|
num_boxes = sum(len(t["labels"]) for t in targets)
|
|
num_boxes = torch.as_tensor(
|
|
[num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
|
|
)
|
|
if is_dist_avail_and_initialized():
|
|
torch.distributed.all_reduce(num_boxes)
|
|
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
|
|
|
# Compute all the requested losses
|
|
losses = {}
|
|
for loss in self.losses:
|
|
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
|
|
|
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
|
if "aux_outputs" in outputs:
|
|
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
|
|
indices = self.matcher(aux_outputs, targets)
|
|
for loss in self.losses:
|
|
if loss == "masks":
|
|
# Intermediate masks losses are too costly to compute, we ignore them.
|
|
continue
|
|
kwargs = {}
|
|
if loss == "labels":
|
|
# Logging is enabled only for the last layer
|
|
kwargs = {"log": False}
|
|
l_dict = self.get_loss(
|
|
loss, aux_outputs, targets, indices, num_boxes, **kwargs
|
|
)
|
|
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
|
losses.update(l_dict)
|
|
|
|
return losses
|
|
|
|
|
|
class HungarianMatcher(nn.Module):
|
|
"""This class computes an assignment between the targets and the predictions of the network
|
|
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
|
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
|
while the others are un-matched (and thus treated as non-objects).
|
|
"""
|
|
|
|
def __init__(
|
|
self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1
|
|
):
|
|
"""Creates the matcher
|
|
Params:
|
|
cost_class: This is the relative weight of the classification error in the matching cost
|
|
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
|
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
|
|
"""
|
|
super().__init__()
|
|
self.cost_class = cost_class
|
|
self.cost_bbox = cost_bbox
|
|
self.cost_giou = cost_giou
|
|
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
|
"all costs can't be 0"
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def forward(self, outputs, targets):
|
|
"""Performs the matching
|
|
Params:
|
|
outputs: This is a dict that contains at least these entries:
|
|
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
|
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
|
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
|
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
|
objects in the target) containing the class labels
|
|
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
|
Returns:
|
|
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
|
- index_i is the indices of the selected predictions (in order)
|
|
- index_j is the indices of the corresponding selected targets (in order)
|
|
For each batch element, it holds:
|
|
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
|
"""
|
|
bs, num_queries = outputs["pred_logits"].shape[:2]
|
|
|
|
# We flatten to compute the cost matrices in a batch
|
|
out_prob = (
|
|
outputs["pred_logits"].flatten(0, 1).softmax(-1)
|
|
) # [batch_size * num_queries, num_classes]
|
|
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
|
|
|
# Also concat the target labels and boxes
|
|
tgt_ids = torch.cat([v["labels"] for v in targets])
|
|
tgt_bbox = torch.cat([v["boxes"] for v in targets])
|
|
|
|
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
|
# but approximate it in 1 - proba[target class].
|
|
# The 1 is a constant that doesn't change the matching, it can be omitted.
|
|
cost_class = -out_prob[:, tgt_ids]
|
|
|
|
# Compute the L1 cost between boxes
|
|
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
|
|
|
|
# Compute the giou cost between boxes
|
|
cost_giou = -generalized_box_iou(
|
|
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
|
|
)
|
|
|
|
# Final cost matrix
|
|
C = (
|
|
self.cost_bbox * cost_bbox
|
|
+ self.cost_class * cost_class
|
|
+ self.cost_giou * cost_giou
|
|
)
|
|
C = C.view(bs, num_queries, -1).cpu()
|
|
|
|
sizes = [len(v["boxes"]) for v in targets]
|
|
if not scipy_available:
|
|
raise RuntimeError(
|
|
"The 'detr' model requires scipy to run. Please make sure you have it installed"
|
|
" if you enable the 'detr' model."
|
|
)
|
|
indices = [
|
|
linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))
|
|
]
|
|
return [
|
|
(
|
|
torch.as_tensor(i, dtype=torch.int64),
|
|
torch.as_tensor(j, dtype=torch.int64),
|
|
)
|
|
for i, j in indices
|
|
]
|