test_track_data_sample.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from unittest import TestCase
  2. import pytest
  3. from mmdet.structures import DetDataSample, TrackDataSample
  4. class TestDetDataSample(TestCase):
  5. def test_init(self):
  6. track_data_sample = TrackDataSample(
  7. metainfo=dict(key_frames_inds=[0], ref_frames_inds=[1]))
  8. assert 'key_frames_inds' in track_data_sample.metainfo and \
  9. 'ref_frames_inds' in track_data_sample.metainfo
  10. assert track_data_sample.key_frames_inds == [0]
  11. assert track_data_sample.ref_frames_inds == [1]
  12. with pytest.raises(AssertionError):
  13. track_data_sample.get_key_frames()
  14. with pytest.raises(AssertionError):
  15. track_data_sample.get_ref_frames()
  16. def test_setter(self):
  17. det_data_sample_1 = DetDataSample(
  18. metainfo=dict(scale_factor=(1.5, 1.5)))
  19. det_data_sample_2 = DetDataSample(metainfo=dict(scale_factor=(2., 2.)))
  20. track_data_sample = TrackDataSample(
  21. metainfo=dict(key_frames_inds=[0], ref_frames_inds=[1]))
  22. track_data_sample.video_data_samples = [
  23. det_data_sample_1, det_data_sample_2
  24. ]
  25. assert track_data_sample.get_key_frames()[0].scale_factor == (1.5, 1.5)
  26. assert track_data_sample.get_ref_frames()[0].scale_factor == (2., 2.)
  27. def test_deleter(self):
  28. det_data_sample_1 = DetDataSample(
  29. metainfo=dict(scale_factor=(1.5, 1.5)))
  30. det_data_sample_2 = DetDataSample(metainfo=dict(scale_factor=(2., 2.)))
  31. track_data_sample = TrackDataSample(
  32. metainfo=dict(key_frames_inds=[0], ref_frames_inds=[1]))
  33. track_data_sample.video_data_samples = [
  34. det_data_sample_1, det_data_sample_2
  35. ]
  36. assert 'video_data_samples' in track_data_sample
  37. del track_data_sample.video_data_samples
  38. assert 'video_data_samples' not in track_data_sample