123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import unittest
- import numpy as np
- from mmdet.datasets.transforms import BaseFrameSample, UniformRefFrameSample
- class TestFrameSample(unittest.TestCase):
- def setUp(self):
- """Setup the model and optimizer which are used in every test method.
- TestCase calls functions in this order: setUp() -> testMethod()
- -> tearDown() -> cleanUp()
- """
- self.H, self.W = 5, 8
- self.img = np.zeros((self.H, self.W, 3))
- self.gt_bboxes = np.zeros((2, 4))
- self.gt_bboxes_labels = [
- np.zeros((2, )),
- np.zeros((2, )) + 1,
- np.zeros((2, )) - 1
- ]
- self.gt_instances_id = [
- np.ones((2, ), dtype=np.int32),
- np.ones((2, ), dtype=np.int32) - 1,
- np.ones((2, ), dtype=np.int32) + 1
- ]
- self.frame_id = [0, 1, 2]
- self.scale_factor = [1.0, 1.5, 2.0]
- self.flip = [False] * 3
- self.ori_shape = [(self.H, self.W)] * 3
- self.img_id = [0, 1, 2]
- self.video_infos = dict(video_id=0, video_length=10, key_frame_id=4)
- self.video_infos['images'] = []
- self.info_keys = [
- 'video_id', 'video_length', 'img', 'gt_bboxes', 'gt_bboxes_labels',
- 'gt_instances_id', 'img_id', 'frame_id'
- ]
- for i in range(10):
- frame_info = dict(
- img=np.zeros((self.H, self.W, 3)) + i,
- gt_bboxes=np.zeros((2, 4)) + i,
- gt_bboxes_labels=np.zeros((2, )) + i,
- gt_instances_id=np.zeros((2, ), dtype=np.int32) + i,
- ori_shape=(self.H + i, self.W + i),
- frame_id=i,
- img_id=i)
- self.video_infos['images'].append(frame_info)
- def test_base_frame_sample(self):
- sampler = BaseFrameSample()
- results = sampler(self.video_infos)
- assert isinstance(results, dict)
- for key in self.info_keys:
- assert key in results
- assert len(results[key]) == 1
- if key == 'frame_id':
- assert results[key] == [4]
- key_frame_id = self.video_infos['key_frame_id']
- assert (results['img'][0] == np.zeros(
- (self.H, self.W, 3)) + key_frame_id).all()
- assert (results['gt_bboxes'][0] == np.zeros(
- (2, 4)) + key_frame_id).all()
- assert (results['gt_bboxes_labels'][0] == np.zeros(
- (2, )) + key_frame_id).all()
- assert (results['gt_instances_id'][0] == np.zeros(
- (2, )) + key_frame_id).all()
- assert results['ori_shape'][0] == (self.H + key_frame_id,
- self.W + key_frame_id)
- assert results['img_id'][0] == key_frame_id
- def test_uniform_ref_frame_sample(self):
- sampler = UniformRefFrameSample(
- num_ref_imgs=2, frame_range=[-1, 1], filter_key_img=True)
- results = sampler(self.video_infos)
- assert isinstance(results, dict)
- for key in self.info_keys:
- assert key in results
- assert len(results[key]) == 3
- if key == 'frame_id':
- assert results[key] == [3, 4, 5]
- key_frame_id = self.video_infos['key_frame_id']
- assert (results['img'][1] == np.zeros(
- (self.H, self.W, 3)) + key_frame_id).all()
- assert (results['gt_bboxes'][1] == np.zeros(
- (2, 4)) + key_frame_id).all()
- assert (results['gt_bboxes_labels'][1] == np.zeros(
- (2, )) + key_frame_id).all()
- assert (results['gt_instances_id'][1] == np.zeros(
- (2, )) + key_frame_id).all()
- assert results['ori_shape'][1] == (self.H + key_frame_id,
- self.W + key_frame_id)
- assert results['img_id'][1] == key_frame_id
- # test the filter_key_img and the correctness of returned frame index
- sampler = UniformRefFrameSample(
- num_ref_imgs=2, frame_range=[0, 1], filter_key_img=False)
- results = sampler(self.video_infos)
- assert 4 in results['img_id'] and results['img_id'].count(4) == 2
- assert 5 in results['img_id'] and results['img_id'].count(5) == 1
- assert results['key_frame_flags'] == [True, False, False]
- def test_repr(self):
- transform = BaseFrameSample()
- self.assertEqual(
- repr(transform),
- "BaseFrameSample(collect_video_keys=['video_id', 'video_length'])")
- transform = UniformRefFrameSample(
- num_ref_imgs=2, frame_range=10, filter_key_img=True)
- self.assertEqual(
- repr(transform),
- ('UniformRefFrameSample(num_ref_imgs=2, '
- 'frame_range=[-10, 10], filter_key_img=True, '
- "collect_video_keys=['video_id', 'video_length'])"))
|