# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple

try:
    import lap
except ImportError:
    lap = None
import numpy as np
import torch
from addict import Dict
from mmengine.structures import InstanceData

from mmdet.registry import MODELS
from mmdet.structures import DetDataSample
from mmdet.structures.bbox import (bbox_cxcyah_to_xyxy, bbox_overlaps,
                                   bbox_xyxy_to_cxcyah)
from .sort_tracker import SORTTracker


@MODELS.register_module()
class OCSORTTracker(SORTTracker):
    """Tracker for OC-SORT.

    Args:
        motion (dict): Configuration of motion. Defaults to None.
        obj_score_thrs (float): Detection score threshold for matching objects.
            Defaults to 0.3.
        init_track_thr (float): Detection score threshold for initializing a
            new tracklet. Defaults to 0.7.
        weight_iou_with_det_scores (bool): Whether using detection scores to
            weight IOU which is used for matching. Defaults to True.
        match_iou_thr (float): IOU distance threshold for matching between two
            frames. Defaults to 0.3.
        num_tentatives (int, optional): Number of continuous frames to confirm
            a track. Defaults to 3.
        vel_consist_weight (float): Weight of the velocity consistency term in
            association (OCM term in the paper).
        vel_delta_t (int): The difference of time step for calculating of the
            velocity direction of tracklets.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 motion: Optional[dict] = None,
                 obj_score_thr: float = 0.3,
                 init_track_thr: float = 0.7,
                 weight_iou_with_det_scores: bool = True,
                 match_iou_thr: float = 0.3,
                 num_tentatives: int = 3,
                 vel_consist_weight: float = 0.2,
                 vel_delta_t: int = 3,
                 **kwargs):
        if lap is None:
            raise RuntimeError('lap is not installed,\
                 please install it by: pip install lap')
        super().__init__(motion=motion, **kwargs)
        self.obj_score_thr = obj_score_thr
        self.init_track_thr = init_track_thr

        self.weight_iou_with_det_scores = weight_iou_with_det_scores
        self.match_iou_thr = match_iou_thr
        self.vel_consist_weight = vel_consist_weight
        self.vel_delta_t = vel_delta_t

        self.num_tentatives = num_tentatives

    @property
    def unconfirmed_ids(self):
        """Unconfirmed ids in the tracker."""
        ids = [id for id, track in self.tracks.items() if track.tentative]
        return ids

    def init_track(self, id: int, obj: Tuple[torch.Tensor]):
        """Initialize a track."""
        super().init_track(id, obj)
        if self.tracks[id].frame_ids[-1] == 0:
            self.tracks[id].tentative = False
        else:
            self.tracks[id].tentative = True
        bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1])  # size = (1, 4)
        assert bbox.ndim == 2 and bbox.shape[0] == 1
        bbox = bbox.squeeze(0).cpu().numpy()
        self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate(
            bbox)
        # track.obs maintains the history associated detections to this track
        self.tracks[id].obs = []
        bbox_id = self.memo_items.index('bboxes')
        self.tracks[id].obs.append(obj[bbox_id])
        # a placefolder to save mean/covariance before losing tracking it
        # parameters to save: mean, covariance, measurement
        self.tracks[id].tracked = True
        self.tracks[id].saved_attr = Dict()
        self.tracks[id].velocity = torch.tensor(
            (-1, -1)).to(obj[bbox_id].device)  # placeholder

    def update_track(self, id: int, obj: Tuple[torch.Tensor]):
        """Update a track."""
        super().update_track(id, obj)
        if self.tracks[id].tentative:
            if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
                self.tracks[id].tentative = False
        bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1])  # size = (1, 4)
        assert bbox.ndim == 2 and bbox.shape[0] == 1
        bbox = bbox.squeeze(0).cpu().numpy()
        self.tracks[id].mean, self.tracks[id].covariance = self.kf.update(
            self.tracks[id].mean, self.tracks[id].covariance, bbox)
        self.tracks[id].tracked = True
        bbox_id = self.memo_items.index('bboxes')
        self.tracks[id].obs.append(obj[bbox_id])

        bbox1 = self.k_step_observation(self.tracks[id])
        bbox2 = obj[bbox_id]
        self.tracks[id].velocity = self.vel_direction(bbox1, bbox2).to(
            obj[bbox_id].device)

    def vel_direction(self, bbox1: torch.Tensor, bbox2: torch.Tensor):
        """Estimate the direction vector between two boxes."""
        if bbox1.sum() < 0 or bbox2.sum() < 0:
            return torch.tensor((-1, -1))
        cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
        cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
        speed = torch.tensor([cy2 - cy1, cx2 - cx1])
        norm = torch.sqrt((speed[0])**2 + (speed[1])**2) + 1e-6
        return speed / norm

    def vel_direction_batch(self, bboxes1: torch.Tensor,
                            bboxes2: torch.Tensor):
        """Estimate the direction vector given two batches of boxes."""
        cx1, cy1 = (bboxes1[:, 0] + bboxes1[:, 2]) / 2.0, (bboxes1[:, 1] +
                                                           bboxes1[:, 3]) / 2.0
        cx2, cy2 = (bboxes2[:, 0] + bboxes2[:, 2]) / 2.0, (bboxes2[:, 1] +
                                                           bboxes2[:, 3]) / 2.0
        speed_diff_y = cy2[None, :] - cy1[:, None]
        speed_diff_x = cx2[None, :] - cx1[:, None]
        speed = torch.cat((speed_diff_y[..., None], speed_diff_x[..., None]),
                          dim=-1)
        norm = torch.sqrt((speed[:, :, 0])**2 + (speed[:, :, 1])**2) + 1e-6
        speed[:, :, 0] /= norm
        speed[:, :, 1] /= norm
        return speed

    def k_step_observation(self, track: Dict):
        """return the observation k step away before."""
        obs_seqs = track.obs
        num_obs = len(obs_seqs)
        if num_obs == 0:
            return torch.tensor((-1, -1, -1, -1)).to(track.obs[0].device)
        elif num_obs > self.vel_delta_t:
            if obs_seqs[num_obs - 1 - self.vel_delta_t] is not None:
                return obs_seqs[num_obs - 1 - self.vel_delta_t]
            else:
                return self.last_obs(track)
        else:
            return self.last_obs(track)

    def ocm_assign_ids(self,
                       ids: List[int],
                       det_bboxes: torch.Tensor,
                       det_labels: torch.Tensor,
                       det_scores: torch.Tensor,
                       weight_iou_with_det_scores: Optional[bool] = False,
                       match_iou_thr: Optional[float] = 0.5):
        """Apply Observation-Centric Momentum (OCM) to assign ids.

        OCM adds movement direction consistency into the association cost
        matrix. This term requires no additional assumption but from the
        same linear motion assumption as the canonical Kalman Filter in SORT.

        Args:
            ids (list[int]): Tracking ids.
            det_bboxes (Tensor): of shape (N, 4)
            det_labels (Tensor): of shape (N,)
            det_scores (Tensor): of shape (N,)
            weight_iou_with_det_scores (bool, optional): Whether using
                detection scores to weight IOU which is used for matching.
                Defaults to False.
            match_iou_thr (float, optional): Matching threshold.
                Defaults to 0.5.

        Returns:
            tuple(int): The assigning ids.

        OC-SORT uses velocity consistency besides IoU for association
        """
        # get track_bboxes
        track_bboxes = np.zeros((0, 4))
        for id in ids:
            track_bboxes = np.concatenate(
                (track_bboxes, self.tracks[id].mean[:4][None]), axis=0)
        track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes)
        track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes)

        # compute distance
        ious = bbox_overlaps(track_bboxes, det_bboxes)
        if weight_iou_with_det_scores:
            ious *= det_scores

        # support multi-class association
        track_labels = torch.tensor([
            self.tracks[id]['labels'][-1] for id in ids
        ]).to(det_bboxes.device)
        cate_match = det_labels[None, :] == track_labels[:, None]
        # to avoid det and track of different categories are matched
        cate_cost = (1 - cate_match.int()) * 1e6

        dists = (1 - ious + cate_cost).cpu().numpy()

        if len(ids) > 0 and len(det_bboxes) > 0:
            track_velocities = torch.stack(
                [self.tracks[id].velocity for id in ids]).to(det_bboxes.device)
            k_step_observations = torch.stack([
                self.k_step_observation(self.tracks[id]) for id in ids
            ]).to(det_bboxes.device)
            # valid1: if the track has previous observations to estimate speed
            # valid2: if the associated observation k steps ago is a detection
            valid1 = track_velocities.sum(dim=1) != -2
            valid2 = k_step_observations.sum(dim=1) != -4
            valid = valid1 & valid2

            vel_to_match = self.vel_direction_batch(k_step_observations,
                                                    det_bboxes)
            track_velocities = track_velocities[:, None, :].repeat(
                1, det_bboxes.shape[0], 1)

            angle_cos = (vel_to_match * track_velocities).sum(dim=-1)
            angle_cos = torch.clamp(angle_cos, min=-1, max=1)
            angle = torch.acos(angle_cos)  # [0, pi]
            norm_angle = (angle - np.pi / 2.) / np.pi  # [-0.5, 0.5]
            valid_matrix = valid[:, None].int().repeat(1, det_bboxes.shape[0])
            # set non-valid entries 0
            valid_norm_angle = norm_angle * valid_matrix

            dists += valid_norm_angle.cpu().numpy() * self.vel_consist_weight

        # bipartite match
        if dists.size > 0:
            cost, row, col = lap.lapjv(
                dists, extend_cost=True, cost_limit=1 - match_iou_thr)
        else:
            row = np.zeros(len(ids)).astype(np.int32) - 1
            col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
        return row, col

    def last_obs(self, track: Dict):
        """extract the last associated observation."""
        for bbox in track.obs[::-1]:
            if bbox is not None:
                return bbox

    def ocr_assign_ids(self,
                       track_obs: torch.Tensor,
                       last_track_labels: torch.Tensor,
                       det_bboxes: torch.Tensor,
                       det_labels: torch.Tensor,
                       det_scores: torch.Tensor,
                       weight_iou_with_det_scores: Optional[bool] = False,
                       match_iou_thr: Optional[float] = 0.5):
        """association for Observation-Centric Recovery.

        As try to recover tracks from being lost whose estimated velocity is
        out- to-date, we use IoU-only matching strategy.

        Args:
            track_obs (Tensor): the list of historical associated
                detections of tracks
            det_bboxes (Tensor): of shape (N, 5), unmatched detections
            det_labels (Tensor): of shape (N,)
            det_scores (Tensor): of shape (N,)
            weight_iou_with_det_scores (bool, optional): Whether using
                detection scores to weight IOU which is used for matching.
                Defaults to False.
            match_iou_thr (float, optional): Matching threshold.
                Defaults to 0.5.

        Returns:
            tuple(int): The assigning ids.
        """
        # compute distance
        ious = bbox_overlaps(track_obs, det_bboxes)
        if weight_iou_with_det_scores:
            ious *= det_scores

        # support multi-class association
        cate_match = det_labels[None, :] == last_track_labels[:, None]
        # to avoid det and track of different categories are matched
        cate_cost = (1 - cate_match.int()) * 1e6

        dists = (1 - ious + cate_cost).cpu().numpy()

        # bipartite match
        if dists.size > 0:
            cost, row, col = lap.lapjv(
                dists, extend_cost=True, cost_limit=1 - match_iou_thr)
        else:
            row = np.zeros(len(track_obs)).astype(np.int32) - 1
            col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
        return row, col

    def online_smooth(self, track: Dict, obj: torch.Tensor):
        """Once a track is recovered from being lost, online smooth its
        parameters to fix the error accumulated during being lost.

        NOTE: you can use different virtual trajectory generation
        strategies, we adopt the naive linear interpolation as default
        """
        last_match_bbox = self.last_obs(track)
        new_match_bbox = obj
        unmatch_len = 0
        for bbox in track.obs[::-1]:
            if bbox is None:
                unmatch_len += 1
            else:
                break
        bbox_shift_per_step = (new_match_bbox - last_match_bbox) / (
            unmatch_len + 1)
        track.mean = track.saved_attr.mean
        track.covariance = track.saved_attr.covariance
        for i in range(unmatch_len):
            virtual_bbox = last_match_bbox + (i + 1) * bbox_shift_per_step
            virtual_bbox = bbox_xyxy_to_cxcyah(virtual_bbox[None, :])
            virtual_bbox = virtual_bbox.squeeze(0).cpu().numpy()
            track.mean, track.covariance = self.kf.update(
                track.mean, track.covariance, virtual_bbox)

    def track(self, data_sample: DetDataSample, **kwargs) -> InstanceData:
        """Tracking forward function.
        NOTE: this implementation is slightly different from the original
        OC-SORT implementation (https://github.com/noahcao/OC_SORT)that we
        do association between detections and tentative/non-tentative tracks
        independently while the original implementation combines them together.

        Args:
            data_sample (:obj:`DetDataSample`): The data sample.
                It includes information such as `pred_instances`.

        Returns:
            :obj:`InstanceData`: Tracking results of the input images.
            Each InstanceData usually contains ``bboxes``, ``labels``,
            ``scores`` and ``instances_id``.
        """
        metainfo = data_sample.metainfo
        bboxes = data_sample.pred_instances.bboxes
        labels = data_sample.pred_instances.labels
        scores = data_sample.pred_instances.scores
        frame_id = metainfo.get('frame_id', -1)
        if frame_id == 0:
            self.reset()
        if not hasattr(self, 'kf'):
            self.kf = self.motion

        if self.empty or bboxes.size(0) == 0:
            valid_inds = scores > self.init_track_thr
            scores = scores[valid_inds]
            bboxes = bboxes[valid_inds]
            labels = labels[valid_inds]
            num_new_tracks = bboxes.size(0)
            ids = torch.arange(self.num_tracks,
                               self.num_tracks + num_new_tracks).to(labels)
            self.num_tracks += num_new_tracks
        else:
            # 0. init
            ids = torch.full((bboxes.size(0), ),
                             -1,
                             dtype=labels.dtype,
                             device=labels.device)

            # get the detection bboxes for the first association
            det_inds = scores > self.obj_score_thr
            det_bboxes = bboxes[det_inds]
            det_labels = labels[det_inds]
            det_scores = scores[det_inds]
            det_ids = ids[det_inds]

            # 1. predict by Kalman Filter
            for id in self.confirmed_ids:
                # track is lost in previous frame
                if self.tracks[id].frame_ids[-1] != frame_id - 1:
                    self.tracks[id].mean[7] = 0
                if self.tracks[id].tracked:
                    self.tracks[id].saved_attr.mean = self.tracks[id].mean
                    self.tracks[id].saved_attr.covariance = self.tracks[
                        id].covariance
                (self.tracks[id].mean,
                 self.tracks[id].covariance) = self.kf.predict(
                     self.tracks[id].mean, self.tracks[id].covariance)

            # 2. match detections and tracks' predicted locations
            match_track_inds, raw_match_det_inds = self.ocm_assign_ids(
                self.confirmed_ids, det_bboxes, det_labels, det_scores,
                self.weight_iou_with_det_scores, self.match_iou_thr)
            # '-1' mean a detection box is not matched with tracklets in
            # previous frame
            valid = raw_match_det_inds > -1
            det_ids[valid] = torch.tensor(
                self.confirmed_ids)[raw_match_det_inds[valid]].to(labels)

            match_det_bboxes = det_bboxes[valid]
            match_det_labels = det_labels[valid]
            match_det_scores = det_scores[valid]
            match_det_ids = det_ids[valid]
            assert (match_det_ids > -1).all()

            # unmatched tracks and detections
            unmatch_det_bboxes = det_bboxes[~valid]
            unmatch_det_labels = det_labels[~valid]
            unmatch_det_scores = det_scores[~valid]
            unmatch_det_ids = det_ids[~valid]
            assert (unmatch_det_ids == -1).all()

            # 3. use unmatched detection bboxes from the first match to match
            # the unconfirmed tracks
            (tentative_match_track_inds,
             tentative_match_det_inds) = self.ocm_assign_ids(
                 self.unconfirmed_ids, unmatch_det_bboxes, unmatch_det_labels,
                 unmatch_det_scores, self.weight_iou_with_det_scores,
                 self.match_iou_thr)
            valid = tentative_match_det_inds > -1
            unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_ids)[
                tentative_match_det_inds[valid]].to(labels)

            match_det_bboxes = torch.cat(
                (match_det_bboxes, unmatch_det_bboxes[valid]), dim=0)
            match_det_labels = torch.cat(
                (match_det_labels, unmatch_det_labels[valid]), dim=0)
            match_det_scores = torch.cat(
                (match_det_scores, unmatch_det_scores[valid]), dim=0)
            match_det_ids = torch.cat((match_det_ids, unmatch_det_ids[valid]),
                                      dim=0)
            assert (match_det_ids > -1).all()

            unmatch_det_bboxes = unmatch_det_bboxes[~valid]
            unmatch_det_labels = unmatch_det_labels[~valid]
            unmatch_det_scores = unmatch_det_scores[~valid]
            unmatch_det_ids = unmatch_det_ids[~valid]
            assert (unmatch_det_ids == -1).all()

            all_track_ids = [id for id, _ in self.tracks.items()]
            unmatched_track_inds = torch.tensor(
                [ind for ind in all_track_ids if ind not in match_det_ids])

            if len(unmatched_track_inds) > 0:
                # 4. still some tracks not associated yet, perform OCR
                last_observations = []
                for id in unmatched_track_inds:
                    last_box = self.last_obs(self.tracks[id.item()])
                    last_observations.append(last_box)
                last_observations = torch.stack(last_observations)
                last_track_labels = torch.tensor([
                    self.tracks[id.item()]['labels'][-1]
                    for id in unmatched_track_inds
                ]).to(det_bboxes.device)

                remain_det_ids = torch.full((unmatch_det_bboxes.size(0), ),
                                            -1,
                                            dtype=labels.dtype,
                                            device=labels.device)

                _, ocr_match_det_inds = self.ocr_assign_ids(
                    last_observations, last_track_labels, unmatch_det_bboxes,
                    unmatch_det_labels, unmatch_det_scores,
                    self.weight_iou_with_det_scores, self.match_iou_thr)

                valid = ocr_match_det_inds > -1
                remain_det_ids[valid] = unmatched_track_inds.clone()[
                    ocr_match_det_inds[valid]].to(labels)

                ocr_match_det_bboxes = unmatch_det_bboxes[valid]
                ocr_match_det_labels = unmatch_det_labels[valid]
                ocr_match_det_scores = unmatch_det_scores[valid]
                ocr_match_det_ids = remain_det_ids[valid]
                assert (ocr_match_det_ids > -1).all()

                ocr_unmatch_det_bboxes = unmatch_det_bboxes[~valid]
                ocr_unmatch_det_labels = unmatch_det_labels[~valid]
                ocr_unmatch_det_scores = unmatch_det_scores[~valid]
                ocr_unmatch_det_ids = remain_det_ids[~valid]
                assert (ocr_unmatch_det_ids == -1).all()

                unmatch_det_bboxes = ocr_unmatch_det_bboxes
                unmatch_det_labels = ocr_unmatch_det_labels
                unmatch_det_scores = ocr_unmatch_det_scores
                unmatch_det_ids = ocr_unmatch_det_ids
                match_det_bboxes = torch.cat(
                    (match_det_bboxes, ocr_match_det_bboxes), dim=0)
                match_det_labels = torch.cat(
                    (match_det_labels, ocr_match_det_labels), dim=0)
                match_det_scores = torch.cat(
                    (match_det_scores, ocr_match_det_scores), dim=0)
                match_det_ids = torch.cat((match_det_ids, ocr_match_det_ids),
                                          dim=0)

            # 5. summarize the track results
            for i in range(len(match_det_ids)):
                det_bbox = match_det_bboxes[i]
                track_id = match_det_ids[i].item()
                if not self.tracks[track_id].tracked:
                    # the track is lost before this step
                    self.online_smooth(self.tracks[track_id], det_bbox)

            for track_id in all_track_ids:
                if track_id not in match_det_ids:
                    self.tracks[track_id].tracked = False
                    self.tracks[track_id].obs.append(None)

            bboxes = torch.cat((match_det_bboxes, unmatch_det_bboxes), dim=0)
            labels = torch.cat((match_det_labels, unmatch_det_labels), dim=0)
            scores = torch.cat((match_det_scores, unmatch_det_scores), dim=0)
            ids = torch.cat((match_det_ids, unmatch_det_ids), dim=0)
            # 6. assign new ids
            new_track_inds = ids == -1

            ids[new_track_inds] = torch.arange(
                self.num_tracks,
                self.num_tracks + new_track_inds.sum()).to(labels)
            self.num_tracks += new_track_inds.sum()

        self.update(
            ids=ids,
            bboxes=bboxes,
            labels=labels,
            scores=scores,
            frame_ids=frame_id)

        # update pred_track_instances
        pred_track_instances = InstanceData()
        pred_track_instances.bboxes = bboxes
        pred_track_instances.labels = labels
        pred_track_instances.scores = scores
        pred_track_instances.instances_id = ids
        return pred_track_instances