1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- import unittest
- import torch
- from mmengine.config import Config
- from mmengine.structures import InstanceData
- from mmengine.testing import assert_allclose
- from mmdet.evaluation import INSTANCE_OFFSET
- from mmdet.models.seg_heads.panoptic_fusion_heads import HeuristicFusionHead
- class TestHeuristicFusionHead(unittest.TestCase):
- def test_loss(self):
- head = HeuristicFusionHead(num_things_classes=2, num_stuff_classes=2)
- result = head.loss()
- self.assertTrue(not head.with_loss)
- self.assertDictEqual(result, dict())
- def test_predict(self):
- test_cfg = Config(dict(mask_overlap=0.5, stuff_area_limit=1))
- head = HeuristicFusionHead(
- num_things_classes=2, num_stuff_classes=2, test_cfg=test_cfg)
- mask_results = InstanceData()
- mask_results.bboxes = torch.tensor([[0, 0, 1, 1], [1, 1, 2, 2]])
- mask_results.labels = torch.tensor([0, 1])
- mask_results.scores = torch.tensor([0.8, 0.7])
- mask_results.masks = torch.tensor([[[1, 0], [0, 0]], [[0, 0],
- [0, 1]]]).bool()
- seg_preds_list = [
- torch.tensor([[[0.2, 0.7], [0.3, 0.1]], [[0.2, 0.2], [0.6, 0.1]],
- [[0.6, 0.1], [0.1, 0.8]]])
- ]
- target_list = [
- torch.tensor([[0 + 1 * INSTANCE_OFFSET, 2],
- [3, 1 + 2 * INSTANCE_OFFSET]])
- ]
- results_list = head.predict([mask_results], seg_preds_list)
- for target, result in zip(target_list, results_list):
- assert_allclose(result.sem_seg[0], target)
- # test with no thing
- head = HeuristicFusionHead(
- num_things_classes=2, num_stuff_classes=2, test_cfg=test_cfg)
- mask_results = InstanceData()
- mask_results.bboxes = torch.zeros((0, 4))
- mask_results.labels = torch.zeros((0, )).long()
- mask_results.scores = torch.zeros((0, ))
- mask_results.masks = torch.zeros((0, 2, 2), dtype=torch.bool)
- seg_preds_list = [
- torch.tensor([[[0.2, 0.7], [0.3, 0.1]], [[0.2, 0.2], [0.6, 0.1]],
- [[0.6, 0.1], [0.1, 0.8]]])
- ]
- target_list = [torch.tensor([[4, 2], [3, 4]])]
- results_list = head.predict([mask_results], seg_preds_list)
- for target, result in zip(target_list, results_list):
- assert_allclose(result.sem_seg[0], target)
|