import json
from typing import List, Tuple

import cv2
import numpy as np
import onnxruntime
from onnxruntime import InferenceSession

onnxruntime.set_default_logger_severity(3)


class EmbeddingGallery:
    def __init__(self, embedding_dim: int = 512):
        """
        Initialize EmbeddingGallery.

        Args:
            embedding_dim: Dimension of face embeddings (default 512)
            normalize: Whether to normalize embeddings for cosine similarity
        """
        self.embedding_dim = embedding_dim
        self.embeddings = np.empty((0, embedding_dim), dtype=np.float32)
        self.names = []
        self.names_to_idx = {}

    def match(
        self,
        query_embedding: np.ndarray,
        normalize: bool = True,
        device: object = None,
        top_k: int = 1,
    ) -> List[Tuple[int, str, float]]:
        """
        Return the top-K matching faces using PyTorch operations

        Args:
            query_embedding (np.ndarray): Query embedding vector
            normalize (bool): Whether or not to normalize the query embedding
            top_k (int): Number of top matches to return (default: 1)

        Returns:
            List of (idx, name, score) candidates, ordered best to worst.
        """
        if len(self.embeddings) == 0:
            raise ValueError("No embeddings in gallery")

        query = np.asarray(query_embedding, dtype=np.float32).reshape(-1)
        if normalize:
            qn = np.linalg.norm(query)
            if qn > 0:
                query = query / qn

        # embeddings are pre-normalized on load()
        similarity = self.embeddings @ query
        similarity = (similarity + 1.0) / 2.0

        top_k = min(int(top_k), len(self.embeddings))
        if top_k <= 0:
            return []

        if top_k == 1:
            idx = int(np.argmax(similarity))
            return [(idx, self.names[idx], float(similarity[idx]))]

        # argpartition gives top-k unordered; sort afterward
        idxs = np.argpartition(-similarity, top_k - 1)[:top_k]
        idxs = idxs[np.argsort(-similarity[idxs])]
        return [(int(i), self.names[int(i)], float(similarity[int(i)])) for i in idxs]

    def load(self, path: str):
        """Load gallery from a JSON file in the v1 schema.

        Flattens all per-identity embeddings into a single matrix, repeating
        the identity name for each embedding row. Max-pooling over multiple
        embeddings per identity is implicit in match() via argmax.

        Args:
            path: Path to the gallery JSON file.
        """
        embeddings = []
        self.names = []
        self.names_to_idx = {}

        with open(path, "r") as f:
            data = json.load(f)

        if data.get("version") != 1:
            raise ValueError(
                f"Unsupported gallery format (expected version=1, got version={data.get('version')}). "
                "Use the migration notebook to convert old galleries.",
            )

        row = 0
        for name, entries in data["identities"].items():
            for entry in entries:
                embeddings.append(entry["embedding"])
                self.names.append(name)
            self.names_to_idx[name] = row
            row += len(entries)

        self.embeddings = np.array(embeddings, dtype=np.float32)
        embedding_length = self.embeddings.shape[1]
        if embedding_length != self.embedding_dim:
            raise ValueError(
                f"Embedding dimension {embedding_length} doesn't match "
                f"expected dimension {self.embedding_dim}",
            )

        norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
        norms = np.maximum(norms, 1e-12)
        self.embeddings = self.embeddings / norms


def get_onnx_session(model_path: str):
    preferred = [
        "CUDAExecutionProvider",
        "CPUExecutionProvider",
    ]
    available = set(onnxruntime.get_available_providers())
    providers = [p for p in preferred if p in available]
    if not providers:
        providers = ["CPUExecutionProvider"]

    model_session: InferenceSession = onnxruntime.InferenceSession(
        model_path,
        providers=providers,
    )
    print(f"Using providers {model_session.get_providers()} for {model_path}")
    input_name = model_session.get_inputs()[0].name
    output_names = [output.name for output in model_session.get_outputs()]
    return model_session, input_name, output_names


def _nms(dets: np.ndarray, nms_thresh: float = 0.3) -> List[int]:
    """Classic NMS for dets with columns [x1,y1,x2,y2,score]."""
    if dets.size == 0:
        return []

    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]

    areas = (x2 - x1 + 1.0) * (y2 - y1 + 1.0)
    order = scores.argsort()[::-1]

    keep: List[int] = []
    while order.size > 0:
        i = int(order[0])
        keep.append(i)
        if order.size == 1:
            break

        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1.0)
        h = np.maximum(0.0, yy2 - yy1 + 1.0)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= nms_thresh)[0]
        order = order[inds + 1]

    return keep


def preprocess_input_scrfd(img: np.ndarray) -> Tuple[np.ndarray, float]:
    """
    Taken from fr_in_the_wild. See https://vc1.klass.dev/queenstown/platform/opportunities/fr_in_the_wild

    Prepare input for scrfd model (Face Detector). Takes an image with abitrary height/width and resize and normalize it for SCRFD consumption.
    Also calculates a scaling factor such that the model output can be scaled back to the original image size.

    Args:
        img (np.ndarray): A numpy array representing the image read using cv2.imread.
              Note that cv2.imread reads BGR so it is necessary to convert to RGB using cv2.cvtColor

    Returns:
        np.ndarray: pre-processed input for scrfd model consumption
        float: factor to scale detection by to fit back original image size.
    """
    input_size = (640, 640)
    im_ratio = float(img.shape[0]) / img.shape[1]
    model_ratio = float(input_size[1]) / input_size[0]
    if im_ratio > model_ratio:
        new_height = input_size[1]
        new_width = int(new_height / im_ratio)
    else:
        new_width = input_size[0]
        new_height = int(new_width * im_ratio)

    det_scale = float(new_height) / img.shape[0]
    resized_img = cv2.resize(img, (new_width, new_height))
    det_img = np.zeros((640, 640, 3), dtype=np.float32)
    det_img[:new_height, :new_width, :] = resized_img

    det_img = np.subtract(det_img, [127.5, 127.5, 127.5])
    det_img = np.divide(det_img, [128.0, 128.0, 128.0])

    det_img = np.expand_dims(det_img, axis=0)

    return det_img.astype("float32"), det_scale


def preprocess_input_onlyface(img, landmark5=None):
    """
    Taken from fr_in_the_wild. See https://vc1.klass.dev/queenstown/platform/opportunities/fr_in_the_wild

    Prepare input for onlyface model. Takes an image with abitrary height/width and resize and normalize it for onlyface consumption.

    Args:
        img       : A numpy array of the face to be passed through onlyface
        landmark5 : a (5,2) np array representing the 5 facial landmarks. to be used to align the face using affine warping.

    Returns:
        img      : pre-processed input for onlyface model.
        img_plot : the resized face prior to normalising. allows easier plotting
    """
    src = np.array(
        [
            [30.2946, 51.6963],
            [65.5318, 51.5014],
            [48.0252, 71.7366],
            [33.5493, 92.3655],
            [62.7299, 92.2041],
        ],
        dtype=np.float32,
    )
    landmark5_arr = (
        np.asarray(landmark5, dtype=np.float32) if landmark5 is not None else None
    )

    if landmark5_arr is not None and landmark5_arr.shape == (5, 2):
        # The following line of code is a drop-in replacement for these lines
        # tform = trans.SimilarityTransform()
        # tform.estimate(landmark5, src)
        # M = tform.params[0:2, :]
        M, _ = cv2.estimateAffinePartial2D(landmark5_arr, src)

        img_plot = cv2.warpAffine(img, M, (112, 112), borderValue=0.0)
        img_plot_flip = cv2.flip(img_plot, 1)
    else:
        img_plot = cv2.resize(img, (112, 112))
        img_plot_flip = cv2.flip(img_plot, 1)

    mean = 255 * np.array([0.5, 0.5, 0.5])
    std = 255 * np.array([0.5, 0.5, 0.5])

    img = (img_plot - mean[None, None, :]) / std[None, None, :]
    img_flip = (img_plot_flip - mean[None, None, :]) / std[None, None, :]

    img = np.expand_dims(img, 0)
    img_flip = np.expand_dims(img_flip, 0)

    img = np.concatenate([img, img_flip], axis=0)
    # transpose for open-source fr model
    img = np.transpose(img, [0, 3, 1, 2])
    return img, img_plot


def get_positive_faces(
    output: list[np.ndarray],
    pos_inds: tuple[np.ndarray, ...],
    num_detect_faces: int,
    det_scale: float,
    nms_thresh: float = 0.3,
    device: str | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    GPU-accelerated face detection post-processing with PyTorch NMS

    Taken from fr_in_the_wild with GPU acceleration improvements.

    Args:
        output: Model output containing scores, bboxes, and keypoints
        pos_inds: Indices of positive detections
        num_detect_faces: Maximum number of faces to return
        det_scale: Scale factor to convert back to original image coordinates
        nms_thresh: NMS threshold for overlap suppression
        device: Device to run operations on

    Returns:
        det: Filtered bounding boxes with scores [N, 5]
        kpss: Filtered keypoints [N, 10]
        scores: Original scores [N, 1]
    """
    # Early return for empty detections - avoids unnecessary computation
    if len(pos_inds[0]) == 0:
        return (
            np.empty((0, 5), dtype=np.float32),  # det: empty bboxes with scores
            np.empty((0, 10), dtype=np.float32),  # kpss: empty keypoints
            np.empty((0, 1), dtype=np.float32),  # scores: empty scores
        )

    pos_scores = output[0][0][pos_inds]
    pos_bboxes = output[1][0][pos_inds]
    pos_kps = output[2][0][pos_inds]
    scores = np.vstack(pos_scores)
    scores_ravel = scores.ravel()

    bboxes_np = np.vstack(pos_bboxes) / det_scale
    pre_det = np.hstack((bboxes_np, scores_ravel.reshape(-1, 1))).astype(
        np.float32,
        copy=False,
    )
    keep_np = np.array(_nms(pre_det, nms_thresh=nms_thresh), dtype=np.int64)

    bboxes_kept = bboxes_np[keep_np]
    scores_kept = scores_ravel[keep_np].reshape(-1, 1)
    det = np.hstack((bboxes_kept, scores_kept)).astype(np.float32, copy=False)

    kpss = pos_kps[keep_np, :] / det_scale

    # Limit to requested number of faces
    det = det[:num_detect_faces, :]
    kpss = kpss[:num_detect_faces, :]
    scores_kept = scores_kept[:num_detect_faces, :]

    return det, kpss, scores_kept


def extract_person_keypoints(keypoints, detection_bbox):
    person_keypoints = []
    for kp in keypoints:
        # store kps
        kp = kp.astype("int")
        person_keypoints.append([kp[0] - detection_bbox[0], kp[1] - detection_bbox[1]])
    return person_keypoints
