123456789101112131415161718192021222324252627 |
- from unittest import TestCase
- import torch
- from mmdet.models import GlobalAveragePooling
- class TestGlobalAveragePooling(TestCase):
- def test_forward(self):
- inputs = torch.rand(32, 128, 14, 14)
-
- neck = GlobalAveragePooling()
- outputs = neck(inputs)
- assert outputs.shape == (32, 128)
-
- neck = GlobalAveragePooling(kernel_size=7)
- outputs = neck(inputs)
- assert outputs.shape == (32, 128 * 2 * 2)
-
- neck = GlobalAveragePooling(kernel_size=7, stride=2)
- outputs = neck(inputs)
- assert outputs.shape == (32, 128 * 4 * 4)
|