import json

import numpy as np
import pytest

from app.face import (
    EmbeddingGallery,
    _nms,
    extract_person_keypoints,
    preprocess_input_onlyface,
    preprocess_input_scrfd,
)

# --- EmbeddingGallery ---


@pytest.fixture()
def gallery_path(tmp_path):
    """Create a temporary gallery JSON with 3 known embeddings in v1 schema."""
    rng = np.random.default_rng(42)
    gallery = {
        "version": 1,
        "detection_model": "scrfd_s",
        "recognition_model": "ms1mv3_vit_run_4",
        "identities": {
            "Alice": [{"embedding": rng.standard_normal(512).tolist(), "source": "/fake/alice.jpg", "landmark_alignment": True}],
            "Bob": [{"embedding": rng.standard_normal(512).tolist(), "source": "/fake/bob.jpg", "landmark_alignment": True}],
            "Charlie": [{"embedding": rng.standard_normal(512).tolist(), "source": "/fake/charlie.jpg", "landmark_alignment": True}],
        },
    }
    path = tmp_path / "gallery.json"
    path.write_text(json.dumps(gallery))
    return str(path)


@pytest.fixture()
def gallery(gallery_path):
    """Load a gallery from the temporary JSON."""
    g = EmbeddingGallery(embedding_dim=512)
    g.load(gallery_path)
    return g


def test_gallery_load(gallery):
    """Gallery should load all entries and normalize embeddings."""
    assert len(gallery.names) == 3
    assert gallery.embeddings.shape == (3, 512)
    assert set(gallery.names) == {"Alice", "Bob", "Charlie"}

    norms = np.linalg.norm(gallery.embeddings, axis=1)
    np.testing.assert_allclose(norms, 1.0, atol=1e-6)


def test_gallery_match_returns_best(gallery):
    """Querying with a gallery embedding should return that person as best match."""
    alice_embedding = gallery.embeddings[gallery.names_to_idx["Alice"]]
    results = gallery.match(alice_embedding, top_k=1)

    assert len(results) == 1
    idx, name, score = results[0]
    assert name == "Alice"
    assert score > 0.9


def test_gallery_match_top_k(gallery):
    """top_k=3 should return all 3 gallery entries, ordered by score."""
    query = gallery.embeddings[0]
    results = gallery.match(query, top_k=3)

    assert len(results) == 3
    scores = [r[2] for r in results]
    assert scores == sorted(scores, reverse=True)


def test_gallery_match_empty_raises():
    """Matching against an empty gallery should raise ValueError."""
    g = EmbeddingGallery()
    with pytest.raises(ValueError, match="No embeddings"):
        g.match(np.random.rand(512))


def test_gallery_load_wrong_dim(tmp_path):
    """Loading embeddings with wrong dimension should raise ValueError."""
    gallery = {
        "version": 1,
        "detection_model": "scrfd_s",
        "recognition_model": "ms1mv3_vit_run_4",
        "identities": {
            "Alice": [{"embedding": np.random.rand(256).tolist(), "source": "/fake/alice.jpg", "landmark_alignment": True}],
        },
    }
    path = tmp_path / "bad_gallery.json"
    path.write_text(json.dumps(gallery))

    g = EmbeddingGallery(embedding_dim=512)
    with pytest.raises(ValueError, match="Embedding dimension"):
        g.load(str(path))


# --- NMS ---


def test_nms_suppresses_overlapping():
    """NMS should suppress highly overlapping lower-score detections."""
    dets = np.array(
        [
            [0, 0, 100, 100, 0.9],
            [5, 5, 105, 105, 0.8],
            [200, 200, 300, 300, 0.7],
        ],
        dtype=np.float32,
    )
    keep = _nms(dets, nms_thresh=0.3)
    assert 0 in keep
    assert 1 not in keep
    assert 2 in keep


def test_nms_empty():
    """NMS on empty input should return empty list."""
    dets = np.empty((0, 5), dtype=np.float32)
    assert _nms(dets) == []


def test_nms_no_overlap():
    """Non-overlapping detections should all be kept."""
    dets = np.array(
        [
            [0, 0, 10, 10, 0.9],
            [100, 100, 110, 110, 0.8],
            [200, 200, 210, 210, 0.7],
        ],
        dtype=np.float32,
    )
    keep = _nms(dets, nms_thresh=0.3)
    assert len(keep) == 3


# --- preprocess_input_scrfd ---


def test_preprocess_scrfd_output_shape():
    """SCRFD preprocessing should produce (1, 640, 640, 3) float32 output."""
    img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
    result, scale = preprocess_input_scrfd(img)

    assert result.shape == (1, 640, 640, 3)
    assert result.dtype == np.float32
    assert scale > 0


def test_preprocess_scrfd_scale_factor():
    """Scale factor should reflect the resize ratio."""
    img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
    _, scale = preprocess_input_scrfd(img)
    assert abs(scale - 640 / 640) < 0.01 or abs(scale - 640 / 480) < 0.01


def test_preprocess_scrfd_tall_image():
    """Tall images should be handled correctly."""
    img = np.random.randint(0, 255, (1280, 480, 3), dtype=np.uint8)
    result, scale = preprocess_input_scrfd(img)

    assert result.shape == (1, 640, 640, 3)
    assert scale > 0


# --- preprocess_input_onlyface ---


def test_preprocess_onlyface_output_shape():
    """Onlyface preprocessing should produce (2, 3, 112, 112) output (original + flipped)."""
    face = np.random.randint(0, 255, (100, 80, 3), dtype=np.uint8)
    result, img_plot = preprocess_input_onlyface(face, None)

    assert result.shape == (2, 3, 112, 112)
    assert img_plot.shape == (112, 112, 3)


def test_preprocess_onlyface_with_landmarks():
    """Onlyface with valid landmarks should use affine warp."""
    face = np.random.randint(0, 255, (200, 200, 3), dtype=np.uint8)
    landmarks = np.array(
        [[60, 100], [130, 100], [95, 140], [65, 180], [125, 180]],
        dtype=np.float32,
    )
    result, img_plot = preprocess_input_onlyface(face, landmarks)

    assert result.shape == (2, 3, 112, 112)


# --- extract_person_keypoints ---


def test_extract_person_keypoints():
    """Keypoints should be offset by the detection bbox origin."""
    keypoints = np.array([[150, 200], [160, 210]], dtype=np.float32)
    bbox = [100, 150, 200, 250]
    result = extract_person_keypoints(keypoints, bbox)

    assert result == [[50, 50], [60, 60]]
