1234567891011121314151617181920212223242526272829303132333435363738394041 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- from unittest.mock import Mock, patch
- import torch.nn as nn
- from mmdet.engine.hooks import SyncNormHook
- class TestSyncNormHook(TestCase):
- @patch(
- 'mmdet.engine.hooks.sync_norm_hook.get_dist_info', return_value=(0, 1))
- def test_before_val_epoch_non_dist(self, mock):
- model = nn.Sequential(
- nn.Conv2d(1, 5, kernel_size=3), nn.BatchNorm2d(5, momentum=0.3),
- nn.Linear(5, 10))
- runner = Mock()
- runner.model = model
- hook = SyncNormHook()
- hook.before_val_epoch(runner)
- @patch(
- 'mmdet.engine.hooks.sync_norm_hook.get_dist_info', return_value=(0, 2))
- def test_before_val_epoch_dist(self, mock):
- model = nn.Sequential(
- nn.Conv2d(1, 5, kernel_size=3), nn.BatchNorm2d(5, momentum=0.3),
- nn.Linear(5, 10))
- runner = Mock()
- runner.model = model
- hook = SyncNormHook()
- hook.before_val_epoch(runner)
- @patch(
- 'mmdet.engine.hooks.sync_norm_hook.get_dist_info', return_value=(0, 2))
- def test_before_val_epoch_dist_no_norm(self, mock):
- model = nn.Sequential(nn.Conv2d(1, 5, kernel_size=3), nn.Linear(5, 10))
- runner = Mock()
- runner.model = model
- hook = SyncNormHook()
- hook.before_val_epoch(runner)
|