test_masktrack_rcnn_tracker.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.registry import init_default_scope
  5. from parameterized import parameterized
  6. from mmdet.registry import MODELS
  7. from mmdet.testing import demo_track_inputs, get_detector_cfg, random_boxes
  8. class TestMaskTrackRCNNTracker(TestCase):
  9. @classmethod
  10. def setUpClass(cls):
  11. init_default_scope('mmdet')
  12. tracker_cfg = dict(
  13. type='MaskTrackRCNNTracker',
  14. match_weights=dict(det_score=1.0, iou=2.0, det_label=10.0),
  15. num_frames_retain=20)
  16. cls.tracker = MODELS.build(tracker_cfg)
  17. cls.num_objs = 5
  18. def test_get_match_score(self):
  19. bboxes = random_boxes(self.num_objs, 64)
  20. labels = torch.arange(self.num_objs)
  21. scores = torch.arange(self.num_objs, dtype=torch.float32)
  22. similarity_logits = torch.randn(self.num_objs, self.num_objs + 1)
  23. match_score = self.tracker.get_match_score(bboxes, labels, scores,
  24. bboxes, labels,
  25. similarity_logits)
  26. assert match_score.size() == similarity_logits.size()
  27. @parameterized.expand([
  28. 'masktrack_rcnn/masktrack-rcnn_mask-rcnn_r50_fpn_8xb1-12e_youtubevis2019.py' # noqa: E501
  29. ])
  30. def test_track(self, cfg_file):
  31. _model = get_detector_cfg(cfg_file)
  32. # _scope_ will be popped after build
  33. model = MODELS.build(_model)
  34. packed_inputs = demo_track_inputs(
  35. batch_size=1, num_frames=2, with_mask=True)
  36. track_data_sample = packed_inputs['data_samples'][0]
  37. imgs = packed_inputs['inputs'][0]
  38. video_len = len(track_data_sample)
  39. for frame_id in range(video_len):
  40. img_data_sample = track_data_sample[frame_id]
  41. single_image = imgs[frame_id]
  42. img_data_sample.pred_instances = \
  43. img_data_sample.gt_instances.clone()
  44. # add fake scores
  45. scores = torch.ones(len(img_data_sample.pred_instances.bboxes))
  46. img_data_sample.pred_instances.scores = torch.FloatTensor(scores)
  47. feats = []
  48. for i in range(
  49. len(model.track_head.roi_extractor.featmap_strides)):
  50. feats.append(
  51. torch.rand(1, 256, 256 // (2**(i + 2)),
  52. 256 // (2**(i + 2))).to(device='cpu'))
  53. pred_track_instances = self.tracker.track(
  54. model=model,
  55. img=single_image,
  56. feats=tuple(feats),
  57. data_sample=img_data_sample)
  58. bboxes = pred_track_instances.bboxes
  59. labels = pred_track_instances.labels
  60. ids = pred_track_instances.instances_id
  61. assert bboxes.shape[1] == 4
  62. assert bboxes.shape[0] == labels.shape[0]
  63. assert bboxes.shape[0] == ids.shape[0]