test_num_class_check_hook.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from copy import deepcopy
  3. from unittest import TestCase
  4. from unittest.mock import Mock
  5. from mmcv.cnn import VGG
  6. from mmengine.dataset import BaseDataset
  7. from torch import nn
  8. from mmdet.engine.hooks import NumClassCheckHook
  9. from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
  10. class TestNumClassCheckHook(TestCase):
  11. def setUp(self):
  12. # Setup NumClassCheckHook
  13. hook = NumClassCheckHook()
  14. self.hook = hook
  15. # Setup runner mock
  16. runner = Mock()
  17. runner.model = Mock()
  18. runner.logger = Mock()
  19. runner.logger.warning = Mock()
  20. runner.train_dataloader = Mock()
  21. runner.val_dataloader = Mock()
  22. self.runner = runner
  23. # Setup dataset
  24. metainfo = dict(classes=None)
  25. self.none_classmeta_dataset = BaseDataset(
  26. metainfo=metainfo, lazy_init=True)
  27. metainfo = dict(classes='class_name')
  28. self.str_classmeta_dataset = BaseDataset(
  29. metainfo=metainfo, lazy_init=True)
  30. metainfo = dict(classes=('bus', 'car'))
  31. self.normal_classmeta_dataset = BaseDataset(
  32. metainfo=metainfo, lazy_init=True)
  33. # Setup valid model
  34. valid_model = nn.Module()
  35. valid_model.add_module('backbone', VGG(depth=11))
  36. fused_semantic_head = FusedSemanticHead(
  37. num_ins=1,
  38. fusion_level=0,
  39. num_convs=1,
  40. in_channels=1,
  41. conv_out_channels=1)
  42. valid_model.add_module('semantic_head', fused_semantic_head)
  43. rpn_head = nn.Module()
  44. rpn_head.num_classes = 1
  45. valid_model.add_module('rpn_head', rpn_head)
  46. bbox_head = nn.Module()
  47. bbox_head.num_classes = 2
  48. valid_model.add_module('bbox_head', bbox_head)
  49. self.valid_model = valid_model
  50. # Setup invalid model
  51. invalid_model = nn.Module()
  52. bbox_head = nn.Module()
  53. bbox_head.num_classes = 4
  54. invalid_model.add_module('bbox_head', bbox_head)
  55. self.invalid_model = invalid_model
  56. def test_before_train_epch(self):
  57. runner = deepcopy(self.runner)
  58. # Test when dataset.metainfo['classes'] is None
  59. runner.train_dataloader.dataset = self.none_classmeta_dataset
  60. self.hook.before_train_epoch(runner)
  61. runner.logger.warning.assert_called_once()
  62. # Test when dataset.metainfo['classes'] is a str
  63. runner.train_dataloader.dataset = self.str_classmeta_dataset
  64. with self.assertRaises(AssertionError):
  65. self.hook.before_train_epoch(runner)
  66. runner.train_dataloader.dataset = self.normal_classmeta_dataset
  67. # Test `num_classes` of model is compatible with dataset
  68. runner.model = self.valid_model
  69. self.hook.before_train_epoch(runner)
  70. # Test `num_classes` of model is not compatible with dataset
  71. runner.model = self.invalid_model
  72. with self.assertRaises(AssertionError):
  73. self.hook.before_train_epoch(runner)
  74. def test_before_val_epoch(self):
  75. runner = deepcopy(self.runner)
  76. # Test when dataset.metainfo['classes'] is None
  77. runner.val_dataloader.dataset = self.none_classmeta_dataset
  78. self.hook.before_val_epoch(runner)
  79. runner.logger.warning.assert_called_once()
  80. # Test when dataset.metainfo['classes'] is a str
  81. runner.val_dataloader.dataset = self.str_classmeta_dataset
  82. with self.assertRaises(AssertionError):
  83. self.hook.before_val_epoch(runner)
  84. runner.val_dataloader.dataset = self.normal_classmeta_dataset
  85. # Test `num_classes` of model is compatible with dataset
  86. runner.model = self.valid_model
  87. self.hook.before_val_epoch(runner)
  88. # Test `num_classes` of model is not compatible with dataset
  89. runner.model = self.invalid_model
  90. with self.assertRaises(AssertionError):
  91. self.hook.before_val_epoch(runner)