test_byte_tracker.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmdet.registry import MODELS, TASK_UTILS
  5. from mmdet.testing import demo_track_inputs, random_boxes
  6. from mmdet.utils import register_all_modules
  7. class TestByteTracker(TestCase):
  8. @classmethod
  9. def setUpClass(cls):
  10. register_all_modules(init_default_scope=True)
  11. cfg = dict(
  12. type='ByteTracker',
  13. motion=dict(type='KalmanFilter'),
  14. obj_score_thrs=dict(high=0.6, low=0.1),
  15. init_track_thr=0.7,
  16. weight_iou_with_det_scores=True,
  17. match_iou_thrs=dict(high=0.1, low=0.5, tentative=0.3),
  18. num_tentatives=3,
  19. num_frames_retain=30)
  20. cls.tracker = MODELS.build(cfg)
  21. cls.tracker.kf = TASK_UTILS.build(dict(type='KalmanFilter'))
  22. cls.num_frames_retain = cfg['num_frames_retain']
  23. cls.num_objs = 30
  24. def test_init(self):
  25. bboxes = random_boxes(self.num_objs, 512)
  26. labels = torch.zeros(self.num_objs)
  27. scores = torch.ones(self.num_objs)
  28. ids = torch.arange(self.num_objs)
  29. self.tracker.update(
  30. ids=ids, bboxes=bboxes, scores=scores, labels=labels, frame_ids=0)
  31. assert self.tracker.ids == list(ids)
  32. assert self.tracker.memo_items == [
  33. 'ids', 'bboxes', 'scores', 'labels', 'frame_ids'
  34. ]
  35. def test_track(self):
  36. with torch.no_grad():
  37. packed_inputs = demo_track_inputs(batch_size=1, num_frames=2)
  38. track_data_sample = packed_inputs['data_samples'][0]
  39. video_len = len(track_data_sample)
  40. for frame_id in range(video_len):
  41. img_data_sample = track_data_sample[frame_id]
  42. img_data_sample.pred_instances = \
  43. img_data_sample.gt_instances.clone()
  44. # add fake scores
  45. scores = torch.ones(len(img_data_sample.gt_instances.bboxes))
  46. img_data_sample.pred_instances.scores = torch.FloatTensor(
  47. scores)
  48. pred_track_instances = self.tracker.track(
  49. data_sample=img_data_sample)
  50. bboxes = pred_track_instances.bboxes
  51. labels = pred_track_instances.labels
  52. assert bboxes.shape[1] == 4
  53. assert bboxes.shape[0] == labels.shape[0]