ytviseval.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. # Copyright (c) Github URL
  2. # Copied from
  3. # https://github.com/youtubevos/cocoapi/blob/master/PythonAPI/pycocotools/ytvoseval.py
  4. __author__ = 'ychfan'
  5. import copy
  6. import datetime
  7. import time
  8. from collections import defaultdict
  9. import numpy as np
  10. from pycocotools import mask as maskUtils
  11. class YTVISeval:
  12. # Interface for evaluating video instance segmentation on
  13. # the YouTubeVIS dataset.
  14. #
  15. # The usage for YTVISeval is as follows:
  16. # cocoGt=..., cocoDt=... # load dataset and results
  17. # E = YTVISeval(cocoGt,cocoDt); # initialize YTVISeval object
  18. # E.params.recThrs = ...; # set parameters as desired
  19. # E.evaluate(); # run per image evaluation
  20. # E.accumulate(); # accumulate per image results
  21. # E.summarize(); # display summary metrics of results
  22. # For example usage see evalDemo.m and http://mscoco.org/.
  23. #
  24. # The evaluation parameters are as follows (defaults in brackets):
  25. # imgIds - [all] N img ids to use for evaluation
  26. # catIds - [all] K cat ids to use for evaluation
  27. # iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation
  28. # recThrs - [0:.01:1] R=101 recall thresholds for evaluation
  29. # areaRng - [...] A=4 object area ranges for evaluation
  30. # maxDets - [1 10 100] M=3 thresholds on max detections per image
  31. # iouType - ['segm'] set iouType to 'segm', 'bbox' or 'keypoints'
  32. # iouType replaced the now DEPRECATED useSegm parameter.
  33. # useCats - [1] if true use category labels for evaluation
  34. # Note: if useCats=0 category labels are ignored as in proposal scoring.
  35. # Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified.
  36. #
  37. # evaluate(): evaluates detections on every image and every category and
  38. # concats the results into the "evalImgs" with fields:
  39. # dtIds - [1xD] id for each of the D detections (dt)
  40. # gtIds - [1xG] id for each of the G ground truths (gt)
  41. # dtMatches - [TxD] matching gt id at each IoU or 0
  42. # gtMatches - [TxG] matching dt id at each IoU or 0
  43. # dtScores - [1xD] confidence of each dt
  44. # gtIgnore - [1xG] ignore flag for each gt
  45. # dtIgnore - [TxD] ignore flag for each dt at each IoU
  46. #
  47. # accumulate(): accumulates the per-image, per-category evaluation
  48. # results in "evalImgs" into the dictionary "eval" with fields:
  49. # params - parameters used for evaluation
  50. # date - date evaluation was performed
  51. # counts - [T,R,K,A,M] parameter dimensions (see above)
  52. # precision - [TxRxKxAxM] precision for every evaluation setting
  53. # recall - [TxKxAxM] max recall for every evaluation setting
  54. # Note: precision and recall==-1 for settings with no gt objects.
  55. #
  56. # See also coco, mask, pycocoDemo, pycocoEvalDemo
  57. #
  58. # Microsoft COCO Toolbox. version 2.0
  59. # Data, paper, and tutorials available at: http://mscoco.org/
  60. # Code written by Piotr Dollar and Tsung-Yi Lin, 2015.
  61. # Licensed under the Simplified BSD License [see coco/license.txt]
  62. def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'):
  63. """Initialize CocoEval using coco APIs for gt and dt.
  64. :param cocoGt: coco object with ground truth annotations
  65. :param cocoDt: coco object with detection results
  66. :return: None
  67. """
  68. if not iouType:
  69. print('iouType not specified. use default iouType segm')
  70. self.cocoGt = cocoGt # ground truth COCO API
  71. self.cocoDt = cocoDt # detections COCO API
  72. self.params = {} # evaluation parameters
  73. self.evalVids = defaultdict(
  74. list) # per-image per-category evaluation results [KxAxI] elements
  75. self.eval = {} # accumulated evaluation results
  76. self._gts = defaultdict(list) # gt for evaluation
  77. self._dts = defaultdict(list) # dt for evaluation
  78. self.params = Params(iouType=iouType) # parameters
  79. self._paramsEval = {} # parameters for evaluation
  80. self.stats = [] # result summarization
  81. self.ious = {} # ious between all gts and dts
  82. if cocoGt is not None:
  83. self.params.vidIds = sorted(cocoGt.getVidIds())
  84. self.params.catIds = sorted(cocoGt.getCatIds())
  85. def _prepare(self):
  86. '''
  87. Prepare ._gts and ._dts for evaluation based on params
  88. :return: None
  89. '''
  90. def _toMask(anns, coco):
  91. # modify ann['segmentation'] by reference
  92. for ann in anns:
  93. for i, a in enumerate(ann['segmentations']):
  94. if a:
  95. rle = coco.annToRLE(ann, i)
  96. ann['segmentations'][i] = rle
  97. l_ori = [a for a in ann['areas'] if a]
  98. if len(l_ori) == 0:
  99. ann['avg_area'] = 0
  100. else:
  101. ann['avg_area'] = np.array(l_ori).mean()
  102. p = self.params
  103. if p.useCats:
  104. gts = self.cocoGt.loadAnns(
  105. self.cocoGt.getAnnIds(vidIds=p.vidIds, catIds=p.catIds))
  106. dts = self.cocoDt.loadAnns(
  107. self.cocoDt.getAnnIds(vidIds=p.vidIds, catIds=p.catIds))
  108. else:
  109. gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(vidIds=p.vidIds))
  110. dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(vidIds=p.vidIds))
  111. # convert ground truth to mask if iouType == 'segm'
  112. if p.iouType == 'segm':
  113. _toMask(gts, self.cocoGt)
  114. _toMask(dts, self.cocoDt)
  115. # set ignore flag
  116. for gt in gts:
  117. gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0
  118. gt['ignore'] = 'iscrowd' in gt and gt['iscrowd']
  119. if p.iouType == 'keypoints':
  120. gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore']
  121. self._gts = defaultdict(list) # gt for evaluation
  122. self._dts = defaultdict(list) # dt for evaluation
  123. for gt in gts:
  124. self._gts[gt['video_id'], gt['category_id']].append(gt)
  125. for dt in dts:
  126. self._dts[dt['video_id'], dt['category_id']].append(dt)
  127. self.evalVids = defaultdict(
  128. list) # per-image per-category evaluation results
  129. self.eval = {} # accumulated evaluation results
  130. def evaluate(self):
  131. '''
  132. Run per image evaluation on given images and store
  133. results (a list of dict) in self.evalVids
  134. :return: None
  135. '''
  136. tic = time.time()
  137. print('Running per image evaluation...')
  138. p = self.params
  139. # add backward compatibility if useSegm is specified in params
  140. if p.useSegm is not None:
  141. p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
  142. print('useSegm (deprecated) is not None. Running {} evaluation'.
  143. format(p.iouType))
  144. print('Evaluate annotation type *{}*'.format(p.iouType))
  145. p.vidIds = list(np.unique(p.vidIds))
  146. if p.useCats:
  147. p.catIds = list(np.unique(p.catIds))
  148. p.maxDets = sorted(p.maxDets)
  149. self.params = p
  150. self._prepare()
  151. # loop through images, area range, max detection number
  152. catIds = p.catIds if p.useCats else [-1]
  153. if p.iouType == 'segm' or p.iouType == 'bbox':
  154. computeIoU = self.computeIoU
  155. elif p.iouType == 'keypoints':
  156. computeIoU = self.computeOks
  157. self.ious = {(vidId, catId): computeIoU(vidId, catId)
  158. for vidId in p.vidIds for catId in catIds}
  159. evaluateVid = self.evaluateVid
  160. maxDet = p.maxDets[-1]
  161. self.evalImgs = [
  162. evaluateVid(vidId, catId, areaRng, maxDet) for catId in catIds
  163. for areaRng in p.areaRng for vidId in p.vidIds
  164. ]
  165. self._paramsEval = copy.deepcopy(self.params)
  166. toc = time.time()
  167. print('DONE (t={:0.2f}s).'.format(toc - tic))
  168. def computeIoU(self, vidId, catId):
  169. p = self.params
  170. if p.useCats:
  171. gt = self._gts[vidId, catId]
  172. dt = self._dts[vidId, catId]
  173. else:
  174. gt = [_ for cId in p.catIds for _ in self._gts[vidId, cId]]
  175. dt = [_ for cId in p.catIds for _ in self._dts[vidId, cId]]
  176. if len(gt) == 0 and len(dt) == 0:
  177. return []
  178. inds = np.argsort([-d['score'] for d in dt], kind='mergesort')
  179. dt = [dt[i] for i in inds]
  180. if len(dt) > p.maxDets[-1]:
  181. dt = dt[0:p.maxDets[-1]]
  182. if p.iouType == 'segm':
  183. g = [g['segmentations'] for g in gt]
  184. d = [d['segmentations'] for d in dt]
  185. elif p.iouType == 'bbox':
  186. g = [g['bboxes'] for g in gt]
  187. d = [d['bboxes'] for d in dt]
  188. else:
  189. raise Exception('unknown iouType for iou computation')
  190. # compute iou between each dt and gt region
  191. def iou_seq(d_seq, g_seq):
  192. i = .0
  193. u = .0
  194. for d, g in zip(d_seq, g_seq):
  195. if d and g:
  196. i += maskUtils.area(maskUtils.merge([d, g], True))
  197. u += maskUtils.area(maskUtils.merge([d, g], False))
  198. elif not d and g:
  199. u += maskUtils.area(g)
  200. elif d and not g:
  201. u += maskUtils.area(d)
  202. if not u > .0:
  203. print('Mask sizes in video {} and category {} may not match!'.
  204. format(vidId, catId))
  205. iou = i / u if u > .0 else .0
  206. return iou
  207. ious = np.zeros([len(d), len(g)])
  208. for i, j in np.ndindex(ious.shape):
  209. ious[i, j] = iou_seq(d[i], g[j])
  210. return ious
  211. def computeOks(self, imgId, catId):
  212. p = self.params
  213. gts = self._gts[imgId, catId]
  214. dts = self._dts[imgId, catId]
  215. inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
  216. dts = [dts[i] for i in inds]
  217. if len(dts) > p.maxDets[-1]:
  218. dts = dts[0:p.maxDets[-1]]
  219. # if len(gts) == 0 and len(dts) == 0:
  220. if len(gts) == 0 or len(dts) == 0:
  221. return []
  222. ious = np.zeros((len(dts), len(gts)))
  223. sigmas = np.array([
  224. .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07,
  225. .87, .87, .89, .89
  226. ]) / 10.0
  227. vars = (sigmas * 2)**2
  228. k = len(sigmas)
  229. # compute oks between each detection and ground truth object
  230. for j, gt in enumerate(gts):
  231. # create bounds for ignore regions(double the gt bbox)
  232. g = np.array(gt['keypoints'])
  233. xg = g[0::3]
  234. yg = g[1::3]
  235. vg = g[2::3]
  236. k1 = np.count_nonzero(vg > 0)
  237. bb = gt['bbox']
  238. x0 = bb[0] - bb[2]
  239. x1 = bb[0] + bb[2] * 2
  240. y0 = bb[1] - bb[3]
  241. y1 = bb[1] + bb[3] * 2
  242. for i, dt in enumerate(dts):
  243. d = np.array(dt['keypoints'])
  244. xd = d[0::3]
  245. yd = d[1::3]
  246. if k1 > 0:
  247. # measure the per-keypoint distance if keypoints visible
  248. dx = xd - xg
  249. dy = yd - yg
  250. else:
  251. # measure minimum distance to keypoints
  252. z = np.zeros((k))
  253. dx = np.max((z, x0 - xd), axis=0) + np.max(
  254. (z, xd - x1), axis=0)
  255. dy = np.max((z, y0 - yd), axis=0) + np.max(
  256. (z, yd - y1), axis=0)
  257. e = (dx**2 + dy**2) / vars / (gt['avg_area'] +
  258. np.spacing(1)) / 2
  259. if k1 > 0:
  260. e = e[vg > 0]
  261. ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
  262. return ious
  263. def evaluateVid(self, vidId, catId, aRng, maxDet):
  264. '''
  265. perform evaluation for single category and image
  266. :return: dict (single image results)
  267. '''
  268. p = self.params
  269. if p.useCats:
  270. gt = self._gts[vidId, catId]
  271. dt = self._dts[vidId, catId]
  272. else:
  273. gt = [_ for cId in p.catIds for _ in self._gts[vidId, cId]]
  274. dt = [_ for cId in p.catIds for _ in self._dts[vidId, cId]]
  275. if len(gt) == 0 and len(dt) == 0:
  276. return None
  277. for g in gt:
  278. if g['ignore'] or (g['avg_area'] < aRng[0]
  279. or g['avg_area'] > aRng[1]):
  280. g['_ignore'] = 1
  281. else:
  282. g['_ignore'] = 0
  283. # sort dt highest score first, sort gt ignore last
  284. gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
  285. gt = [gt[i] for i in gtind]
  286. dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
  287. dt = [dt[i] for i in dtind[0:maxDet]]
  288. iscrowd = [int(o['iscrowd']) for o in gt]
  289. # load computed ious
  290. ious = self.ious[vidId, catId][:, gtind] if len(
  291. self.ious[vidId, catId]) > 0 else self.ious[vidId, catId]
  292. T = len(p.iouThrs)
  293. G = len(gt)
  294. D = len(dt)
  295. gtm = np.zeros((T, G))
  296. dtm = np.zeros((T, D))
  297. gtIg = np.array([g['_ignore'] for g in gt])
  298. dtIg = np.zeros((T, D))
  299. if not len(ious) == 0:
  300. for tind, t in enumerate(p.iouThrs):
  301. for dind, d in enumerate(dt):
  302. # information about best match so far (m=-1 -> unmatched)
  303. iou = min([t, 1 - 1e-10])
  304. m = -1
  305. for gind, g in enumerate(gt):
  306. # if this gt already matched, and not a crowd, continue
  307. if gtm[tind, gind] > 0 and not iscrowd[gind]:
  308. continue
  309. # if dt matched to reg gt, and on ignore gt, stop
  310. if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1:
  311. break
  312. # continue to next gt unless better match made
  313. if ious[dind, gind] < iou:
  314. continue
  315. # if match successful and best so far,
  316. # store appropriately
  317. iou = ious[dind, gind]
  318. m = gind
  319. # if match made store id of match for both dt and gt
  320. if m == -1:
  321. continue
  322. dtIg[tind, dind] = gtIg[m]
  323. dtm[tind, dind] = gt[m]['id']
  324. gtm[tind, m] = d['id']
  325. # set unmatched detections outside of area range to ignore
  326. a = np.array([
  327. d['avg_area'] < aRng[0] or d['avg_area'] > aRng[1] for d in dt
  328. ]).reshape((1, len(dt)))
  329. dtIg = np.logical_or(dtIg, np.logical_and(dtm == 0, np.repeat(a, T,
  330. 0)))
  331. # store results for given image and category
  332. return {
  333. 'video_id': vidId,
  334. 'category_id': catId,
  335. 'aRng': aRng,
  336. 'maxDet': maxDet,
  337. 'dtIds': [d['id'] for d in dt],
  338. 'gtIds': [g['id'] for g in gt],
  339. 'dtMatches': dtm,
  340. 'gtMatches': gtm,
  341. 'dtScores': [d['score'] for d in dt],
  342. 'gtIgnore': gtIg,
  343. 'dtIgnore': dtIg,
  344. }
  345. def accumulate(self, p=None):
  346. """Accumulate per image evaluation results and store the result in
  347. self.eval.
  348. :param p: input params for evaluation
  349. :return: None
  350. """
  351. print('Accumulating evaluation results...')
  352. tic = time.time()
  353. if not self.evalImgs:
  354. print('Please run evaluate() first')
  355. # allows input customized parameters
  356. if p is None:
  357. p = self.params
  358. p.catIds = p.catIds if p.useCats == 1 else [-1]
  359. T = len(p.iouThrs)
  360. R = len(p.recThrs)
  361. K = len(p.catIds) if p.useCats else 1
  362. A = len(p.areaRng)
  363. M = len(p.maxDets)
  364. precision = -np.ones(
  365. (T, R, K, A, M)) # -1 for the precision of absent categories
  366. recall = -np.ones((T, K, A, M))
  367. scores = -np.ones((T, R, K, A, M))
  368. # create dictionary for future indexing
  369. _pe = self._paramsEval
  370. catIds = _pe.catIds if _pe.useCats else [-1]
  371. setK = set(catIds)
  372. setA = set(map(tuple, _pe.areaRng))
  373. setM = set(_pe.maxDets)
  374. setI = set(_pe.vidIds)
  375. # get inds to evaluate
  376. k_list = [n for n, k in enumerate(p.catIds) if k in setK]
  377. m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
  378. a_list = [
  379. n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng))
  380. if a in setA
  381. ]
  382. i_list = [n for n, i in enumerate(p.vidIds) if i in setI]
  383. I0 = len(_pe.vidIds)
  384. A0 = len(_pe.areaRng)
  385. # retrieve E at each category, area range, and max number of detections
  386. for k, k0 in enumerate(k_list):
  387. Nk = k0 * A0 * I0
  388. for a, a0 in enumerate(a_list):
  389. Na = a0 * I0
  390. for m, maxDet in enumerate(m_list):
  391. E = [self.evalImgs[Nk + Na + i] for i in i_list]
  392. E = [e for e in E if e is not None]
  393. if len(E) == 0:
  394. continue
  395. dtScores = np.concatenate(
  396. [e['dtScores'][0:maxDet] for e in E])
  397. inds = np.argsort(-dtScores, kind='mergesort')
  398. dtScoresSorted = dtScores[inds]
  399. dtm = np.concatenate(
  400. [e['dtMatches'][:, 0:maxDet] for e in E], axis=1)[:,
  401. inds]
  402. dtIg = np.concatenate(
  403. [e['dtIgnore'][:, 0:maxDet] for e in E], axis=1)[:,
  404. inds]
  405. gtIg = np.concatenate([e['gtIgnore'] for e in E])
  406. npig = np.count_nonzero(gtIg == 0)
  407. if npig == 0:
  408. continue
  409. tps = np.logical_and(dtm, np.logical_not(dtIg))
  410. fps = np.logical_and(
  411. np.logical_not(dtm), np.logical_not(dtIg))
  412. tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)
  413. fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float)
  414. for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
  415. tp = np.array(tp)
  416. fp = np.array(fp)
  417. nd_ori = len(tp)
  418. rc = tp / npig
  419. pr = tp / (fp + tp + np.spacing(1))
  420. q = np.zeros((R, ))
  421. ss = np.zeros((R, ))
  422. if nd_ori:
  423. recall[t, k, a, m] = rc[-1]
  424. else:
  425. recall[t, k, a, m] = 0
  426. # use python array gets significant speed improvement
  427. pr = pr.tolist()
  428. q = q.tolist()
  429. for i in range(nd_ori - 1, 0, -1):
  430. if pr[i] > pr[i - 1]:
  431. pr[i - 1] = pr[i]
  432. inds = np.searchsorted(rc, p.recThrs, side='left')
  433. try:
  434. for ri, pi in enumerate(inds):
  435. q[ri] = pr[pi]
  436. ss[ri] = dtScoresSorted[pi]
  437. except Exception:
  438. pass
  439. precision[t, :, k, a, m] = np.array(q)
  440. scores[t, :, k, a, m] = np.array(ss)
  441. self.eval = {
  442. 'params': p,
  443. 'counts': [T, R, K, A, M],
  444. 'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
  445. 'precision': precision,
  446. 'recall': recall,
  447. 'scores': scores,
  448. }
  449. toc = time.time()
  450. print('DONE (t={:0.2f}s).'.format(toc - tic))
  451. def summarize(self):
  452. """Compute and display summary metrics for evaluation results.
  453. Note this function can *only* be applied on the default parameter
  454. setting
  455. """
  456. def _summarize(ap=1, iouThr=None, areaRng='all', maxDets=100):
  457. p = self.params
  458. iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | ' \
  459. 'maxDets={:>3d} ] = {:0.3f}'
  460. titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
  461. typeStr = '(AP)' if ap == 1 else '(AR)'
  462. iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
  463. if iouThr is None else '{:0.2f}'.format(iouThr)
  464. aind = [
  465. i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng
  466. ]
  467. mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
  468. if ap == 1:
  469. # dimension of precision: [TxRxKxAxM]
  470. s = self.eval['precision']
  471. # IoU
  472. if iouThr is not None:
  473. t = np.where(iouThr == p.iouThrs)[0]
  474. s = s[t]
  475. s = s[:, :, :, aind, mind]
  476. else:
  477. # dimension of recall: [TxKxAxM]
  478. s = self.eval['recall']
  479. if iouThr is not None:
  480. t = np.where(iouThr == p.iouThrs)[0]
  481. s = s[t]
  482. s = s[:, :, aind, mind]
  483. if len(s[s > -1]) == 0:
  484. mean_s = -1
  485. else:
  486. mean_s = np.mean(s[s > -1])
  487. print(
  488. iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets,
  489. mean_s))
  490. return mean_s
  491. def _summarizeDets():
  492. stats = np.zeros((12, ))
  493. stats[0] = _summarize(1)
  494. stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
  495. stats[2] = _summarize(
  496. 1, iouThr=.75, maxDets=self.params.maxDets[2])
  497. stats[3] = _summarize(
  498. 1, areaRng='small', maxDets=self.params.maxDets[2])
  499. stats[4] = _summarize(
  500. 1, areaRng='medium', maxDets=self.params.maxDets[2])
  501. stats[5] = _summarize(
  502. 1, areaRng='large', maxDets=self.params.maxDets[2])
  503. stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
  504. stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
  505. stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
  506. stats[9] = _summarize(
  507. 0, areaRng='small', maxDets=self.params.maxDets[2])
  508. stats[10] = _summarize(
  509. 0, areaRng='medium', maxDets=self.params.maxDets[2])
  510. stats[11] = _summarize(
  511. 0, areaRng='large', maxDets=self.params.maxDets[2])
  512. return stats
  513. def _summarizeKps():
  514. stats = np.zeros((10, ))
  515. stats[0] = _summarize(1, maxDets=20)
  516. stats[1] = _summarize(1, maxDets=20, iouThr=.5)
  517. stats[2] = _summarize(1, maxDets=20, iouThr=.75)
  518. stats[3] = _summarize(1, maxDets=20, areaRng='medium')
  519. stats[4] = _summarize(1, maxDets=20, areaRng='large')
  520. stats[5] = _summarize(0, maxDets=20)
  521. stats[6] = _summarize(0, maxDets=20, iouThr=.5)
  522. stats[7] = _summarize(0, maxDets=20, iouThr=.75)
  523. stats[8] = _summarize(0, maxDets=20, areaRng='medium')
  524. stats[9] = _summarize(0, maxDets=20, areaRng='large')
  525. return stats
  526. if not self.eval:
  527. raise Exception('Please run accumulate() first')
  528. iouType = self.params.iouType
  529. if iouType == 'segm' or iouType == 'bbox':
  530. summarize = _summarizeDets
  531. elif iouType == 'keypoints':
  532. summarize = _summarizeKps
  533. self.stats = summarize()
  534. def __str__(self):
  535. self.summarize()
  536. class Params:
  537. """Params for coco evaluation api."""
  538. def setDetParams(self):
  539. self.vidIds = []
  540. self.catIds = []
  541. # np.arange causes trouble. the data point on arange
  542. # is slightly larger than the true value
  543. self.iouThrs = np.linspace(
  544. .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
  545. self.recThrs = np.linspace(
  546. .0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
  547. self.maxDets = [1, 10, 100]
  548. self.areaRng = [[0**2, 1e5**2], [0**2, 128**2], [128**2, 256**2],
  549. [256**2, 1e5**2]]
  550. self.areaRngLbl = ['all', 'small', 'medium', 'large']
  551. self.useCats = 1
  552. def setKpParams(self):
  553. self.vidIds = []
  554. self.catIds = []
  555. # np.arange causes trouble. the data point on arange
  556. # is slightly larger than the true value
  557. self.iouThrs = np.linspace(
  558. .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
  559. self.recThrs = np.linspace(
  560. .0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
  561. self.maxDets = [20]
  562. self.areaRng = [[0**2, 1e5**2], [32**2, 96**2], [96**2, 1e5**2]]
  563. self.areaRngLbl = ['all', 'medium', 'large']
  564. self.useCats = 1
  565. def __init__(self, iouType='segm'):
  566. if iouType == 'segm' or iouType == 'bbox':
  567. self.setDetParams()
  568. elif iouType == 'keypoints':
  569. self.setKpParams()
  570. else:
  571. raise Exception('iouType not supported')
  572. self.iouType = iouType
  573. # useSegm is deprecated
  574. self.useSegm = None