test_aflink.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmengine.registry import init_default_scope
  5. from torch import nn
  6. from mmdet.registry import TASK_UTILS
  7. class TestAppearanceFreeLink(TestCase):
  8. @classmethod
  9. def setUpClass(cls):
  10. init_default_scope('mmdet')
  11. cls.cfg = dict(
  12. type='AppearanceFreeLink',
  13. checkpoint='',
  14. temporal_threshold=(0, 30),
  15. spatial_threshold=75,
  16. confidence_threshold=0.95,
  17. )
  18. def test_init(self):
  19. aflink = TASK_UTILS.build(self.cfg)
  20. assert aflink.temporal_threshold == (0, 30)
  21. assert aflink.spatial_threshold == 75
  22. assert aflink.confidence_threshold == 0.95
  23. assert isinstance(aflink.model, nn.Module)
  24. def test_forward(self):
  25. pred_track = np.random.randn(10, 7)
  26. aflink = TASK_UTILS.build(self.cfg)
  27. linked_track = aflink.forward(pred_track)
  28. assert isinstance(linked_track, np.ndarray)
  29. assert linked_track.shape == (10, 7)