loops.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmengine.model import is_model_wrapper
  3. from mmengine.runner import ValLoop
  4. from mmdet.registry import LOOPS
  5. @LOOPS.register_module()
  6. class TeacherStudentValLoop(ValLoop):
  7. """Loop for validation of model teacher and student."""
  8. def run(self):
  9. """Launch validation for model teacher and student."""
  10. self.runner.call_hook('before_val')
  11. self.runner.call_hook('before_val_epoch')
  12. self.runner.model.eval()
  13. model = self.runner.model
  14. if is_model_wrapper(model):
  15. model = model.module
  16. assert hasattr(model, 'teacher')
  17. assert hasattr(model, 'student')
  18. predict_on = model.semi_test_cfg.get('predict_on', None)
  19. multi_metrics = dict()
  20. for _predict_on in ['teacher', 'student']:
  21. model.semi_test_cfg['predict_on'] = _predict_on
  22. for idx, data_batch in enumerate(self.dataloader):
  23. self.run_iter(idx, data_batch)
  24. # compute metrics
  25. metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
  26. multi_metrics.update(
  27. {'/'.join((_predict_on, k)): v
  28. for k, v in metrics.items()})
  29. model.semi_test_cfg['predict_on'] = predict_on
  30. self.runner.call_hook('after_val_epoch', metrics=multi_metrics)
  31. self.runner.call_hook('after_val')