test_panoptic_fpn_head.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import unittest
  2. import torch
  3. from mmengine.structures import PixelData
  4. from mmengine.testing import assert_allclose
  5. from mmdet.models.seg_heads import PanopticFPNHead
  6. from mmdet.structures import DetDataSample
  7. class TestPanopticFPNHead(unittest.TestCase):
  8. def test_init_weights(self):
  9. head = PanopticFPNHead(
  10. num_things_classes=2,
  11. num_stuff_classes=2,
  12. in_channels=32,
  13. inner_channels=32)
  14. head.init_weights()
  15. assert_allclose(head.conv_logits.bias.data,
  16. torch.zeros_like(head.conv_logits.bias.data))
  17. def test_loss(self):
  18. head = PanopticFPNHead(
  19. num_things_classes=2,
  20. num_stuff_classes=2,
  21. in_channels=32,
  22. inner_channels=32,
  23. start_level=0,
  24. end_level=1)
  25. x = [torch.rand((2, 32, 8, 8)), torch.rand((2, 32, 4, 4))]
  26. data_sample1 = DetDataSample()
  27. data_sample1.gt_sem_seg = PixelData(
  28. sem_seg=torch.randint(0, 4, (1, 7, 8)))
  29. data_sample2 = DetDataSample()
  30. data_sample2.gt_sem_seg = PixelData(
  31. sem_seg=torch.randint(0, 4, (1, 7, 8)))
  32. batch_data_samples = [data_sample1, data_sample2]
  33. results = head.loss(x, batch_data_samples)
  34. self.assertIsInstance(results, dict)
  35. def test_predict(self):
  36. head = PanopticFPNHead(
  37. num_things_classes=2,
  38. num_stuff_classes=2,
  39. in_channels=32,
  40. inner_channels=32,
  41. start_level=0,
  42. end_level=1)
  43. x = [torch.rand((2, 32, 8, 8)), torch.rand((2, 32, 4, 4))]
  44. img_meta1 = {
  45. 'batch_input_shape': (16, 16),
  46. 'img_shape': (14, 14),
  47. 'ori_shape': (12, 12),
  48. }
  49. img_meta2 = {
  50. 'batch_input_shape': (16, 16),
  51. 'img_shape': (16, 16),
  52. 'ori_shape': (16, 16),
  53. }
  54. batch_img_metas = [img_meta1, img_meta2]
  55. head.eval()
  56. with torch.no_grad():
  57. seg_preds = head.predict(x, batch_img_metas, rescale=False)
  58. self.assertTupleEqual(seg_preds[0].shape[-2:], (16, 16))
  59. self.assertTupleEqual(seg_preds[1].shape[-2:], (16, 16))
  60. seg_preds = head.predict(x, batch_img_metas, rescale=True)
  61. self.assertTupleEqual(seg_preds[0].shape[-2:], (12, 12))
  62. self.assertTupleEqual(seg_preds[1].shape[-2:], (16, 16))