123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import os
- import tempfile
- from unittest import TestCase
- import numpy as np
- import pycocotools.mask as mask_util
- import torch
- from mmengine.registry import init_default_scope
- from mmengine.structures import BaseDataElement, InstanceData
- from mmdet.registry import METRICS
- from mmdet.structures import DetDataSample, TrackDataSample
- class TestYouTubeVISMetric(TestCase):
- @classmethod
- def setUpClass(cls):
- init_default_scope('mmdet')
- def setUp(self):
- self.tmp_dir = tempfile.TemporaryDirectory()
- def tearDown(self):
- self.tmp_dir.cleanup()
- def _create_dummy_results(self, track_id):
- bboxes = np.array([[100, 100, 150, 150]])
- scores = np.array([1.0])
- labels = np.array([0])
- instance_id = np.array([track_id])
- dummy_mask = np.zeros((1, 720, 1280), dtype=np.uint8)
- dummy_mask[:, 100:150, 100:150] = 1
- return dict(
- bboxes=torch.from_numpy(bboxes),
- scores=torch.from_numpy(scores),
- labels=torch.from_numpy(labels),
- instances_id=torch.from_numpy(instance_id),
- masks=torch.from_numpy(dummy_mask))
- def test_format_only(self):
- outfile_prefix = f'{self.tmp_dir.name}/result'
- vis_metric = METRICS.build(
- dict(
- type='YouTubeVISMetric',
- format_only=True,
- outfile_prefix=outfile_prefix,
- ))
- dummy_pred = self._create_dummy_results(track_id=0)
- dummy_mask = np.zeros((720, 1280), order='F', dtype=np.uint8)
- dummy_mask[100:150, 100:150] = 1
- rle_mask = mask_util.encode(dummy_mask)
- rle_mask['counts'] = rle_mask['counts'].decode('utf-8')
- instances = [{
- 'bbox_label': 0,
- 'bbox': [100, 100, 150, 150],
- 'ignore_flag': 0,
- 'instance_id': 1,
- 'mask': rle_mask,
- }]
- vis_metric.dataset_meta = dict(classes=['car', 'train'])
- data_batch = dict(inputs=None, data_samples=None)
- gt_insatnce = InstanceData(**dummy_pred)
- img_data_sample = DetDataSample()
- img_data_sample.pred_track_instances = gt_insatnce
- img_data_sample.set_metainfo(
- dict(
- img_id=0,
- video_id=1,
- ori_video_length=1,
- ori_shape=(720, 1280),
- instances=instances))
- track_data_sample = TrackDataSample()
- track_data_sample.video_data_samples = [img_data_sample]
- predictions = []
- if isinstance(track_data_sample, BaseDataElement):
- predictions.append(track_data_sample.to_dict())
- vis_metric.process(data_batch, predictions)
- vis_metric.evaluate(size=1)
- assert os.path.exists(f'{outfile_prefix}.json')
- assert os.path.exists(f'{outfile_prefix}.submission_file.zip')
- def test_evaluate(self):
- """Test using the metric in the same way as Evaluator."""
- dummy_pred_1 = self._create_dummy_results(track_id=1)
- dummy_pred_2 = self._create_dummy_results(track_id=1)
- dummy_pred_3 = self._create_dummy_results(track_id=2)
- dummy_mask = np.zeros((720, 1280), order='F', dtype=np.uint8)
- dummy_mask[100:150, 100:150] = 1
- rle_mask = mask_util.encode(dummy_mask)
- rle_mask['counts'] = rle_mask['counts'].decode('utf-8')
- instances_1 = [{
- 'bbox_label': 0,
- 'bbox': [100, 100, 150, 150],
- 'ignore_flag': 0,
- 'instance_id': 1,
- 'mask': rle_mask,
- }]
- instances_2 = [{
- 'bbox_label': 0,
- 'bbox': [100, 100, 150, 150],
- 'ignore_flag': 0,
- 'instance_id': 2,
- 'mask': rle_mask,
- }]
- vis_metric = METRICS.build(
- dict(
- type='YouTubeVISMetric',
- outfile_prefix=f'{self.tmp_dir.name}/test',
- ))
- vis_metric.dataset_meta = dict(classes=['car', 'train'])
- data_batch = dict(inputs=None, data_samples=None)
- gt_insatnce = InstanceData(**dummy_pred_1)
- img_data_sample = DetDataSample()
- img_data_sample.pred_track_instances = gt_insatnce
- img_data_sample.set_metainfo(
- dict(
- img_id=1,
- video_id=1,
- ori_video_length=2,
- ori_shape=(720, 1280),
- instances=instances_1))
- gt_insatnce_2 = InstanceData(**dummy_pred_2)
- img_data_sample_2 = DetDataSample()
- img_data_sample_2.pred_track_instances = gt_insatnce_2
- img_data_sample_2.set_metainfo(
- dict(
- img_id=2,
- video_id=1,
- ori_video_length=2,
- ori_shape=(720, 1280),
- instances=instances_1))
- track_data_sample = TrackDataSample()
- track_data_sample.video_data_samples = [
- img_data_sample, img_data_sample_2
- ]
- predictions = []
- if isinstance(track_data_sample, BaseDataElement):
- predictions.append(track_data_sample.to_dict())
- vis_metric.process(data_batch, predictions)
- gt_insatnce = InstanceData(**dummy_pred_3)
- img_data_sample = DetDataSample()
- img_data_sample.pred_track_instances = gt_insatnce
- img_data_sample.set_metainfo(
- dict(
- img_id=3,
- video_id=2,
- ori_video_length=1,
- ori_shape=(720, 1280),
- instances=instances_2))
- track_data_sample = TrackDataSample()
- track_data_sample.video_data_samples = [img_data_sample]
- predictions = []
- if isinstance(track_data_sample, BaseDataElement):
- predictions.append(track_data_sample.to_dict())
- vis_metric.process(data_batch, predictions)
- eval_results = vis_metric.evaluate(size=3)
- target = {
- 'youtube_vis/segm_mAP': 1.0,
- 'youtube_vis/segm_mAP_50': 1.0,
- 'youtube_vis/segm_mAP_75': 1.0,
- 'youtube_vis/segm_mAP_s': 1.0,
- 'youtube_vis/segm_mAP_m': -1.0,
- 'youtube_vis/segm_mAP_l': -1.0,
- }
- self.assertDictEqual(eval_results, target)
|