test_sync_norm_hook.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. from unittest.mock import Mock, patch
  4. import torch.nn as nn
  5. from mmdet.engine.hooks import SyncNormHook
  6. class TestSyncNormHook(TestCase):
  7. @patch(
  8. 'mmdet.engine.hooks.sync_norm_hook.get_dist_info', return_value=(0, 1))
  9. def test_before_val_epoch_non_dist(self, mock):
  10. model = nn.Sequential(
  11. nn.Conv2d(1, 5, kernel_size=3), nn.BatchNorm2d(5, momentum=0.3),
  12. nn.Linear(5, 10))
  13. runner = Mock()
  14. runner.model = model
  15. hook = SyncNormHook()
  16. hook.before_val_epoch(runner)
  17. @patch(
  18. 'mmdet.engine.hooks.sync_norm_hook.get_dist_info', return_value=(0, 2))
  19. def test_before_val_epoch_dist(self, mock):
  20. model = nn.Sequential(
  21. nn.Conv2d(1, 5, kernel_size=3), nn.BatchNorm2d(5, momentum=0.3),
  22. nn.Linear(5, 10))
  23. runner = Mock()
  24. runner.model = model
  25. hook = SyncNormHook()
  26. hook.before_val_epoch(runner)
  27. @patch(
  28. 'mmdet.engine.hooks.sync_norm_hook.get_dist_info', return_value=(0, 2))
  29. def test_before_val_epoch_dist_no_norm(self, mock):
  30. model = nn.Sequential(nn.Conv2d(1, 5, kernel_size=3), nn.Linear(5, 10))
  31. runner = Mock()
  32. runner.model = model
  33. hook = SyncNormHook()
  34. hook.before_val_epoch(runner)