import json

import numpy as np
import pytest

from app.face import EmbeddingGallery
from app.track import (
    EmbeddingAggregationIdentityResolver,
    FaceTracker,
    VotingIdentityResolver,
    WeightedVotingIdentityResolver,
    get_identity_resolver_class,
)


@pytest.fixture()
def gallery(tmp_path):
    """Gallery with 3 known identities in v1 schema."""
    rng = np.random.default_rng(0)
    data = {
        "version": 1,
        "detection_model": "scrfd_s",
        "recognition_model": "ms1mv3_vit_run_4",
        "identities": {
            "Alice": [{"embedding": rng.standard_normal(512).tolist(), "source": "/fake/alice.jpg", "landmark_alignment": True}],
            "Bob": [{"embedding": rng.standard_normal(512).tolist(), "source": "/fake/bob.jpg", "landmark_alignment": True}],
            "Charlie": [{"embedding": rng.standard_normal(512).tolist(), "source": "/fake/charlie.jpg", "landmark_alignment": True}],
        },
    }
    path = tmp_path / "gallery.json"
    path.write_text(json.dumps(data))
    g = EmbeddingGallery(embedding_dim=512)
    g.load(str(path))
    return g


@pytest.fixture()
def alice_embedding(gallery):
    """Normalized embedding for Alice from the gallery."""
    return gallery.embeddings[gallery.names_to_idx["Alice"]]


# --- FaceTracker ---


def test_face_tracker_empty_detections():
    """Tracker should return empty detections when given no input."""
    tracker = FaceTracker()
    result = tracker.update(np.empty((0, 5)), np.empty((0,)))
    assert len(result) == 0


def test_face_tracker_strips_score_column():
    """Tracker should handle (N, 5) bboxes by stripping the score column."""
    tracker = FaceTracker()
    bboxes = np.array([[10, 20, 100, 200, 0.9]], dtype=np.float32)
    scores = np.array([0.9], dtype=np.float32)
    result = tracker.update(bboxes, scores)
    assert result is not None


def test_face_tracker_none_input():
    """Tracker should handle None bboxes gracefully."""
    tracker = FaceTracker()
    result = tracker.update(None, None)
    assert len(result) == 0


# --- EmbeddingAggregationIdentityResolver ---


def test_embedding_resolver_returns_identity(gallery, alice_embedding):
    """Resolver should return Alice's identity when given her embedding."""
    resolver = EmbeddingAggregationIdentityResolver(gallery=gallery, max_history=5)
    identities = resolver.update_identities([alice_embedding], [1])
    assert 1 in identities
    assert identities[1]["name"] == "Alice"
    assert identities[1]["score"] > 0.9


def test_embedding_resolver_empty_input(gallery):
    """Resolver should return empty dict when given no embeddings."""
    resolver = EmbeddingAggregationIdentityResolver(gallery=gallery)
    identities = resolver.update_identities([], [])
    assert identities == {}


def test_embedding_resolver_aggregates_over_time(gallery, alice_embedding):
    """Resolver should accumulate embeddings across calls for the same track."""
    resolver = EmbeddingAggregationIdentityResolver(gallery=gallery, max_history=10)
    for _ in range(5):
        identities = resolver.update_identities([alice_embedding], [42])
    assert identities[42]["name"] == "Alice"
    assert len(resolver.embedding_histories[42]) == 5


def test_embedding_resolver_max_history(gallery, alice_embedding):
    """Embedding history should not exceed max_history."""
    resolver = EmbeddingAggregationIdentityResolver(gallery=gallery, max_history=3)
    for _ in range(10):
        resolver.update_identities([alice_embedding], [1])
    assert len(resolver.embedding_histories[1]) == 3


# --- VotingIdentityResolver ---


def test_voting_resolver_returns_identity(gallery, alice_embedding):
    """Voting resolver should return Alice's identity after several frames."""
    resolver = VotingIdentityResolver(gallery=gallery, max_history=10, top_k=1)
    for _ in range(5):
        identities = resolver.update_identities([alice_embedding], [1])
    assert identities[1]["name"] == "Alice"
    assert identities[1]["score"] > 0.0


def test_voting_resolver_empty_input(gallery):
    """Voting resolver should return empty dict when given no embeddings."""
    resolver = VotingIdentityResolver(gallery=gallery)
    identities = resolver.update_identities([], [])
    assert identities == {}


def test_voting_resolver_max_history(gallery, alice_embedding):
    """Identity history should not exceed max_history."""
    resolver = VotingIdentityResolver(gallery=gallery, max_history=3, top_k=1)
    for _ in range(10):
        resolver.update_identities([alice_embedding], [1])
    assert len(resolver.identity_histories[1]) == 3


# --- WeightedVotingIdentityResolver ---


def test_weighted_voting_resolver_returns_identity(gallery, alice_embedding):
    """Weighted voting resolver should return Alice's identity."""
    resolver = WeightedVotingIdentityResolver(gallery=gallery, max_history=10, top_k=2)
    for _ in range(5):
        identities = resolver.update_identities([alice_embedding], [1])
    assert identities[1]["name"] == "Alice"


def test_weighted_voting_resolver_empty_input(gallery):
    """Weighted voting resolver should return empty dict when given no embeddings."""
    resolver = WeightedVotingIdentityResolver(gallery=gallery)
    identities = resolver.update_identities([], [])
    assert identities == {}


# --- get_identity_resolver_class ---


def test_get_identity_resolver_class_embedding():
    cls = get_identity_resolver_class("embedding")
    assert cls is EmbeddingAggregationIdentityResolver


def test_get_identity_resolver_class_voting():
    cls = get_identity_resolver_class("voting")
    assert cls is VotingIdentityResolver


def test_get_identity_resolver_class_weighted_voting():
    cls = get_identity_resolver_class("weighted_voting")
    assert cls is WeightedVotingIdentityResolver


def test_get_identity_resolver_class_unknown_falls_back():
    """Unknown strategy should fall back to EmbeddingAggregationIdentityResolver."""
    cls = get_identity_resolver_class("nonexistent")
    assert cls is EmbeddingAggregationIdentityResolver
