base_video_metric.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import pickle
  4. import shutil
  5. import tempfile
  6. import warnings
  7. from typing import Optional, Sequence
  8. import torch
  9. from mmengine.dist import (barrier, broadcast, broadcast_object_list,
  10. get_dist_info, is_main_process)
  11. from mmengine.evaluator import BaseMetric
  12. from mmengine.utils import mkdir_or_exist
  13. class BaseVideoMetric(BaseMetric):
  14. """Base class for a metric in video task.
  15. The metric first processes each batch of data_samples and predictions,
  16. and appends the processed results to the results list. Then it
  17. collects all results together from all ranks if distributed training
  18. is used. Finally, it computes the metrics of the entire dataset.
  19. A subclass of class:`BaseVideoMetric` should assign a meaningful value
  20. to the class attribute `default_prefix`. See the argument `prefix` for
  21. details.
  22. """
  23. def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
  24. """Process one batch of data samples and predictions.
  25. The processed results should be stored in ``self.results``, which will
  26. be used to compute the metrics when all batches have been processed.
  27. Args:
  28. data_batch (dict): A batch of data from the dataloader.
  29. data_samples (Sequence[dict]): A batch of data samples that
  30. contain annotations and predictions.
  31. """
  32. for track_data_sample in data_samples:
  33. video_data_samples = track_data_sample['video_data_samples']
  34. ori_video_len = video_data_samples[0].ori_video_length
  35. if ori_video_len == len(video_data_samples):
  36. # video process
  37. self.process_video(video_data_samples)
  38. else:
  39. # image process
  40. self.process_image(video_data_samples, ori_video_len)
  41. def evaluate(self, size: int = 1) -> dict:
  42. """Evaluate the model performance of the whole dataset after processing
  43. all batches.
  44. Args:
  45. size (int): Length of the entire validation dataset.
  46. Returns:
  47. dict: Evaluation metrics dict on the val dataset. The keys are the
  48. names of the metrics, and the values are corresponding results.
  49. """
  50. if len(self.results) == 0:
  51. warnings.warn(
  52. f'{self.__class__.__name__} got empty `self.results`. Please '
  53. 'ensure that the processed results are properly added into '
  54. '`self.results` in `process` method.')
  55. results = collect_tracking_results(self.results, self.collect_device)
  56. if is_main_process():
  57. _metrics = self.compute_metrics(results) # type: ignore
  58. # Add prefix to metric names
  59. if self.prefix:
  60. _metrics = {
  61. '/'.join((self.prefix, k)): v
  62. for k, v in _metrics.items()
  63. }
  64. metrics = [_metrics]
  65. else:
  66. metrics = [None] # type: ignore
  67. broadcast_object_list(metrics)
  68. # reset the results list
  69. self.results.clear()
  70. return metrics[0]
  71. def collect_tracking_results(results: list,
  72. device: str = 'cpu',
  73. tmpdir: Optional[str] = None) -> Optional[list]:
  74. """Collected results in distributed environments. different from the
  75. function mmengine.dist.collect_results, tracking compute metrics don't use
  76. paramenter size, which means length of the entire validation dataset.
  77. because it's equal to video num, but compute metrics need image num.
  78. Args:
  79. results (list): Result list containing result parts to be
  80. collected. Each item of ``result_part`` should be a picklable
  81. object.
  82. device (str): Device name. Optional values are 'cpu' and 'gpu'.
  83. tmpdir (str | None): Temporal directory for collected results to
  84. store. If set to None, it will create a temporal directory for it.
  85. ``tmpdir`` should be None when device is 'gpu'. Defaults to None.
  86. Returns:
  87. list or None: The collected results.
  88. """
  89. if device not in ['gpu', 'cpu']:
  90. raise NotImplementedError(
  91. f"device must be 'cpu' or 'gpu', but got {device}")
  92. if device == 'gpu':
  93. assert tmpdir is None, 'tmpdir should be None when device is "gpu"'
  94. raise NotImplementedError('GPU collecting has not been supported yet')
  95. else:
  96. return collect_tracking_results_cpu(results, tmpdir)
  97. def collect_tracking_results_cpu(result_part: list,
  98. tmpdir: Optional[str] = None
  99. ) -> Optional[list]:
  100. """Collect results on cpu mode.
  101. Saves the results on different gpus to 'tmpdir' and collects them by the
  102. rank 0 worker.
  103. Args:
  104. result_part (list): The part of prediction results.
  105. tmpdir (str): Path of directory to save the temporary results from
  106. different gpus under cpu mode. If is None, use `tempfile.mkdtemp()`
  107. to make a temporary path. Defaults to None.
  108. Returns:
  109. list or None: The collected results.
  110. """
  111. rank, world_size = get_dist_info()
  112. if world_size == 1:
  113. return result_part
  114. # create a tmp dir if it is not specified
  115. if tmpdir is None:
  116. MAX_LEN = 512
  117. # 32 is whitespace
  118. dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8)
  119. if rank == 0:
  120. mkdir_or_exist('.dist_test')
  121. tmpdir = tempfile.mkdtemp(dir='.dist_test')
  122. tmpdir = torch.tensor(
  123. bytearray(tmpdir.encode()), dtype=torch.uint8)
  124. dir_tensor[:len(tmpdir)] = tmpdir
  125. broadcast(dir_tensor, 0)
  126. tmpdir = dir_tensor.numpy().tobytes().decode().rstrip()
  127. else:
  128. mkdir_or_exist(tmpdir)
  129. # dump the part result to the dir
  130. with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: # type: ignore
  131. pickle.dump(result_part, f, protocol=2)
  132. barrier()
  133. # collect all parts
  134. if rank != 0:
  135. return None
  136. else:
  137. # load results of all parts from tmp dir
  138. part_list = []
  139. for i in range(world_size):
  140. path = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore
  141. with open(path, 'rb') as f:
  142. part_list.extend(pickle.load(f))
  143. shutil.rmtree(tmpdir)
  144. return part_list