from unittest import TestCase import pytest from mmdet.structures import DetDataSample, TrackDataSample class TestDetDataSample(TestCase): def test_init(self): track_data_sample = TrackDataSample( metainfo=dict(key_frames_inds=[0], ref_frames_inds=[1])) assert 'key_frames_inds' in track_data_sample.metainfo and \ 'ref_frames_inds' in track_data_sample.metainfo assert track_data_sample.key_frames_inds == [0] assert track_data_sample.ref_frames_inds == [1] with pytest.raises(AssertionError): track_data_sample.get_key_frames() with pytest.raises(AssertionError): track_data_sample.get_ref_frames() def test_setter(self): det_data_sample_1 = DetDataSample( metainfo=dict(scale_factor=(1.5, 1.5))) det_data_sample_2 = DetDataSample(metainfo=dict(scale_factor=(2., 2.))) track_data_sample = TrackDataSample( metainfo=dict(key_frames_inds=[0], ref_frames_inds=[1])) track_data_sample.video_data_samples = [ det_data_sample_1, det_data_sample_2 ] assert track_data_sample.get_key_frames()[0].scale_factor == (1.5, 1.5) assert track_data_sample.get_ref_frames()[0].scale_factor == (2., 2.) def test_deleter(self): det_data_sample_1 = DetDataSample( metainfo=dict(scale_factor=(1.5, 1.5))) det_data_sample_2 = DetDataSample(metainfo=dict(scale_factor=(2., 2.))) track_data_sample = TrackDataSample( metainfo=dict(key_frames_inds=[0], ref_frames_inds=[1])) track_data_sample.video_data_samples = [ det_data_sample_1, det_data_sample_2 ] assert 'video_data_samples' in track_data_sample del track_data_sample.video_data_samples assert 'video_data_samples' not in track_data_sample