123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional, Tuple, Union
- import numpy as np
- import torch
- from mmengine.model import BaseModule
- from torch import Tensor
- from mmdet.registry import MODELS
- from .utils import weighted_loss
- @weighted_loss
- def l2_loss(pred: Tensor, target: Tensor) -> Tensor:
- """L2 loss.
- Args:
- pred (torch.Tensor): The prediction.
- target (torch.Tensor): The learning target of the prediction.
- Returns:
- torch.Tensor: Calculated loss
- """
- assert pred.size() == target.size()
- loss = torch.abs(pred - target)**2
- return loss
- @MODELS.register_module()
- class L2Loss(BaseModule):
- """L2 loss.
- Args:
- reduction (str, optional): The method to reduce the loss.
- Options are "none", "mean" and "sum".
- loss_weight (float, optional): The weight of loss.
- """
- def __init__(self,
- neg_pos_ub: int = -1,
- pos_margin: float = -1,
- neg_margin: float = -1,
- hard_mining: bool = False,
- reduction: str = 'mean',
- loss_weight: float = 1.0):
- super(L2Loss, self).__init__()
- self.neg_pos_ub = neg_pos_ub
- self.pos_margin = pos_margin
- self.neg_margin = neg_margin
- self.hard_mining = hard_mining
- self.reduction = reduction
- self.loss_weight = loss_weight
- def forward(self,
- pred: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- avg_factor: Optional[float] = None,
- reduction_override: Optional[str] = None) -> Tensor:
- """Forward function.
- Args:
- pred (torch.Tensor): The prediction.
- target (torch.Tensor): The learning target of the prediction.
- weight (torch.Tensor, optional): The weight of loss for each
- prediction. Defaults to None.
- avg_factor (float, optional): Average factor that is used to
- average the loss. Defaults to None.
- reduction_override (str, optional): The reduction method used to
- override the original reduction method of the loss.
- Defaults to None.
- """
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- pred, weight, avg_factor = self.update_weight(pred, target, weight,
- avg_factor)
- loss_bbox = self.loss_weight * l2_loss(
- pred, target, weight, reduction=reduction, avg_factor=avg_factor)
- return loss_bbox
- def update_weight(self, pred: Tensor, target: Tensor, weight: Tensor,
- avg_factor: float) -> Tuple[Tensor, Tensor, float]:
- """Update the weight according to targets."""
- if weight is None:
- weight = target.new_ones(target.size())
- invalid_inds = weight <= 0
- target[invalid_inds] = -1
- pos_inds = target == 1
- neg_inds = target == 0
- if self.pos_margin > 0:
- pred[pos_inds] -= self.pos_margin
- if self.neg_margin > 0:
- pred[neg_inds] -= self.neg_margin
- pred = torch.clamp(pred, min=0, max=1)
- num_pos = int((target == 1).sum())
- num_neg = int((target == 0).sum())
- if self.neg_pos_ub > 0 and num_neg / (num_pos +
- 1e-6) > self.neg_pos_ub:
- num_neg = num_pos * self.neg_pos_ub
- neg_idx = torch.nonzero(target == 0, as_tuple=False)
- if self.hard_mining:
- costs = l2_loss(
- pred, target, reduction='none')[neg_idx[:, 0],
- neg_idx[:, 1]].detach()
- neg_idx = neg_idx[costs.topk(num_neg)[1], :]
- else:
- neg_idx = self.random_choice(neg_idx, num_neg)
- new_neg_inds = neg_inds.new_zeros(neg_inds.size()).bool()
- new_neg_inds[neg_idx[:, 0], neg_idx[:, 1]] = True
- invalid_neg_inds = torch.logical_xor(neg_inds, new_neg_inds)
- weight[invalid_neg_inds] = 0
- avg_factor = (weight > 0).sum()
- return pred, weight, avg_factor
- @staticmethod
- def random_choice(gallery: Union[list, np.ndarray, Tensor],
- num: int) -> np.ndarray:
- """Random select some elements from the gallery.
- It seems that Pytorch's implementation is slower than numpy so we use
- numpy to randperm the indices.
- """
- assert len(gallery) >= num
- if isinstance(gallery, list):
- gallery = np.array(gallery)
- cands = np.arange(len(gallery))
- np.random.shuffle(cands)
- rand_inds = cands[:num]
- if not isinstance(gallery, np.ndarray):
- rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
- return gallery[rand_inds]
|