# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import numpy as np
from mmengine.registry import init_default_scope

from mmdet.registry import TASK_UTILS


class TestInterpolateTracklets(TestCase):

    @classmethod
    def setUpClass(cls):
        init_default_scope('mmdet')
        cls.cfg = dict(
            type='InterpolateTracklets',
            min_num_frames=5,
            max_num_frames=20,
            use_gsi=True,
            smooth_tau=10)

    def test_init(self):
        interpolation = TASK_UTILS.build(self.cfg)
        assert interpolation.min_num_frames == 5
        assert interpolation.max_num_frames == 20
        assert interpolation.use_gsi
        assert interpolation.smooth_tau == 10

    def test_forward(self):
        pred_track = np.random.randn(5, 7)

        # set frame_id and target_id
        pred_track[:, 0] = np.array([1, 2, 5, 6, 7])
        pred_track[:, 1] = 1

        interpolation = TASK_UTILS.build(self.cfg)
        linked_track = interpolation.forward(pred_track)
        assert isinstance(linked_track, np.ndarray)
        assert linked_track.shape == (5, 7)