1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import torch
- from mmengine.registry import init_default_scope
- from parameterized import parameterized
- from mmdet.registry import MODELS
- from mmdet.testing import demo_track_inputs, get_detector_cfg, random_boxes
- class TestMaskTrackRCNNTracker(TestCase):
- @classmethod
- def setUpClass(cls):
- init_default_scope('mmdet')
- tracker_cfg = dict(
- type='MaskTrackRCNNTracker',
- match_weights=dict(det_score=1.0, iou=2.0, det_label=10.0),
- num_frames_retain=20)
- cls.tracker = MODELS.build(tracker_cfg)
- cls.num_objs = 5
- def test_get_match_score(self):
- bboxes = random_boxes(self.num_objs, 64)
- labels = torch.arange(self.num_objs)
- scores = torch.arange(self.num_objs, dtype=torch.float32)
- similarity_logits = torch.randn(self.num_objs, self.num_objs + 1)
- match_score = self.tracker.get_match_score(bboxes, labels, scores,
- bboxes, labels,
- similarity_logits)
- assert match_score.size() == similarity_logits.size()
- @parameterized.expand([
- 'masktrack_rcnn/masktrack-rcnn_mask-rcnn_r50_fpn_8xb1-12e_youtubevis2019.py' # noqa: E501
- ])
- def test_track(self, cfg_file):
- _model = get_detector_cfg(cfg_file)
- # _scope_ will be popped after build
- model = MODELS.build(_model)
- packed_inputs = demo_track_inputs(
- batch_size=1, num_frames=2, with_mask=True)
- track_data_sample = packed_inputs['data_samples'][0]
- imgs = packed_inputs['inputs'][0]
- video_len = len(track_data_sample)
- for frame_id in range(video_len):
- img_data_sample = track_data_sample[frame_id]
- single_image = imgs[frame_id]
- img_data_sample.pred_instances = \
- img_data_sample.gt_instances.clone()
- # add fake scores
- scores = torch.ones(len(img_data_sample.pred_instances.bboxes))
- img_data_sample.pred_instances.scores = torch.FloatTensor(scores)
- feats = []
- for i in range(
- len(model.track_head.roi_extractor.featmap_strides)):
- feats.append(
- torch.rand(1, 256, 256 // (2**(i + 2)),
- 256 // (2**(i + 2))).to(device='cpu'))
- pred_track_instances = self.tracker.track(
- model=model,
- img=single_image,
- feats=tuple(feats),
- data_sample=img_data_sample)
- bboxes = pred_track_instances.bboxes
- labels = pred_track_instances.labels
- ids = pred_track_instances.instances_id
- assert bboxes.shape[1] == 4
- assert bboxes.shape[0] == labels.shape[0]
- assert bboxes.shape[0] == ids.shape[0]
|