12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import os.path as osp
- from typing import List, Union
- from mmdet.registry import DATASETS
- from .base_video_dataset import BaseVideoDataset
- @DATASETS.register_module()
- class MOTChallengeDataset(BaseVideoDataset):
- """Dataset for MOTChallenge.
- Args:
- visibility_thr (float, optional): The minimum visibility
- for the objects during training. Default to -1.
- """
- METAINFO = {
- 'classes':
- ('pedestrian', 'person_on_vehicle', 'car', 'bicycle', 'motorbike',
- 'non_mot_vehicle', 'static_person', 'distractor', 'occluder',
- 'occluder_on_ground', 'occluder_full', 'reflection', 'crowd')
- }
- def __init__(self, visibility_thr: float = -1, *args, **kwargs):
- self.visibility_thr = visibility_thr
- super().__init__(*args, **kwargs)
- def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
- """Parse raw annotation to target format. The difference between this
- function and the one in ``BaseVideoDataset`` is that the parsing here
- adds ``visibility`` and ``mot_conf``.
- Args:
- raw_data_info (dict): Raw data information load from ``ann_file``
- Returns:
- Union[dict, List[dict]]: Parsed annotation.
- """
- img_info = raw_data_info['raw_img_info']
- ann_info = raw_data_info['raw_ann_info']
- data_info = {}
- data_info.update(img_info)
- if self.data_prefix.get('img_path', None) is not None:
- img_path = osp.join(self.data_prefix['img_path'],
- img_info['file_name'])
- else:
- img_path = img_info['file_name']
- data_info['img_path'] = img_path
- instances = []
- for i, ann in enumerate(ann_info):
- instance = {}
- if (not self.test_mode) and (ann['visibility'] <
- self.visibility_thr):
- continue
- if ann.get('ignore', False):
- continue
- x1, y1, w, h = ann['bbox']
- inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
- inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
- if inter_w * inter_h == 0:
- continue
- if ann['area'] <= 0 or w < 1 or h < 1:
- continue
- if ann['category_id'] not in self.cat_ids:
- continue
- bbox = [x1, y1, x1 + w, y1 + h]
- if ann.get('iscrowd', False):
- instance['ignore_flag'] = 1
- else:
- instance['ignore_flag'] = 0
- instance['bbox'] = bbox
- instance['bbox_label'] = self.cat2label[ann['category_id']]
- instance['instance_id'] = ann['instance_id']
- instance['category_id'] = ann['category_id']
- instance['mot_conf'] = ann['mot_conf']
- instance['visibility'] = ann['visibility']
- if len(instance) > 0:
- instances.append(instance)
- if not self.test_mode:
- assert len(instances) > 0, f'No valid instances found in ' \
- f'image {data_info["img_path"]}!'
- data_info['instances'] = instances
- return data_info
|