mot_challenge_dataset.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. from typing import List, Union
  4. from mmdet.registry import DATASETS
  5. from .base_video_dataset import BaseVideoDataset
  6. @DATASETS.register_module()
  7. class MOTChallengeDataset(BaseVideoDataset):
  8. """Dataset for MOTChallenge.
  9. Args:
  10. visibility_thr (float, optional): The minimum visibility
  11. for the objects during training. Default to -1.
  12. """
  13. METAINFO = {
  14. 'classes':
  15. ('pedestrian', 'person_on_vehicle', 'car', 'bicycle', 'motorbike',
  16. 'non_mot_vehicle', 'static_person', 'distractor', 'occluder',
  17. 'occluder_on_ground', 'occluder_full', 'reflection', 'crowd')
  18. }
  19. def __init__(self, visibility_thr: float = -1, *args, **kwargs):
  20. self.visibility_thr = visibility_thr
  21. super().__init__(*args, **kwargs)
  22. def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
  23. """Parse raw annotation to target format. The difference between this
  24. function and the one in ``BaseVideoDataset`` is that the parsing here
  25. adds ``visibility`` and ``mot_conf``.
  26. Args:
  27. raw_data_info (dict): Raw data information load from ``ann_file``
  28. Returns:
  29. Union[dict, List[dict]]: Parsed annotation.
  30. """
  31. img_info = raw_data_info['raw_img_info']
  32. ann_info = raw_data_info['raw_ann_info']
  33. data_info = {}
  34. data_info.update(img_info)
  35. if self.data_prefix.get('img_path', None) is not None:
  36. img_path = osp.join(self.data_prefix['img_path'],
  37. img_info['file_name'])
  38. else:
  39. img_path = img_info['file_name']
  40. data_info['img_path'] = img_path
  41. instances = []
  42. for i, ann in enumerate(ann_info):
  43. instance = {}
  44. if (not self.test_mode) and (ann['visibility'] <
  45. self.visibility_thr):
  46. continue
  47. if ann.get('ignore', False):
  48. continue
  49. x1, y1, w, h = ann['bbox']
  50. inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
  51. inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
  52. if inter_w * inter_h == 0:
  53. continue
  54. if ann['area'] <= 0 or w < 1 or h < 1:
  55. continue
  56. if ann['category_id'] not in self.cat_ids:
  57. continue
  58. bbox = [x1, y1, x1 + w, y1 + h]
  59. if ann.get('iscrowd', False):
  60. instance['ignore_flag'] = 1
  61. else:
  62. instance['ignore_flag'] = 0
  63. instance['bbox'] = bbox
  64. instance['bbox_label'] = self.cat2label[ann['category_id']]
  65. instance['instance_id'] = ann['instance_id']
  66. instance['category_id'] = ann['category_id']
  67. instance['mot_conf'] = ann['mot_conf']
  68. instance['visibility'] = ann['visibility']
  69. if len(instance) > 0:
  70. instances.append(instance)
  71. if not self.test_mode:
  72. assert len(instances) > 0, f'No valid instances found in ' \
  73. f'image {data_info["img_path"]}!'
  74. data_info['instances'] = instances
  75. return data_info