test_interpolation.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmengine.registry import init_default_scope
  5. from mmdet.registry import TASK_UTILS
  6. class TestInterpolateTracklets(TestCase):
  7. @classmethod
  8. def setUpClass(cls):
  9. init_default_scope('mmdet')
  10. cls.cfg = dict(
  11. type='InterpolateTracklets',
  12. min_num_frames=5,
  13. max_num_frames=20,
  14. use_gsi=True,
  15. smooth_tau=10)
  16. def test_init(self):
  17. interpolation = TASK_UTILS.build(self.cfg)
  18. assert interpolation.min_num_frames == 5
  19. assert interpolation.max_num_frames == 20
  20. assert interpolation.use_gsi
  21. assert interpolation.smooth_tau == 10
  22. def test_forward(self):
  23. pred_track = np.random.randn(5, 7)
  24. # set frame_id and target_id
  25. pred_track[:, 0] = np.array([1, 2, 5, 6, 7])
  26. pred_track[:, 1] = 1
  27. interpolation = TASK_UTILS.build(self.cfg)
  28. linked_track = interpolation.forward(pred_track)
  29. assert isinstance(linked_track, np.ndarray)
  30. assert linked_track.shape == (5, 7)