test_roi_embed_head.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import mmengine
  4. import torch
  5. from mmengine.structures import InstanceData
  6. from mmdet.models.tracking_heads import RoIEmbedHead
  7. from mmdet.registry import TASK_UTILS
  8. def _dummy_bbox_sampling(rpn_results_list, batch_gt_instances):
  9. """Create sample results that can be passed to Head.get_targets."""
  10. num_imgs = len(rpn_results_list)
  11. feat = torch.rand(1, 1, 3, 3)
  12. assign_config = dict(
  13. type='MaxIoUAssigner',
  14. pos_iou_thr=0.5,
  15. neg_iou_thr=0.5,
  16. min_pos_iou=0.5,
  17. ignore_iof_thr=-1)
  18. sampler_config = dict(
  19. type='RandomSampler',
  20. num=512,
  21. pos_fraction=0.25,
  22. neg_pos_ub=-1,
  23. add_gt_as_proposals=False)
  24. bbox_assigner = TASK_UTILS.build(assign_config)
  25. bbox_sampler = TASK_UTILS.build(sampler_config)
  26. sampling_results = []
  27. for i in range(num_imgs):
  28. assign_result = bbox_assigner.assign(rpn_results_list[i],
  29. batch_gt_instances[i])
  30. sampling_result = bbox_sampler.sample(
  31. assign_result,
  32. rpn_results_list[i],
  33. batch_gt_instances[i],
  34. feats=feat)
  35. sampling_results.append(sampling_result)
  36. return sampling_results
  37. class TestRoIEmbedHead(TestCase):
  38. def test_roi_embed_head_loss(self):
  39. """Test roi embed head loss when truth is non-empty."""
  40. cfg = mmengine.Config(
  41. dict(
  42. num_convs=2,
  43. num_fcs=2,
  44. roi_feat_size=7,
  45. in_channels=16,
  46. fc_out_channels=32))
  47. embed_head = RoIEmbedHead(**cfg)
  48. x = torch.rand(1, 16, 7, 7)
  49. ref_x = torch.rand(1, 16, 7, 7)
  50. num_x_per_img = [1]
  51. num_x_per_ref_img = [1]
  52. x_split, ref_x_split = embed_head.forward(x, ref_x, num_x_per_img,
  53. num_x_per_ref_img)
  54. gt_instance_ids = [torch.LongTensor([2])]
  55. ref_gt_instance_ids = [torch.LongTensor([2])]
  56. rpn_results = InstanceData()
  57. rpn_results.labels = torch.LongTensor([2])
  58. rpn_results.priors = torch.Tensor(
  59. [[23.6667, 23.8757, 238.6326, 151.8874]])
  60. rpn_results_list = [rpn_results]
  61. gt_instance = InstanceData()
  62. gt_instance.labels = torch.LongTensor([2])
  63. gt_instance.bboxes = torch.Tensor(
  64. [[23.6667, 23.8757, 238.6326, 151.8874]])
  65. gt_instance.instances_id = torch.LongTensor([2])
  66. batch_gt_instances = [gt_instance]
  67. sampling_results = _dummy_bbox_sampling(rpn_results_list,
  68. batch_gt_instances)
  69. gt_losses = embed_head.loss_by_feat(x_split, ref_x_split,
  70. sampling_results, gt_instance_ids,
  71. ref_gt_instance_ids)
  72. assert gt_losses['loss_match'] > 0, 'match loss should be non-zero'
  73. assert gt_losses[
  74. 'match_accuracy'] >= 0, 'match accuracy should be non-zero or zero'
  75. def test_roi_embed_head_predict(self):
  76. cfg = mmengine.Config(
  77. dict(
  78. num_convs=2,
  79. num_fcs=2,
  80. roi_feat_size=7,
  81. in_channels=16,
  82. fc_out_channels=32))
  83. embed_head = RoIEmbedHead(**cfg)
  84. x = torch.rand(1, 16, 7, 7)
  85. ref_x = torch.rand(1, 16, 7, 7)
  86. similarity_logits = embed_head.predict(x, ref_x)
  87. assert isinstance(similarity_logits, list)
  88. assert len(similarity_logits) == 1