1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- from unittest import TestCase
- import torch
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.testing import demo_track_inputs, random_boxes
- from mmdet.utils import register_all_modules
- class TestByteTracker(TestCase):
- @classmethod
- def setUpClass(cls):
- register_all_modules(init_default_scope=True)
- cfg = dict(
- type='ByteTracker',
- motion=dict(type='KalmanFilter'),
- obj_score_thrs=dict(high=0.6, low=0.1),
- init_track_thr=0.7,
- weight_iou_with_det_scores=True,
- match_iou_thrs=dict(high=0.1, low=0.5, tentative=0.3),
- num_tentatives=3,
- num_frames_retain=30)
- cls.tracker = MODELS.build(cfg)
- cls.tracker.kf = TASK_UTILS.build(dict(type='KalmanFilter'))
- cls.num_frames_retain = cfg['num_frames_retain']
- cls.num_objs = 30
- def test_init(self):
- bboxes = random_boxes(self.num_objs, 512)
- labels = torch.zeros(self.num_objs)
- scores = torch.ones(self.num_objs)
- ids = torch.arange(self.num_objs)
- self.tracker.update(
- ids=ids, bboxes=bboxes, scores=scores, labels=labels, frame_ids=0)
- assert self.tracker.ids == list(ids)
- assert self.tracker.memo_items == [
- 'ids', 'bboxes', 'scores', 'labels', 'frame_ids'
- ]
- def test_track(self):
- with torch.no_grad():
- packed_inputs = demo_track_inputs(batch_size=1, num_frames=2)
- track_data_sample = packed_inputs['data_samples'][0]
- video_len = len(track_data_sample)
- for frame_id in range(video_len):
- img_data_sample = track_data_sample[frame_id]
- img_data_sample.pred_instances = \
- img_data_sample.gt_instances.clone()
-
- scores = torch.ones(len(img_data_sample.gt_instances.bboxes))
- img_data_sample.pred_instances.scores = torch.FloatTensor(
- scores)
- pred_track_instances = self.tracker.track(
- data_sample=img_data_sample)
- bboxes = pred_track_instances.bboxes
- labels = pred_track_instances.labels
- assert bboxes.shape[1] == 4
- assert bboxes.shape[0] == labels.shape[0]
|