efficientdet_head.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn.bricks import Swish, build_norm_layer
  6. from mmengine.model import bias_init_with_prob
  7. from torch import Tensor
  8. from mmdet.models.dense_heads.anchor_head import AnchorHead
  9. from mmdet.models.utils import images_to_levels, multi_apply
  10. from mmdet.registry import MODELS
  11. from mmdet.structures.bbox import cat_boxes, get_box_tensor
  12. from mmdet.utils import (InstanceList, OptConfigType, OptInstanceList,
  13. OptMultiConfig, reduce_mean)
  14. from .utils import DepthWiseConvBlock
  15. @MODELS.register_module()
  16. class EfficientDetSepBNHead(AnchorHead):
  17. """EfficientDetHead with separate BN.
  18. num_classes (int): Number of categories num_ins (int): Number of the input
  19. feature map. in_channels (int): Number of channels in the input feature
  20. map. feat_channels (int): Number of hidden channels. stacked_convs (int):
  21. Number of repetitions of conv norm_cfg (dict): Config dict for
  22. normalization layer. anchor_generator (dict): Config dict for anchor
  23. generator bbox_coder (dict): Config of bounding box coder. loss_cls (dict):
  24. Config of classification loss. loss_bbox (dict): Config of localization
  25. loss. train_cfg (dict): Training config of anchor head. test_cfg (dict):
  26. Testing config of anchor head. init_cfg (dict or list[dict], optional):
  27. Initialization config dict.
  28. """
  29. def __init__(self,
  30. num_classes: int,
  31. num_ins: int,
  32. in_channels: int,
  33. feat_channels: int,
  34. stacked_convs: int = 3,
  35. norm_cfg: OptConfigType = dict(
  36. type='BN', momentum=1e-2, eps=1e-3),
  37. init_cfg: OptMultiConfig = None,
  38. **kwargs) -> None:
  39. self.num_ins = num_ins
  40. self.stacked_convs = stacked_convs
  41. self.norm_cfg = norm_cfg
  42. super().__init__(
  43. num_classes=num_classes,
  44. in_channels=in_channels,
  45. feat_channels=feat_channels,
  46. init_cfg=init_cfg,
  47. **kwargs)
  48. def _init_layers(self) -> None:
  49. """Initialize layers of the head."""
  50. self.reg_conv_list = nn.ModuleList()
  51. self.cls_conv_list = nn.ModuleList()
  52. for i in range(self.stacked_convs):
  53. channels = self.in_channels if i == 0 else self.feat_channels
  54. self.reg_conv_list.append(
  55. DepthWiseConvBlock(
  56. channels, self.feat_channels, apply_norm=False))
  57. self.cls_conv_list.append(
  58. DepthWiseConvBlock(
  59. channels, self.feat_channels, apply_norm=False))
  60. self.reg_bn_list = nn.ModuleList([
  61. nn.ModuleList([
  62. build_norm_layer(
  63. self.norm_cfg, num_features=self.feat_channels)[1]
  64. for j in range(self.num_ins)
  65. ]) for i in range(self.stacked_convs)
  66. ])
  67. self.cls_bn_list = nn.ModuleList([
  68. nn.ModuleList([
  69. build_norm_layer(
  70. self.norm_cfg, num_features=self.feat_channels)[1]
  71. for j in range(self.num_ins)
  72. ]) for i in range(self.stacked_convs)
  73. ])
  74. self.cls_header = DepthWiseConvBlock(
  75. self.in_channels,
  76. self.num_base_priors * self.cls_out_channels,
  77. apply_norm=False)
  78. self.reg_header = DepthWiseConvBlock(
  79. self.in_channels, self.num_base_priors * 4, apply_norm=False)
  80. self.swish = Swish()
  81. def init_weights(self) -> None:
  82. """Initialize weights of the head."""
  83. for m in self.reg_conv_list:
  84. nn.init.constant_(m.pointwise_conv.bias, 0.0)
  85. for m in self.cls_conv_list:
  86. nn.init.constant_(m.pointwise_conv.bias, 0.0)
  87. bias_cls = bias_init_with_prob(0.01)
  88. nn.init.constant_(self.cls_header.pointwise_conv.bias, bias_cls)
  89. nn.init.constant_(self.reg_header.pointwise_conv.bias, 0.0)
  90. def forward_single_bbox(self, feat: Tensor, level_id: int,
  91. i: int) -> Tensor:
  92. conv_op = self.reg_conv_list[i]
  93. bn = self.reg_bn_list[i][level_id]
  94. feat = conv_op(feat)
  95. feat = bn(feat)
  96. feat = self.swish(feat)
  97. return feat
  98. def forward_single_cls(self, feat: Tensor, level_id: int,
  99. i: int) -> Tensor:
  100. conv_op = self.cls_conv_list[i]
  101. bn = self.cls_bn_list[i][level_id]
  102. feat = conv_op(feat)
  103. feat = bn(feat)
  104. feat = self.swish(feat)
  105. return feat
  106. def forward(self, feats: Tuple[Tensor]) -> tuple:
  107. cls_scores = []
  108. bbox_preds = []
  109. for level_id in range(self.num_ins):
  110. feat = feats[level_id]
  111. for i in range(self.stacked_convs):
  112. feat = self.forward_single_bbox(feat, level_id, i)
  113. bbox_pred = self.reg_header(feat)
  114. bbox_preds.append(bbox_pred)
  115. for level_id in range(self.num_ins):
  116. feat = feats[level_id]
  117. for i in range(self.stacked_convs):
  118. feat = self.forward_single_cls(feat, level_id, i)
  119. cls_score = self.cls_header(feat)
  120. cls_scores.append(cls_score)
  121. return cls_scores, bbox_preds
  122. def loss_by_feat(
  123. self,
  124. cls_scores: List[Tensor],
  125. bbox_preds: List[Tensor],
  126. batch_gt_instances: InstanceList,
  127. batch_img_metas: List[dict],
  128. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  129. """Calculate the loss based on the features extracted by the detection
  130. head.
  131. Args:
  132. cls_scores (list[Tensor]): Box scores for each scale level
  133. has shape (N, num_anchors * num_classes, H, W).
  134. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  135. level with shape (N, num_anchors * 4, H, W).
  136. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  137. gt_instance. It usually includes ``bboxes`` and ``labels``
  138. attributes.
  139. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  140. image size, scaling factor, etc.
  141. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  142. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  143. data that is ignored during training and testing.
  144. Defaults to None.
  145. Returns:
  146. dict: A dictionary of loss components.
  147. """
  148. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  149. assert len(featmap_sizes) == self.prior_generator.num_levels
  150. device = cls_scores[0].device
  151. anchor_list, valid_flag_list = self.get_anchors(
  152. featmap_sizes, batch_img_metas, device=device)
  153. cls_reg_targets = self.get_targets(
  154. anchor_list,
  155. valid_flag_list,
  156. batch_gt_instances,
  157. batch_img_metas,
  158. batch_gt_instances_ignore=batch_gt_instances_ignore)
  159. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  160. avg_factor) = cls_reg_targets
  161. # anchor number of multi levels
  162. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  163. # concat all level anchors and flags to a single tensor
  164. concat_anchor_list = []
  165. for i in range(len(anchor_list)):
  166. concat_anchor_list.append(cat_boxes(anchor_list[i]))
  167. all_anchor_list = images_to_levels(concat_anchor_list,
  168. num_level_anchors)
  169. avg_factor = reduce_mean(
  170. torch.tensor(avg_factor, dtype=torch.float, device=device)).item()
  171. avg_factor = max(avg_factor, 1.0)
  172. losses_cls, losses_bbox = multi_apply(
  173. self.loss_by_feat_single,
  174. cls_scores,
  175. bbox_preds,
  176. all_anchor_list,
  177. labels_list,
  178. label_weights_list,
  179. bbox_targets_list,
  180. bbox_weights_list,
  181. avg_factor=avg_factor)
  182. return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
  183. def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
  184. anchors: Tensor, labels: Tensor,
  185. label_weights: Tensor, bbox_targets: Tensor,
  186. bbox_weights: Tensor, avg_factor: int) -> tuple:
  187. """Calculate the loss of a single scale level based on the features
  188. extracted by the detection head.
  189. Args:
  190. cls_score (Tensor): Box scores for each scale level
  191. Has shape (N, num_anchors * num_classes, H, W).
  192. bbox_pred (Tensor): Box energies / deltas for each scale
  193. level with shape (N, num_anchors * 4, H, W).
  194. anchors (Tensor): Box reference for each scale level with shape
  195. (N, num_total_anchors, 4).
  196. labels (Tensor): Labels of each anchors with shape
  197. (N, num_total_anchors).
  198. label_weights (Tensor): Label weights of each anchor with shape
  199. (N, num_total_anchors)
  200. bbox_targets (Tensor): BBox regression targets of each anchor
  201. weight shape (N, num_total_anchors, 4).
  202. bbox_weights (Tensor): BBox regression loss weights of each anchor
  203. with shape (N, num_total_anchors, 4).
  204. avg_factor (int): Average factor that is used to average the loss.
  205. Returns:
  206. tuple: loss components.
  207. """
  208. # classification loss
  209. labels = labels.reshape(-1)
  210. label_weights = label_weights.reshape(-1)
  211. cls_score = cls_score.permute(0, 2, 3,
  212. 1).reshape(-1, self.cls_out_channels)
  213. loss_cls = self.loss_cls(
  214. cls_score, labels, label_weights, avg_factor=avg_factor)
  215. # regression loss
  216. target_dim = bbox_targets.size(-1)
  217. bbox_targets = bbox_targets.reshape(-1, target_dim)
  218. bbox_weights = bbox_weights.reshape(-1, target_dim)
  219. bbox_pred = bbox_pred.permute(0, 2, 3,
  220. 1).reshape(-1,
  221. self.bbox_coder.encode_size)
  222. if self.reg_decoded_bbox:
  223. # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
  224. # is applied directly on the decoded bounding boxes, it
  225. # decodes the already encoded coordinates to absolute format.
  226. anchors = anchors.reshape(-1, anchors.size(-1))
  227. bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
  228. bbox_pred = get_box_tensor(bbox_pred)
  229. loss_bbox = self.loss_bbox(
  230. bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor * 4)
  231. return loss_cls, loss_bbox