test_heuristic_fusion_head.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import unittest
  2. import torch
  3. from mmengine.config import Config
  4. from mmengine.structures import InstanceData
  5. from mmengine.testing import assert_allclose
  6. from mmdet.evaluation import INSTANCE_OFFSET
  7. from mmdet.models.seg_heads.panoptic_fusion_heads import HeuristicFusionHead
  8. class TestHeuristicFusionHead(unittest.TestCase):
  9. def test_loss(self):
  10. head = HeuristicFusionHead(num_things_classes=2, num_stuff_classes=2)
  11. result = head.loss()
  12. self.assertTrue(not head.with_loss)
  13. self.assertDictEqual(result, dict())
  14. def test_predict(self):
  15. test_cfg = Config(dict(mask_overlap=0.5, stuff_area_limit=1))
  16. head = HeuristicFusionHead(
  17. num_things_classes=2, num_stuff_classes=2, test_cfg=test_cfg)
  18. mask_results = InstanceData()
  19. mask_results.bboxes = torch.tensor([[0, 0, 1, 1], [1, 1, 2, 2]])
  20. mask_results.labels = torch.tensor([0, 1])
  21. mask_results.scores = torch.tensor([0.8, 0.7])
  22. mask_results.masks = torch.tensor([[[1, 0], [0, 0]], [[0, 0],
  23. [0, 1]]]).bool()
  24. seg_preds_list = [
  25. torch.tensor([[[0.2, 0.7], [0.3, 0.1]], [[0.2, 0.2], [0.6, 0.1]],
  26. [[0.6, 0.1], [0.1, 0.8]]])
  27. ]
  28. target_list = [
  29. torch.tensor([[0 + 1 * INSTANCE_OFFSET, 2],
  30. [3, 1 + 2 * INSTANCE_OFFSET]])
  31. ]
  32. results_list = head.predict([mask_results], seg_preds_list)
  33. for target, result in zip(target_list, results_list):
  34. assert_allclose(result.sem_seg[0], target)
  35. # test with no thing
  36. head = HeuristicFusionHead(
  37. num_things_classes=2, num_stuff_classes=2, test_cfg=test_cfg)
  38. mask_results = InstanceData()
  39. mask_results.bboxes = torch.zeros((0, 4))
  40. mask_results.labels = torch.zeros((0, )).long()
  41. mask_results.scores = torch.zeros((0, ))
  42. mask_results.masks = torch.zeros((0, 2, 2), dtype=torch.bool)
  43. seg_preds_list = [
  44. torch.tensor([[[0.2, 0.7], [0.3, 0.1]], [[0.2, 0.2], [0.6, 0.1]],
  45. [[0.6, 0.1], [0.1, 0.8]]])
  46. ]
  47. target_list = [torch.tensor([[4, 2], [3, 4]])]
  48. results_list = head.predict([mask_results], seg_preds_list)
  49. for target, result in zip(target_list, results_list):
  50. assert_allclose(result.sem_seg[0], target)