123456789101112131415161718192021222324252627282930313233343536373839 |
- from unittest import TestCase
- import numpy as np
- from mmengine.registry import init_default_scope
- from mmdet.registry import TASK_UTILS
- class TestInterpolateTracklets(TestCase):
- @classmethod
- def setUpClass(cls):
- init_default_scope('mmdet')
- cls.cfg = dict(
- type='InterpolateTracklets',
- min_num_frames=5,
- max_num_frames=20,
- use_gsi=True,
- smooth_tau=10)
- def test_init(self):
- interpolation = TASK_UTILS.build(self.cfg)
- assert interpolation.min_num_frames == 5
- assert interpolation.max_num_frames == 20
- assert interpolation.use_gsi
- assert interpolation.smooth_tau == 10
- def test_forward(self):
- pred_track = np.random.randn(5, 7)
-
- pred_track[:, 0] = np.array([1, 2, 5, 6, 7])
- pred_track[:, 1] = 1
- interpolation = TASK_UTILS.build(self.cfg)
- linked_track = interpolation.forward(pred_track)
- assert isinstance(linked_track, np.ndarray)
- assert linked_track.shape == (5, 7)
|