ytvis.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # Copyright (c) Github URL
  2. # Copied from
  3. # https://github.com/youtubevos/cocoapi/blob/master/PythonAPI/pycocotools/ytvos.py
  4. __author__ = 'ychfan'
  5. # Interface for accessing the YouTubeVIS dataset.
  6. # The following API functions are defined:
  7. # YTVIS - YTVIS api class that loads YouTubeVIS annotation file
  8. # and prepare data structures.
  9. # decodeMask - Decode binary mask M encoded via run-length encoding.
  10. # encodeMask - Encode binary mask M using run-length encoding.
  11. # getAnnIds - Get ann ids that satisfy given filter conditions.
  12. # getCatIds - Get cat ids that satisfy given filter conditions.
  13. # getImgIds - Get img ids that satisfy given filter conditions.
  14. # loadAnns - Load anns with the specified ids.
  15. # loadCats - Load cats with the specified ids.
  16. # loadImgs - Load imgs with the specified ids.
  17. # annToMask - Convert segmentation in an annotation to binary mask.
  18. # loadRes - Load algorithm results and create API for accessing them.
  19. # Microsoft COCO Toolbox. version 2.0
  20. # Data, paper, and tutorials available at: http://mscoco.org/
  21. # Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
  22. # Licensed under the Simplified BSD License [see bsd.txt]
  23. import copy
  24. import itertools
  25. import json
  26. import sys
  27. import time
  28. from collections import defaultdict
  29. import numpy as np
  30. from pycocotools import mask as maskUtils
  31. PYTHON_VERSION = sys.version_info[0]
  32. def _isArrayLike(obj):
  33. return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
  34. class YTVIS:
  35. def __init__(self, annotation_file=None):
  36. """Constructor of Microsoft COCO helper class for reading and
  37. visualizing annotations.
  38. :param annotation_file (str | dict): location of annotation file or
  39. dict results.
  40. :param image_folder (str): location to the folder that hosts images.
  41. :return:
  42. """
  43. # load dataset
  44. self.dataset, self.anns, self.cats, self.vids = dict(), dict(), dict(
  45. ), dict()
  46. self.vidToAnns, self.catToVids = defaultdict(list), defaultdict(list)
  47. if annotation_file is not None:
  48. print('loading annotations into memory...')
  49. tic = time.time()
  50. if type(annotation_file) == str:
  51. dataset = json.load(open(annotation_file, 'r'))
  52. else:
  53. dataset = annotation_file
  54. assert type(
  55. dataset
  56. ) == dict, 'annotation file format {} not supported'.format(
  57. type(dataset))
  58. print('Done (t={:0.2f}s)'.format(time.time() - tic))
  59. self.dataset = dataset
  60. self.createIndex()
  61. def createIndex(self):
  62. # create index
  63. print('creating index...')
  64. anns, cats, vids = {}, {}, {}
  65. vidToAnns, catToVids = defaultdict(list), defaultdict(list)
  66. if 'annotations' in self.dataset:
  67. for ann in self.dataset['annotations']:
  68. vidToAnns[ann['video_id']].append(ann)
  69. anns[ann['id']] = ann
  70. if 'videos' in self.dataset:
  71. for vid in self.dataset['videos']:
  72. vids[vid['id']] = vid
  73. if 'categories' in self.dataset:
  74. for cat in self.dataset['categories']:
  75. cats[cat['id']] = cat
  76. if 'annotations' in self.dataset and 'categories' in self.dataset:
  77. for ann in self.dataset['annotations']:
  78. catToVids[ann['category_id']].append(ann['video_id'])
  79. print('index created!')
  80. # create class members
  81. self.anns = anns
  82. self.vidToAnns = vidToAnns
  83. self.catToVids = catToVids
  84. self.vids = vids
  85. self.cats = cats
  86. def getAnnIds(self, vidIds=[], catIds=[], areaRng=[], iscrowd=None):
  87. """Get ann ids that satisfy given filter conditions. default skips that
  88. filter.
  89. :param vidIds (int array) : get anns for given vids
  90. catIds (int array) : get anns for given cats
  91. areaRng (float array) : get anns for given area range
  92. iscrowd (boolean) : get anns for given crowd label
  93. :return: ids (int array) : integer array of ann ids
  94. """
  95. vidIds = vidIds if _isArrayLike(vidIds) else [vidIds]
  96. catIds = catIds if _isArrayLike(catIds) else [catIds]
  97. if len(vidIds) == len(catIds) == len(areaRng) == 0:
  98. anns = self.dataset['annotations']
  99. else:
  100. if not len(vidIds) == 0:
  101. lists = [
  102. self.vidToAnns[vidId] for vidId in vidIds
  103. if vidId in self.vidToAnns
  104. ]
  105. anns = list(itertools.chain.from_iterable(lists))
  106. else:
  107. anns = self.dataset['annotations']
  108. anns = anns if len(catIds) == 0 else [
  109. ann for ann in anns if ann['category_id'] in catIds
  110. ]
  111. anns = anns if len(areaRng) == 0 else [
  112. ann for ann in anns if ann['avg_area'] > areaRng[0]
  113. and ann['avg_area'] < areaRng[1]
  114. ]
  115. if iscrowd is not None:
  116. ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
  117. else:
  118. ids = [ann['id'] for ann in anns]
  119. return ids
  120. def getCatIds(self, catNms=[], supNms=[], catIds=[]):
  121. """filtering parameters. default skips that filter.
  122. :param catNms (str array) : get cats for given cat names
  123. :param supNms (str array) : get cats for given supercategory names
  124. :param catIds (int array) : get cats for given cat ids
  125. :return: ids (int array) : integer array of cat ids
  126. """
  127. catNms = catNms if _isArrayLike(catNms) else [catNms]
  128. supNms = supNms if _isArrayLike(supNms) else [supNms]
  129. catIds = catIds if _isArrayLike(catIds) else [catIds]
  130. if len(catNms) == len(supNms) == len(catIds) == 0:
  131. cats = self.dataset['categories']
  132. else:
  133. cats = self.dataset['categories']
  134. cats = cats if len(catNms) == 0 else [
  135. cat for cat in cats if cat['name'] in catNms
  136. ]
  137. cats = cats if len(supNms) == 0 else [
  138. cat for cat in cats if cat['supercategory'] in supNms
  139. ]
  140. cats = cats if len(catIds) == 0 else [
  141. cat for cat in cats if cat['id'] in catIds
  142. ]
  143. ids = [cat['id'] for cat in cats]
  144. return ids
  145. def getVidIds(self, vidIds=[], catIds=[]):
  146. """Get vid ids that satisfy given filter conditions.
  147. :param vidIds (int array) : get vids for given ids
  148. :param catIds (int array) : get vids with all given cats
  149. :return: ids (int array) : integer array of vid ids
  150. """
  151. vidIds = vidIds if _isArrayLike(vidIds) else [vidIds]
  152. catIds = catIds if _isArrayLike(catIds) else [catIds]
  153. if len(vidIds) == len(catIds) == 0:
  154. ids = self.vids.keys()
  155. else:
  156. ids = set(vidIds)
  157. for i, catId in enumerate(catIds):
  158. if i == 0 and len(ids) == 0:
  159. ids = set(self.catToVids[catId])
  160. else:
  161. ids &= set(self.catToVids[catId])
  162. return list(ids)
  163. def loadAnns(self, ids=[]):
  164. """Load anns with the specified ids.
  165. :param ids (int array) : integer ids specifying anns
  166. :return: anns (object array) : loaded ann objects
  167. """
  168. if _isArrayLike(ids):
  169. return [self.anns[id] for id in ids]
  170. elif type(ids) == int:
  171. return [self.anns[ids]]
  172. def loadCats(self, ids=[]):
  173. """Load cats with the specified ids.
  174. :param ids (int array) : integer ids specifying cats
  175. :return: cats (object array) : loaded cat objects
  176. """
  177. if _isArrayLike(ids):
  178. return [self.cats[id] for id in ids]
  179. elif type(ids) == int:
  180. return [self.cats[ids]]
  181. def loadVids(self, ids=[]):
  182. """Load anns with the specified ids.
  183. :param ids (int array) : integer ids specifying vid
  184. :return: vids (object array) : loaded vid objects
  185. """
  186. if _isArrayLike(ids):
  187. return [self.vids[id] for id in ids]
  188. elif type(ids) == int:
  189. return [self.vids[ids]]
  190. def loadRes(self, resFile):
  191. """Load result file and return a result api object.
  192. :param resFile (str) : file name of result file
  193. :return: res (obj) : result api object
  194. """
  195. res = YTVIS()
  196. res.dataset['videos'] = [img for img in self.dataset['videos']]
  197. print('Loading and preparing results...')
  198. tic = time.time()
  199. if type(resFile) == str or (PYTHON_VERSION == 2
  200. and type(resFile) == str):
  201. anns = json.load(open(resFile))
  202. elif type(resFile) == np.ndarray:
  203. anns = self.loadNumpyAnnotations(resFile)
  204. else:
  205. anns = resFile
  206. assert type(anns) == list, 'results in not an array of objects'
  207. annsVidIds = [ann['video_id'] for ann in anns]
  208. assert set(annsVidIds) == (set(annsVidIds) & set(self.getVidIds())), \
  209. 'Results do not correspond to current coco set'
  210. if 'segmentations' in anns[0]:
  211. res.dataset['categories'] = copy.deepcopy(
  212. self.dataset['categories'])
  213. for id, ann in enumerate(anns):
  214. ann['areas'] = []
  215. if 'bboxes' not in ann:
  216. ann['bboxes'] = []
  217. for seg in ann['segmentations']:
  218. # now only support compressed RLE format
  219. # as segmentation results
  220. if seg:
  221. ann['areas'].append(maskUtils.area(seg))
  222. if len(ann['bboxes']) < len(ann['areas']):
  223. ann['bboxes'].append(maskUtils.toBbox(seg))
  224. else:
  225. ann['areas'].append(None)
  226. if len(ann['bboxes']) < len(ann['areas']):
  227. ann['bboxes'].append(None)
  228. ann['id'] = id + 1
  229. l_ori = [a for a in ann['areas'] if a]
  230. if len(l_ori) == 0:
  231. ann['avg_area'] = 0
  232. else:
  233. ann['avg_area'] = np.array(l_ori).mean()
  234. ann['iscrowd'] = 0
  235. print('DONE (t={:0.2f}s)'.format(time.time() - tic))
  236. res.dataset['annotations'] = anns
  237. res.createIndex()
  238. return res
  239. def annToRLE(self, ann, frameId):
  240. """Convert annotation which can be polygons, uncompressed RLE to RLE.
  241. :return: binary mask (numpy 2D array)
  242. """
  243. t = self.vids[ann['video_id']]
  244. h, w = t['height'], t['width']
  245. segm = ann['segmentations'][frameId]
  246. if type(segm) == list:
  247. # polygon -- a single object might consist of multiple parts
  248. # we merge all parts into one mask rle code
  249. rles = maskUtils.frPyObjects(segm, h, w)
  250. rle = maskUtils.merge(rles)
  251. elif type(segm['counts']) == list:
  252. # uncompressed RLE
  253. rle = maskUtils.frPyObjects(segm, h, w)
  254. else:
  255. # rle
  256. rle = segm
  257. return rle
  258. def annToMask(self, ann, frameId):
  259. """Convert annotation which can be polygons, uncompressed RLE, or RLE
  260. to binary mask.
  261. :return: binary mask (numpy 2D array)
  262. """
  263. rle = self.annToRLE(ann, frameId)
  264. m = maskUtils.decode(rle)
  265. return m