from collections import Counter, defaultdict, deque
from typing import Protocol

import numpy as np
import supervision as sv
from trackers import ByteTrackTracker

from app.face import EmbeddingGallery


class IdentityResolver(Protocol):
    """Protocol defining the interface for identity resolution strategies."""

    def update_identities(
        self,
        embeddings: list[np.ndarray],
        track_ids: list[int],
    ) -> dict[int, dict]:
        """Update and return identities for tracked faces.

        Args:
            embeddings: Face embeddings for high-quality detections
            track_ids: Corresponding track IDs for each embedding

        Returns:
            Dict mapping track_id to identity info dict with 'name' and 'score'
        """
        ...


class FaceTracker:
    def __init__(
        self,
        lost_track_buffer: int = 30,
        frame_rate: float = 30.0,
        track_activation_threshold: float = 0.7,
        minimum_consecutive_frames: int = 2,
        minimum_iou_threshold: float = 0.1,
        high_conf_det_threshold: float = 0.6,
    ) -> None:
        self.tracker = ByteTrackTracker(
            lost_track_buffer=lost_track_buffer,
            frame_rate=frame_rate,
            track_activation_threshold=track_activation_threshold,
            minimum_consecutive_frames=minimum_consecutive_frames,
            minimum_iou_threshold=minimum_iou_threshold,
            high_conf_det_threshold=high_conf_det_threshold,
        )

    def update(
        self,
        detection_bboxes: np.ndarray,
        detection_scores: np.ndarray,
        keypoints: np.ndarray | None = None,
    ) -> sv.Detections:
        """Update tracks with new detections and return tracked detections.

        Returns a `sv.Detections` instance with `tracker_id` assigned.
        If `keypoints` is provided (shape [N, 10]), it is stored in `data['keypoints']`
        before passing to ByteTrack so the correspondence is preserved after tracking.
        """
        if detection_bboxes is None or len(detection_bboxes) == 0:
            detections = sv.Detections.empty()
        else:
            scores = np.asarray(detection_scores).reshape(-1).astype(np.float32)
            bboxes = np.asarray(detection_bboxes, dtype=np.float32)
            if bboxes.ndim == 2 and bboxes.shape[1] > 4:
                bboxes = bboxes[:, :4]
            data = {}
            if keypoints is not None:
                data["keypoints"] = np.asarray(keypoints, dtype=np.float32)
            detections = sv.Detections(
                xyxy=bboxes,
                confidence=scores,
                data=data,
            )

        return self.tracker.update(detections)


class EmbeddingAggregationIdentityResolver:
    """Identity resolver that aggregates embeddings over time, then matches against gallery.

    This resolver implements a temporal smoothening strategy where:
    1. Face embeddings are accumulated over a sliding window
    2. Embeddings are averaged to create a more stable representation
    3. The aggregated embedding is matched against the face gallery
    4. This approach provides more stable identities but slower initial recognition

    Use this when you want stable face recognition with minimal identity flickering.
    """

    def __init__(
        self,
        gallery: EmbeddingGallery,
        max_history: int = 30,
        device: str = "cpu",
    ) -> None:
        self.gallery = gallery
        self.device = device
        self.embedding_histories: dict[int, deque] = defaultdict(
            lambda: deque(maxlen=max_history),
        )

    def update_identities(
        self,
        embeddings: list[np.ndarray],
        track_ids: list[int],
    ) -> dict[int, dict]:
        """Update identities by aggregating embeddings over time

        Args:
            embeddings: Face embeddings for high-quality detections
            track_ids: Corresponding track IDs for each embedding
        """

        identities = {}
        for track_id, embedding in zip(track_ids, embeddings):
            self.embedding_histories[track_id].append(embedding.copy())
            embeddings_array = np.array(self.embedding_histories[track_id])
            aggregated_embedding = np.mean(embeddings_array, axis=0)

            norm = np.linalg.norm(aggregated_embedding)
            if norm > 0:
                aggregated_embedding = aggregated_embedding / norm
            _, name, score = self.gallery.match(
                aggregated_embedding,
                device=self.device,
                top_k=1,
            )[0]
            identities[track_id] = {"name": name, "score": score}

        return identities


class WeightedVotingIdentityResolver:
    def __init__(
        self,
        gallery: EmbeddingGallery,
        max_history: int = 30,
        top_k: int = 3,
        device: str = "cpu",
    ) -> None:
        self.gallery = gallery
        self.top_k = top_k
        self.device = device
        self.frame_histories: dict[int, deque] = defaultdict(
            lambda: deque(maxlen=max_history),
        )

    def update_identities(
        self,
        embeddings: list[np.ndarray],
        track_ids: list[int],
    ) -> dict[int, dict]:
        identities = {}
        for track_id, embedding in zip(track_ids, embeddings):
            candidates = self.gallery.match(
                embedding,
                device=self.device,
                top_k=self.top_k,
            )

            frame_votes = {}
            for rank, (_, name, score) in enumerate(candidates):
                points = self.top_k - rank
                if points <= 0:
                    break
                frame_votes[name] = (points, score)

            if frame_votes:
                self.frame_histories[track_id].append(frame_votes)

            frame_history = self.frame_histories[track_id]
            if len(frame_history) == 0:
                identities[track_id] = {"name": None, "score": None}
                continue

            total_points = defaultdict(int)
            score_sums = defaultdict(float)
            score_counts = defaultdict(int)
            for frame in frame_history:
                for name, (points, score) in frame.items():
                    total_points[name] += points
                    score_sums[name] += float(score)
                    score_counts[name] += 1

            if len(total_points) == 0:
                identities[track_id] = {"name": None, "score": None}
                continue

            identity_name = max(total_points, key=total_points.get)

            count = score_counts.get(identity_name, 0)
            identity_score = (score_sums[identity_name] / count) if count > 0 else None
            identities[track_id] = {"name": identity_name, "score": identity_score}

        return identities


class VotingIdentityResolver:
    """Identity resolver that matches each embedding immediately, then votes on identities.

    This resolver implements an immediate matching strategy where:
    1. Each new embedding is matched against the face gallery immediately
    2. The resulting top-K identities (names, scores) are added to a sliding window history
    3. The final identity is determined by voting - the most frequent identity wins
    4. This approach provides faster initial identification but may be less stable

    Use this when you want rapid face recognition at the cost of potential
    flickering between identities during the initial frames.

    The top_k parameter allows considering multiple matches per embedding:
    - top_k=1: Only the best match is considered (default, backward compatible)
    - top_k>1: Top-K matches are all added to the voting pool, making the system
                more robust to near-confusions between similar faces
    """

    def __init__(
        self,
        gallery: EmbeddingGallery,
        max_history: int = 30,
        top_k: int = 1,
        device: str = "cpu",
    ) -> None:
        self.gallery = gallery
        self.top_k = top_k
        self.device = device
        self.identity_histories: dict[int, deque] = defaultdict(
            lambda: deque(maxlen=max_history),
        )

    def update_identities(
        self,
        embeddings: list[np.ndarray],
        track_ids: list[int],
    ) -> dict[int, dict]:
        """Update identities by immediate matching and voting

        Args:
            embeddings: Face embeddings for high-quality detections
            track_ids: Corresponding track IDs for each embedding
        """
        identities = {}
        for track_id, embedding in zip(track_ids, embeddings):
            candidates = self.gallery.match(
                embedding,
                device=self.device,
                top_k=self.top_k,
            )
            for _, name, score in candidates:
                self.identity_histories[track_id].append((name, score))

            identity_history = self.identity_histories[track_id]
            name_counts = Counter(name for name, _ in identity_history)
            identity_name = name_counts.most_common(1)[0][0]
            scores = [
                score for name, score in identity_history if name == identity_name
            ]
            identity_score = sum(scores) / len(scores) if scores else None
            identities[track_id] = {"name": identity_name, "score": identity_score}

        return identities


def get_identity_resolver_class(
    track_aggregation_strategy: str,
) -> type[IdentityResolver]:
    registry = {
        "embedding": EmbeddingAggregationIdentityResolver,
        "voting": VotingIdentityResolver,
        "weighted_voting": WeightedVotingIdentityResolver,
    }

    return registry.get(
        track_aggregation_strategy.lower(),
        EmbeddingAggregationIdentityResolver,
    )


__all__ = [
    "FaceTracker",
    "IdentityResolver",
    "EmbeddingAggregationIdentityResolver",
    "VotingIdentityResolver",
    "WeightedVotingIdentityResolver",
    "get_identity_resolver_class",
]
