gap.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmengine.model import BaseModule
  5. from mmdet.registry import MODELS
  6. @MODELS.register_module()
  7. class GlobalAveragePooling(BaseModule):
  8. """Global Average Pooling neck.
  9. Note that we use `view` to remove extra channel after pooling. We do not
  10. use `squeeze` as it will also remove the batch dimension when the tensor
  11. has a batch dimension of size 1, which can lead to unexpected errors.
  12. """
  13. def __init__(self, kernel_size=None, stride=None):
  14. super(GlobalAveragePooling, self).__init__()
  15. if kernel_size is None and stride is None:
  16. self.gap = nn.AdaptiveAvgPool2d((1, 1))
  17. else:
  18. self.gap = nn.AvgPool2d(kernel_size, stride)
  19. def forward(self, inputs):
  20. if isinstance(inputs, tuple):
  21. outs = tuple([self.gap(x) for x in inputs])
  22. outs = tuple([
  23. out.view(x.size(0),
  24. torch.tensor(out.size()[1:]).prod())
  25. for out, x in zip(outs, inputs)
  26. ])
  27. elif isinstance(inputs, torch.Tensor):
  28. outs = self.gap(inputs)
  29. outs = outs.view(
  30. inputs.size(0),
  31. torch.tensor(outs.size()[1:]).prod())
  32. else:
  33. raise TypeError('neck inputs should be tuple or torch.tensor')
  34. return outs