123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import mmengine
- import torch
- from mmengine.structures import InstanceData
- from mmdet.models.tracking_heads import RoIEmbedHead
- from mmdet.registry import TASK_UTILS
- def _dummy_bbox_sampling(rpn_results_list, batch_gt_instances):
- """Create sample results that can be passed to Head.get_targets."""
- num_imgs = len(rpn_results_list)
- feat = torch.rand(1, 1, 3, 3)
- assign_config = dict(
- type='MaxIoUAssigner',
- pos_iou_thr=0.5,
- neg_iou_thr=0.5,
- min_pos_iou=0.5,
- ignore_iof_thr=-1)
- sampler_config = dict(
- type='RandomSampler',
- num=512,
- pos_fraction=0.25,
- neg_pos_ub=-1,
- add_gt_as_proposals=False)
- bbox_assigner = TASK_UTILS.build(assign_config)
- bbox_sampler = TASK_UTILS.build(sampler_config)
- sampling_results = []
- for i in range(num_imgs):
- assign_result = bbox_assigner.assign(rpn_results_list[i],
- batch_gt_instances[i])
- sampling_result = bbox_sampler.sample(
- assign_result,
- rpn_results_list[i],
- batch_gt_instances[i],
- feats=feat)
- sampling_results.append(sampling_result)
- return sampling_results
- class TestRoIEmbedHead(TestCase):
- def test_roi_embed_head_loss(self):
- """Test roi embed head loss when truth is non-empty."""
- cfg = mmengine.Config(
- dict(
- num_convs=2,
- num_fcs=2,
- roi_feat_size=7,
- in_channels=16,
- fc_out_channels=32))
- embed_head = RoIEmbedHead(**cfg)
- x = torch.rand(1, 16, 7, 7)
- ref_x = torch.rand(1, 16, 7, 7)
- num_x_per_img = [1]
- num_x_per_ref_img = [1]
- x_split, ref_x_split = embed_head.forward(x, ref_x, num_x_per_img,
- num_x_per_ref_img)
- gt_instance_ids = [torch.LongTensor([2])]
- ref_gt_instance_ids = [torch.LongTensor([2])]
- rpn_results = InstanceData()
- rpn_results.labels = torch.LongTensor([2])
- rpn_results.priors = torch.Tensor(
- [[23.6667, 23.8757, 238.6326, 151.8874]])
- rpn_results_list = [rpn_results]
- gt_instance = InstanceData()
- gt_instance.labels = torch.LongTensor([2])
- gt_instance.bboxes = torch.Tensor(
- [[23.6667, 23.8757, 238.6326, 151.8874]])
- gt_instance.instances_id = torch.LongTensor([2])
- batch_gt_instances = [gt_instance]
- sampling_results = _dummy_bbox_sampling(rpn_results_list,
- batch_gt_instances)
- gt_losses = embed_head.loss_by_feat(x_split, ref_x_split,
- sampling_results, gt_instance_ids,
- ref_gt_instance_ids)
- assert gt_losses['loss_match'] > 0, 'match loss should be non-zero'
- assert gt_losses[
- 'match_accuracy'] >= 0, 'match accuracy should be non-zero or zero'
- def test_roi_embed_head_predict(self):
- cfg = mmengine.Config(
- dict(
- num_convs=2,
- num_fcs=2,
- roi_feat_size=7,
- in_channels=16,
- fc_out_channels=32))
- embed_head = RoIEmbedHead(**cfg)
- x = torch.rand(1, 16, 7, 7)
- ref_x = torch.rand(1, 16, 7, 7)
- similarity_logits = embed_head.predict(x, ref_x)
- assert isinstance(similarity_logits, list)
- assert len(similarity_logits) == 1
|