123456789101112131415161718192021222324252627 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import torch
- from mmdet.models import GlobalAveragePooling
- class TestGlobalAveragePooling(TestCase):
- def test_forward(self):
- inputs = torch.rand(32, 128, 14, 14)
- # test AdaptiveAvgPool2d
- neck = GlobalAveragePooling()
- outputs = neck(inputs)
- assert outputs.shape == (32, 128)
- # test kernel_size
- neck = GlobalAveragePooling(kernel_size=7)
- outputs = neck(inputs)
- assert outputs.shape == (32, 128 * 2 * 2)
- # test kenel_size and stride
- neck = GlobalAveragePooling(kernel_size=7, stride=2)
- outputs = neck(inputs)
- assert outputs.shape == (32, 128 * 4 * 4)
|