from pathlib import Path

import cv2
import numpy as np
import pytest

from app.face import get_onnx_session, get_positive_faces, preprocess_input_scrfd

PROJECT_ROOT = Path(__file__).resolve().parent.parent
TEST_IMAGE = PROJECT_ROOT / "data" / "10people.jpg"
FD_MODEL = PROJECT_ROOT / "app" / "data" / "models" / "scrfd_s.onnx"


@pytest.fixture(scope="module")
def fd_session():
    """Load the face detection ONNX model once for all tests in this module."""
    session, input_name, output_names = get_onnx_session(str(FD_MODEL))
    return session, input_name, output_names


@pytest.fixture(scope="module")
def test_image():
    """Load the 10people.jpg test image as BGR numpy array."""
    frame = cv2.imread(str(TEST_IMAGE))
    assert frame is not None, f"Could not load test image: {TEST_IMAGE}"
    return frame


def _run_detection(fd_session, frame_bgr, threshold=0.5, max_faces=20):
    """Run face detection on a BGR frame and return bboxes, keypoints, scores.

    Args:
        fd_session: Tuple of (session, input_name, output_names).
        frame_bgr: BGR image as numpy array.
        threshold: Minimum detection score.
        max_faces: Maximum number of faces to return after NMS.

    Returns:
        Tuple of (bboxes, keypoints, scores).
    """
    session, input_name, output_names = fd_session
    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    detection_input, detection_scale = preprocess_input_scrfd(frame_rgb)

    detection_output = session.run(
        output_names,
        {input_name: np.transpose(detection_input.astype("float32"), [0, 3, 1, 2])},
    )

    positive_indices = np.where(detection_output[0][0] >= threshold)
    return get_positive_faces(
        detection_output, positive_indices, max_faces, detection_scale, device="cpu"
    )


def test_detects_faces_in_group_photo(fd_session, test_image):
    """10people.jpg should yield approximately 10 face detections."""
    bboxes, _, scores = _run_detection(fd_session, test_image, threshold=0.5)
    assert len(bboxes) >= 8, f"Expected at least 8 faces, got {len(bboxes)}"
    assert len(bboxes) <= 15, f"Expected at most 15 faces, got {len(bboxes)}"


def test_detection_scores_are_valid(fd_session, test_image):
    """All detection scores should be between 0 and 1."""
    _, _, scores = _run_detection(fd_session, test_image, threshold=0.3)
    scores_flat = np.asarray(scores).reshape(-1)
    assert np.all(scores_flat >= 0.0)
    assert np.all(scores_flat <= 1.0)


def test_bboxes_within_image_bounds(fd_session, test_image):
    """All bounding boxes should be within the image dimensions."""
    h, w = test_image.shape[:2]
    bboxes, _, _ = _run_detection(fd_session, test_image, threshold=0.5)

    for bbox in bboxes:
        x1, y1, x2, y2 = bbox[:4]
        assert x1 >= -1 and y1 >= -1, f"Bbox top-left out of bounds: ({x1}, {y1})"
        assert x2 <= w + 1 and y2 <= h + 1, f"Bbox bottom-right out of bounds: ({x2}, {y2})"
        assert x2 > x1 and y2 > y1, f"Degenerate bbox: ({x1}, {y1}, {x2}, {y2})"


def test_no_detections_on_blank_image(fd_session):
    """A blank white image should produce zero detections."""
    blank = np.ones((480, 640, 3), dtype=np.uint8) * 255
    bboxes, _, _ = _run_detection(fd_session, blank, threshold=0.5)
    assert len(bboxes) == 0, f"Expected 0 faces on blank image, got {len(bboxes)}"
