fc_module.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. from mmcv.cnn import build_activation_layer, build_norm_layer
  4. from mmengine.model import BaseModule
  5. from mmdet.registry import MODELS
  6. @MODELS.register_module()
  7. class FcModule(BaseModule):
  8. """Fully-connected layer module.
  9. Args:
  10. in_channels (int): Input channels.
  11. out_channels (int): Ourput channels.
  12. norm_cfg (dict, optional): Configuration of normlization method
  13. after fc. Defaults to None.
  14. act_cfg (dict, optional): Configuration of activation method after fc.
  15. Defaults to dict(type='ReLU').
  16. inplace (bool, optional): Whether inplace the activatation module.
  17. Defaults to True.
  18. init_cfg (dict, optional): Initialization config dict.
  19. Defaults to dict(type='Kaiming', layer='Linear').
  20. """
  21. def __init__(self,
  22. in_channels: int,
  23. out_channels: int,
  24. norm_cfg: dict = None,
  25. act_cfg: dict = dict(type='ReLU'),
  26. inplace: bool = True,
  27. init_cfg=dict(type='Kaiming', layer='Linear')):
  28. super(FcModule, self).__init__(init_cfg)
  29. assert norm_cfg is None or isinstance(norm_cfg, dict)
  30. assert act_cfg is None or isinstance(act_cfg, dict)
  31. self.norm_cfg = norm_cfg
  32. self.act_cfg = act_cfg
  33. self.inplace = inplace
  34. self.with_norm = norm_cfg is not None
  35. self.with_activation = act_cfg is not None
  36. self.fc = nn.Linear(in_channels, out_channels)
  37. # build normalization layers
  38. if self.with_norm:
  39. self.norm_name, norm = build_norm_layer(norm_cfg, out_channels)
  40. self.add_module(self.norm_name, norm)
  41. # build activation layer
  42. if self.with_activation:
  43. act_cfg_ = act_cfg.copy()
  44. # nn.Tanh has no 'inplace' argument
  45. if act_cfg_['type'] not in [
  46. 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
  47. ]:
  48. act_cfg_.setdefault('inplace', inplace)
  49. self.activate = build_activation_layer(act_cfg_)
  50. @property
  51. def norm(self):
  52. """Normalization."""
  53. return getattr(self, self.norm_name)
  54. def forward(self, x, activate=True, norm=True):
  55. """Model forward."""
  56. x = self.fc(x)
  57. if norm and self.with_norm:
  58. x = self.norm(x)
  59. if activate and self.with_activation:
  60. x = self.activate(x)
  61. return x