# Copyright (c) OpenMMLab. All rights reserved. import io import json import logging import os from urllib.parse import urlparse import boto3 from botocore.exceptions import ClientError from label_studio_ml.model import LabelStudioMLBase from label_studio_ml.utils import (DATA_UNDEFINED_NAME, get_image_size, get_single_tag_keys) from label_studio_tools.core.utils.io import get_data_dir from mmdet.apis import inference_detector, init_detector logger = logging.getLogger(__name__) class MMDetection(LabelStudioMLBase): """Object detector based on https://github.com/open-mmlab/mmdetection.""" def __init__(self, config_file=None, checkpoint_file=None, image_dir=None, labels_file=None, score_threshold=0.5, device='cpu', **kwargs): super(MMDetection, self).__init__(**kwargs) config_file = config_file or os.environ['config_file'] checkpoint_file = checkpoint_file or os.environ['checkpoint_file'] self.config_file = config_file self.checkpoint_file = checkpoint_file self.labels_file = labels_file # default Label Studio image upload folder upload_dir = os.path.join(get_data_dir(), 'media', 'upload') self.image_dir = image_dir or upload_dir logger.debug( f'{self.__class__.__name__} reads images from {self.image_dir}') if self.labels_file and os.path.exists(self.labels_file): self.label_map = json_load(self.labels_file) else: self.label_map = {} self.from_name, self.to_name, self.value, self.labels_in_config = get_single_tag_keys( # noqa E501 self.parsed_label_config, 'RectangleLabels', 'Image') schema = list(self.parsed_label_config.values())[0] self.labels_in_config = set(self.labels_in_config) # Collect label maps from `predicted_values="airplane,car"` attribute in