test_loading.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import os
  4. import os.path as osp
  5. import sys
  6. import unittest
  7. from unittest.mock import MagicMock, Mock, patch
  8. import mmcv
  9. import numpy as np
  10. from mmdet.datasets.transforms import (FilterAnnotations, LoadAnnotations,
  11. LoadEmptyAnnotations,
  12. LoadImageFromNDArray,
  13. LoadMultiChannelImageFromFiles,
  14. LoadProposals, LoadTrackAnnotations)
  15. from mmdet.evaluation import INSTANCE_OFFSET
  16. from mmdet.structures.mask import BitmapMasks, PolygonMasks
  17. try:
  18. import panopticapi
  19. except ImportError:
  20. panopticapi = None
  21. class TestLoadAnnotations(unittest.TestCase):
  22. def setUp(self):
  23. """Setup the model and optimizer which are used in every test method.
  24. TestCase calls functions in this order: setUp() -> testMethod() ->
  25. tearDown() -> cleanUp()
  26. """
  27. data_prefix = osp.join(osp.dirname(__file__), '../../data')
  28. seg_map = osp.join(data_prefix, 'gray.jpg')
  29. self.results = {
  30. 'ori_shape': (300, 400),
  31. 'seg_map_path':
  32. seg_map,
  33. 'instances': [{
  34. 'bbox': [0, 0, 10, 20],
  35. 'bbox_label': 1,
  36. 'mask': [[0, 0, 0, 20, 10, 20, 10, 0]],
  37. 'ignore_flag': 0
  38. }, {
  39. 'bbox': [10, 10, 110, 120],
  40. 'bbox_label': 2,
  41. 'mask': [[10, 10, 110, 10, 110, 120, 110, 10]],
  42. 'ignore_flag': 0
  43. }, {
  44. 'bbox': [50, 50, 60, 80],
  45. 'bbox_label': 2,
  46. 'mask': [[50, 50, 60, 50, 60, 80, 50, 80]],
  47. 'ignore_flag': 1
  48. }]
  49. }
  50. def test_load_bboxes(self):
  51. transform = LoadAnnotations(
  52. with_bbox=True,
  53. with_label=False,
  54. with_seg=False,
  55. with_mask=False,
  56. box_type=None)
  57. results = transform(copy.deepcopy(self.results))
  58. self.assertIn('gt_bboxes', results)
  59. self.assertTrue((results['gt_bboxes'] == np.array([[0, 0, 10, 20],
  60. [10, 10, 110, 120],
  61. [50, 50, 60,
  62. 80]])).all())
  63. self.assertEqual(results['gt_bboxes'].dtype, np.float32)
  64. self.assertTrue((results['gt_ignore_flags'] == np.array([0, 0,
  65. 1])).all())
  66. self.assertEqual(results['gt_ignore_flags'].dtype, bool)
  67. def test_load_labels(self):
  68. transform = LoadAnnotations(
  69. with_bbox=False,
  70. with_label=True,
  71. with_seg=False,
  72. with_mask=False,
  73. )
  74. results = transform(copy.deepcopy(self.results))
  75. self.assertIn('gt_bboxes_labels', results)
  76. self.assertTrue((results['gt_bboxes_labels'] == np.array([1, 2,
  77. 2])).all())
  78. self.assertEqual(results['gt_bboxes_labels'].dtype, np.int64)
  79. def test_load_mask(self):
  80. transform = LoadAnnotations(
  81. with_bbox=False,
  82. with_label=False,
  83. with_seg=False,
  84. with_mask=True,
  85. poly2mask=False)
  86. results = transform(copy.deepcopy(self.results))
  87. self.assertIn('gt_masks', results)
  88. self.assertEqual(len(results['gt_masks']), 3)
  89. self.assertIsInstance(results['gt_masks'], PolygonMasks)
  90. def test_load_mask_poly2mask(self):
  91. transform = LoadAnnotations(
  92. with_bbox=False,
  93. with_label=False,
  94. with_seg=False,
  95. with_mask=True,
  96. poly2mask=True)
  97. results = transform(copy.deepcopy(self.results))
  98. self.assertIn('gt_masks', results)
  99. self.assertEqual(len(results['gt_masks']), 3)
  100. self.assertIsInstance(results['gt_masks'], BitmapMasks)
  101. def test_load_semseg(self):
  102. transform = LoadAnnotations(
  103. with_bbox=False, with_label=False, with_seg=True, with_mask=False)
  104. results = transform(copy.deepcopy(self.results))
  105. self.assertIn('gt_seg_map', results)
  106. self.assertIn('ignore_index', results)
  107. self.assertEqual(results['gt_seg_map'].shape, (288, 512))
  108. # test reduce_zero_label and ignore_index
  109. transform = LoadAnnotations(
  110. with_bbox=False,
  111. with_label=False,
  112. with_seg=True,
  113. with_mask=False,
  114. reduce_zero_label=True,
  115. ignore_index=10)
  116. results = transform(copy.deepcopy(self.results))
  117. self.assertIn('gt_seg_map', results)
  118. self.assertIn('ignore_index', results)
  119. self.assertEqual(results['ignore_index'], 10)
  120. self.assertEqual(results['gt_seg_map'].shape, (288, 512))
  121. def test_repr(self):
  122. transform = LoadAnnotations(
  123. with_bbox=True,
  124. with_label=False,
  125. with_seg=False,
  126. with_mask=False,
  127. )
  128. self.assertEqual(
  129. repr(transform), ('LoadAnnotations(with_bbox=True, '
  130. 'with_label=False, with_mask=False, '
  131. 'with_seg=False, poly2mask=True, '
  132. "imdecode_backend='cv2', "
  133. 'backend_args=None)'))
  134. class TestFilterAnnotations(unittest.TestCase):
  135. def setUp(self):
  136. """Setup the model and optimizer which are used in every test method.
  137. TestCase calls functions in this order: setUp() -> testMethod() ->
  138. tearDown() -> cleanUp()
  139. """
  140. rng = np.random.RandomState(0)
  141. self.results = {
  142. 'img':
  143. np.random.random((224, 224, 3)),
  144. 'img_shape': (224, 224),
  145. 'gt_bboxes_labels':
  146. np.array([1, 2, 3], dtype=np.int64),
  147. 'gt_bboxes':
  148. np.array([[10, 10, 20, 20], [20, 20, 40, 40], [40, 40, 80, 80]]),
  149. 'gt_ignore_flags':
  150. np.array([0, 0, 1], dtype=np.bool8),
  151. 'gt_masks':
  152. BitmapMasks(rng.rand(3, 224, 224), height=224, width=224),
  153. }
  154. def test_transform(self):
  155. # test keep_empty = True
  156. transform = FilterAnnotations(
  157. min_gt_bbox_wh=(50, 50),
  158. keep_empty=True,
  159. )
  160. results = transform(copy.deepcopy(self.results))
  161. self.assertIsNone(results)
  162. # test keep_empty = False
  163. transform = FilterAnnotations(
  164. min_gt_bbox_wh=(50, 50),
  165. keep_empty=False,
  166. )
  167. results = transform(copy.deepcopy(self.results))
  168. self.assertTrue(isinstance(results, dict))
  169. # test filter annotations
  170. transform = FilterAnnotations(min_gt_bbox_wh=(15, 15), )
  171. results = transform(copy.deepcopy(self.results))
  172. self.assertIsInstance(results, dict)
  173. self.assertTrue((results['gt_bboxes_labels'] == np.array([2,
  174. 3])).all())
  175. self.assertTrue((results['gt_bboxes'] == np.array([[20, 20, 40, 40],
  176. [40, 40, 80,
  177. 80]])).all())
  178. self.assertEqual(len(results['gt_masks']), 2)
  179. self.assertEqual(len(results['gt_ignore_flags']), 2)
  180. def test_repr(self):
  181. transform = FilterAnnotations(
  182. min_gt_bbox_wh=(1, 1),
  183. keep_empty=False,
  184. )
  185. self.assertEqual(
  186. repr(transform), ('FilterAnnotations(min_gt_bbox_wh=(1, 1), '
  187. 'keep_empty=False)'))
  188. class TestLoadPanopticAnnotations(unittest.TestCase):
  189. def setUp(self):
  190. seg_map = np.zeros((10, 10), dtype=np.int32)
  191. seg_map[:5, :10] = 1 + 10 * INSTANCE_OFFSET
  192. seg_map[5:10, :5] = 4 + 11 * INSTANCE_OFFSET
  193. seg_map[5:10, 5:10] = 6 + 0 * INSTANCE_OFFSET
  194. rgb_seg_map = np.zeros((10, 10, 3), dtype=np.uint8)
  195. rgb_seg_map[:, :, 0] = seg_map / (256 * 256)
  196. rgb_seg_map[:, :, 1] = seg_map % (256 * 256) / 256
  197. rgb_seg_map[:, :, 2] = seg_map % 256
  198. self.seg_map_path = './1.png'
  199. mmcv.imwrite(rgb_seg_map, self.seg_map_path)
  200. self.seg_map = seg_map
  201. self.rgb_seg_map = rgb_seg_map
  202. self.results = {
  203. 'ori_shape': (10, 10),
  204. 'instances': [{
  205. 'bbox': [0, 0, 10, 5],
  206. 'bbox_label': 0,
  207. 'ignore_flag': 0,
  208. }, {
  209. 'bbox': [0, 5, 5, 10],
  210. 'bbox_label': 1,
  211. 'ignore_flag': 1,
  212. }],
  213. 'segments_info': [
  214. {
  215. 'id': 1 + 10 * INSTANCE_OFFSET,
  216. 'category': 0,
  217. 'is_thing': True,
  218. },
  219. {
  220. 'id': 4 + 11 * INSTANCE_OFFSET,
  221. 'category': 1,
  222. 'is_thing': True,
  223. },
  224. {
  225. 'id': 6 + 0 * INSTANCE_OFFSET,
  226. 'category': 2,
  227. 'is_thing': False,
  228. },
  229. ],
  230. 'seg_map_path':
  231. self.seg_map_path
  232. }
  233. self.gt_mask = BitmapMasks([
  234. (seg_map == 1 + 10 * INSTANCE_OFFSET).astype(np.uint8),
  235. (seg_map == 4 + 11 * INSTANCE_OFFSET).astype(np.uint8),
  236. ], 10, 10)
  237. self.gt_bboxes = np.array([[0, 0, 10, 5], [0, 5, 5, 10]],
  238. dtype=np.float32)
  239. self.gt_bboxes_labels = np.array([0, 1], dtype=np.int64)
  240. self.gt_ignore_flags = np.array([0, 1], dtype=bool)
  241. self.gt_seg_map = np.zeros((10, 10), dtype=np.int32)
  242. self.gt_seg_map[:5, :10] = 0
  243. self.gt_seg_map[5:10, :5] = 1
  244. self.gt_seg_map[5:10, 5:10] = 2
  245. def tearDown(self):
  246. os.remove(self.seg_map_path)
  247. @unittest.skipIf(panopticapi is not None, 'panopticapi is installed')
  248. def test_init_without_panopticapi(self):
  249. # test if panopticapi is not installed
  250. from mmdet.datasets.transforms import LoadPanopticAnnotations
  251. with self.assertRaisesRegex(
  252. ImportError,
  253. 'panopticapi is not installed, please install it by'):
  254. LoadPanopticAnnotations()
  255. def test_transform(self):
  256. sys.modules['panopticapi'] = MagicMock()
  257. sys.modules['panopticapi.utils'] = MagicMock()
  258. from mmdet.datasets.transforms import LoadPanopticAnnotations
  259. mock_rgb2id = Mock(return_value=self.seg_map)
  260. with patch('panopticapi.utils.rgb2id', mock_rgb2id):
  261. # test with all False
  262. transform = LoadPanopticAnnotations(
  263. with_bbox=False,
  264. with_label=False,
  265. with_mask=False,
  266. with_seg=False)
  267. results = transform(copy.deepcopy(self.results))
  268. self.assertDictEqual(results, self.results)
  269. # test with with_mask=True
  270. transform = LoadPanopticAnnotations(
  271. with_bbox=False,
  272. with_label=False,
  273. with_mask=True,
  274. with_seg=False)
  275. results = transform(copy.deepcopy(self.results))
  276. self.assertTrue(
  277. (results['gt_masks'].masks == self.gt_mask.masks).all())
  278. # test with with_seg=True
  279. transform = LoadPanopticAnnotations(
  280. with_bbox=False,
  281. with_label=False,
  282. with_mask=False,
  283. with_seg=True)
  284. results = transform(copy.deepcopy(self.results))
  285. self.assertNotIn('gt_masks', results)
  286. self.assertTrue((results['gt_seg_map'] == self.gt_seg_map).all())
  287. # test with all True
  288. transform = LoadPanopticAnnotations(
  289. with_bbox=True,
  290. with_label=True,
  291. with_mask=True,
  292. with_seg=True,
  293. box_type=None)
  294. results = transform(copy.deepcopy(self.results))
  295. self.assertTrue(
  296. (results['gt_masks'].masks == self.gt_mask.masks).all())
  297. self.assertTrue((results['gt_bboxes'] == self.gt_bboxes).all())
  298. self.assertTrue(
  299. (results['gt_bboxes_labels'] == self.gt_bboxes_labels).all())
  300. self.assertTrue(
  301. (results['gt_ignore_flags'] == self.gt_ignore_flags).all())
  302. self.assertTrue((results['gt_seg_map'] == self.gt_seg_map).all())
  303. class TestLoadImageFromNDArray(unittest.TestCase):
  304. def setUp(self):
  305. """Setup the model and optimizer which are used in every test method.
  306. TestCase calls functions in this order: setUp() -> testMethod() ->
  307. tearDown() -> cleanUp()
  308. """
  309. self.results = {'img': np.zeros((256, 256, 3), dtype=np.uint8)}
  310. def test_transform(self):
  311. transform = LoadImageFromNDArray()
  312. results = transform(copy.deepcopy(self.results))
  313. self.assertEqual(results['img'].shape, (256, 256, 3))
  314. self.assertEqual(results['img'].dtype, np.uint8)
  315. self.assertEqual(results['img_shape'], (256, 256))
  316. self.assertEqual(results['ori_shape'], (256, 256))
  317. # to_float32
  318. transform = LoadImageFromNDArray(to_float32=True)
  319. results = transform(copy.deepcopy(results))
  320. self.assertEqual(results['img'].dtype, np.float32)
  321. def test_repr(self):
  322. transform = LoadImageFromNDArray()
  323. self.assertEqual(
  324. repr(transform), ('LoadImageFromNDArray('
  325. 'ignore_empty=False, '
  326. 'to_float32=False, '
  327. "color_type='color', "
  328. "imdecode_backend='cv2', "
  329. 'backend_args=None)'))
  330. class TestLoadMultiChannelImageFromFiles(unittest.TestCase):
  331. def setUp(self):
  332. """Setup the model and optimizer which are used in every test method.
  333. TestCase calls functions in this order: setUp() -> testMethod() ->
  334. tearDown() -> cleanUp()
  335. """
  336. self.img_path = []
  337. for i in range(4):
  338. img_channel_path = f'./part_{i}.jpg'
  339. img_channel = np.zeros((10, 10), dtype=np.uint8)
  340. mmcv.imwrite(img_channel, img_channel_path)
  341. self.img_path.append(img_channel_path)
  342. self.results = {'img_path': self.img_path}
  343. def tearDown(self):
  344. for filename in self.img_path:
  345. os.remove(filename)
  346. def test_transform(self):
  347. transform = LoadMultiChannelImageFromFiles()
  348. results = transform(copy.deepcopy(self.results))
  349. self.assertEqual(results['img'].shape, (10, 10, 4))
  350. self.assertEqual(results['img'].dtype, np.uint8)
  351. self.assertEqual(results['img_shape'], (10, 10))
  352. self.assertEqual(results['ori_shape'], (10, 10))
  353. # to_float32
  354. transform = LoadMultiChannelImageFromFiles(to_float32=True)
  355. results = transform(copy.deepcopy(results))
  356. self.assertEqual(results['img'].dtype, np.float32)
  357. def test_rper(self):
  358. transform = LoadMultiChannelImageFromFiles()
  359. self.assertEqual(
  360. repr(transform), ('LoadMultiChannelImageFromFiles('
  361. 'to_float32=False, '
  362. "color_type='unchanged', "
  363. "imdecode_backend='cv2', "
  364. 'backend_args=None)'))
  365. class TestLoadProposals(unittest.TestCase):
  366. def test_transform(self):
  367. transform = LoadProposals()
  368. results = {
  369. 'proposals':
  370. dict(
  371. bboxes=np.zeros((5, 4), dtype=np.int64),
  372. scores=np.zeros((5, ), dtype=np.int64))
  373. }
  374. results = transform(results)
  375. self.assertEqual(results['proposals'].dtype, np.float32)
  376. self.assertEqual(results['proposals'].shape[-1], 4)
  377. self.assertEqual(results['proposals_scores'].dtype, np.float32)
  378. # bboxes.shape[1] should be 4
  379. results = {'proposals': dict(bboxes=np.zeros((5, 5), dtype=np.int64))}
  380. with self.assertRaises(AssertionError):
  381. transform(results)
  382. # bboxes.shape[0] should equal to scores.shape[0]
  383. results = {
  384. 'proposals':
  385. dict(
  386. bboxes=np.zeros((5, 4), dtype=np.int64),
  387. scores=np.zeros((3, ), dtype=np.int64))
  388. }
  389. with self.assertRaises(AssertionError):
  390. transform(results)
  391. # empty bboxes
  392. results = {
  393. 'proposals': dict(bboxes=np.zeros((0, 4), dtype=np.float32))
  394. }
  395. results = transform(results)
  396. excepted_proposals = np.zeros((0, 4), dtype=np.float32)
  397. excepted_proposals_scores = np.zeros(0, dtype=np.float32)
  398. self.assertTrue((results['proposals'] == excepted_proposals).all())
  399. self.assertTrue(
  400. (results['proposals_scores'] == excepted_proposals_scores).all())
  401. transform = LoadProposals(num_max_proposals=2)
  402. results = {
  403. 'proposals':
  404. dict(
  405. bboxes=np.zeros((5, 4), dtype=np.int64),
  406. scores=np.zeros((5, ), dtype=np.int64))
  407. }
  408. results = transform(results)
  409. self.assertEqual(results['proposals'].shape[0], 2)
  410. def test_repr(self):
  411. transform = LoadProposals()
  412. self.assertEqual(
  413. repr(transform), 'LoadProposals(num_max_proposals=None)')
  414. class TestLoadEmptyAnnotations(unittest.TestCase):
  415. def test_transform(self):
  416. transform = LoadEmptyAnnotations(
  417. with_bbox=True, with_label=True, with_mask=True, with_seg=True)
  418. results = {'img_shape': (224, 224)}
  419. results = transform(results)
  420. self.assertEqual(results['gt_bboxes'].dtype, np.float32)
  421. self.assertEqual(results['gt_bboxes'].shape[-1], 4)
  422. self.assertEqual(results['gt_ignore_flags'].dtype, bool)
  423. self.assertEqual(results['gt_bboxes_labels'].dtype, np.int64)
  424. self.assertEqual(results['gt_masks'].masks.dtype, np.uint8)
  425. self.assertEqual(results['gt_masks'].masks.shape[-2:],
  426. results['img_shape'])
  427. self.assertEqual(results['gt_seg_map'].dtype, np.uint8)
  428. self.assertEqual(results['gt_seg_map'].shape, results['img_shape'])
  429. def test_repr(self):
  430. transform = LoadEmptyAnnotations()
  431. self.assertEqual(
  432. repr(transform), 'LoadEmptyAnnotations(with_bbox=True, '
  433. 'with_label=True, '
  434. 'with_mask=False, '
  435. 'with_seg=False, '
  436. 'seg_ignore_label=255)')
  437. class TestLoadTrackAnnotations(unittest.TestCase):
  438. def setUp(self):
  439. data_prefix = osp.join(osp.dirname(__file__), '../data')
  440. seg_map = osp.join(data_prefix, 'grayscale.jpg')
  441. self.results = {
  442. 'seg_map_path':
  443. seg_map,
  444. 'instances': [{
  445. 'bbox': [0, 0, 10, 20],
  446. 'bbox_label': 1,
  447. 'instance_id': 100,
  448. 'keypoints': [1, 2, 3]
  449. }, {
  450. 'bbox': [10, 10, 110, 120],
  451. 'bbox_label': 2,
  452. 'instance_id': 102,
  453. 'keypoints': [4, 5, 6]
  454. }]
  455. }
  456. def test_load_instances_id(self):
  457. transform = LoadTrackAnnotations(
  458. with_bbox=False,
  459. with_label=True,
  460. with_seg=False,
  461. with_keypoints=False,
  462. )
  463. results = transform(copy.deepcopy(self.results))
  464. assert 'gt_instances_ids' in results
  465. assert (results['gt_instances_ids'] == np.array([100, 102])).all()
  466. assert results['gt_instances_ids'].dtype == np.int32
  467. def test_repr(self):
  468. transform = LoadTrackAnnotations(
  469. with_bbox=True, with_label=False, with_seg=False, with_mask=False)
  470. assert repr(transform) == ('LoadTrackAnnotations(with_bbox=True, '
  471. 'with_label=False, with_mask=False,'
  472. ' with_seg=False, poly2mask=True,'
  473. " imdecode_backend='cv2', "
  474. 'file_client_args=None)')