test_gap.py 742 B

123456789101112131415161718192021222324252627
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmdet.models import GlobalAveragePooling
  5. class TestGlobalAveragePooling(TestCase):
  6. def test_forward(self):
  7. inputs = torch.rand(32, 128, 14, 14)
  8. # test AdaptiveAvgPool2d
  9. neck = GlobalAveragePooling()
  10. outputs = neck(inputs)
  11. assert outputs.shape == (32, 128)
  12. # test kernel_size
  13. neck = GlobalAveragePooling(kernel_size=7)
  14. outputs = neck(inputs)
  15. assert outputs.shape == (32, 128 * 2 * 2)
  16. # test kenel_size and stride
  17. neck = GlobalAveragePooling(kernel_size=7, stride=2)
  18. outputs = neck(inputs)
  19. assert outputs.shape == (32, 128 * 4 * 4)