#!/usr/bin/env python
import argparse
import json
import sys
import time
from pathlib import Path

import cv2
import numpy as np
import onnxruntime

_APP_DIR = Path(__file__).parent


def _create_onnx_session(model_rel_path: str) -> onnxruntime.InferenceSession:
    """Create an ONNX Runtime session with sensible provider fallback.

    Prefers CUDA if available, otherwise falls back to CPU.
    """
    available = onnxruntime.get_available_providers()
    providers: list[str] = []
    if "CUDAExecutionProvider" in available:
        providers.append("CUDAExecutionProvider")
    if "CPUExecutionProvider" in available:
        providers.append("CPUExecutionProvider")
    if not providers:
        providers = ["CPUExecutionProvider"]

    model_path = _APP_DIR / model_rel_path
    if not model_path.exists():
        raise FileNotFoundError(f"Model not found at {model_path}")

    return onnxruntime.InferenceSession(str(model_path), providers=providers)


def get_scrfd_onnx_session():
    """Instantiate and return SCRFD onnx session"""
    scrfd_session = _create_onnx_session("../data/models/scrfd_s.onnx")
    scrfd_input_cfg = scrfd_session.get_inputs()[0]
    scrfd_input_name = scrfd_input_cfg.name
    scrfd_outputs = scrfd_session.get_outputs()
    scrfd_output_names = [o.name for o in scrfd_outputs]

    return scrfd_session, scrfd_input_name, scrfd_output_names


def get_onlyface_onnx_session():
    """Instantiate and return onlyface onnx session"""
    onlyface_session = _create_onnx_session("../data/models/ms1mv3_vit_run_4.onnx")
    onlyface_input_cfg = onlyface_session.get_inputs()[0]
    onlyface_input_name = onlyface_input_cfg.name
    onlyface_outputs = onlyface_session.get_outputs()
    onlyface_output_names = [o.name for o in onlyface_outputs]

    return onlyface_session, onlyface_input_name, onlyface_output_names


def _is_image_file(path: Path) -> bool:
    """Return True if the given path looks like an image file."""
    img_exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
    return path.suffix.lower() in img_exts


def preprocess_input_scrfd(img: np.ndarray) -> tuple[np.ndarray, float]:
    """Prepare input for scrfd model (Face Detector). Takes an image with abitrary height/width and resize and normalize it for SCRFD consumption.
    Also calculates a scaling factor such that the model output can be scaled back to the original image size.

    Args:
        img : A numpy array representing the image read using cv2.imread.
              Note that cv2.imread reads BGR so it is necessary to convert to RGB using cv2.cvtColor

    Returns:
        det_img   : pre-processed input for scrfd model consumption
        det_scale : factor to scale detection by to fit back original image size.
    """
    input_size = (640, 640)
    im_ratio = float(img.shape[0]) / img.shape[1]
    model_ratio = float(input_size[1]) / input_size[0]
    if im_ratio > model_ratio:
        new_height = input_size[1]
        new_width = int(new_height / im_ratio)
    else:
        new_width = input_size[0]
        new_height = int(new_width * im_ratio)

    det_scale = float(new_height) / img.shape[0]
    resized_img = cv2.resize(img, (new_width, new_height))
    det_img = np.zeros((640, 640, 3), dtype=np.float32)
    det_img[:new_height, :new_width, :] = resized_img

    det_img = np.subtract(det_img, [127.5, 127.5, 127.5])
    det_img = np.divide(det_img, [128.0, 128.0, 128.0])

    det_img = np.expand_dims(det_img, axis=0)

    return det_img.astype("float32"), det_scale


def preprocess_input_onlyface(
    img: np.ndarray, landmark5: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray]:
    """Prepare input for onlyface model. Takes an image with abitrary height/width and resize and normalize it for onlyface consumption.

    Args:
        img       : A numpy array of the face to be passed through onlyface
        landmark5 : a (5,2) np array representing the 5 facial landmarks. to be used to align the face using affine warping.

    Returns:
        img      : pre-processed input for onlyface model.
        img_plot : the resized face prior to normalising. allows easier plotting
    """
    src = np.array(
        [
            [30.2946, 51.6963],
            [65.5318, 51.5014],
            [48.0252, 71.7366],
            [33.5493, 92.3655],
            [62.7299, 92.2041],
        ],
        dtype=np.float32,
    )
    if landmark5 is not None:
        # The following line of code is a drop-in replacement for these lines
        # tform = trans.SimilarityTransform()
        # tform.estimate(landmark5, src)
        # M = tform.params[0:2, :]
        M, _ = cv2.estimateAffinePartial2D(landmark5, src)
        img_plot = cv2.warpAffine(img, M, (112, 112), borderValue=0.0)
        img_plot_flip = cv2.flip(img_plot, 1)
    else:
        img_plot = cv2.resize(img, (112, 112))
        img_plot_flip = cv2.flip(img_plot, 1)

    mean = 255 * np.array([0.5, 0.5, 0.5])
    std = 255 * np.array([0.5, 0.5, 0.5])

    img = (img_plot - mean[None, None, :]) / std[None, None, :]
    img_flip = (img_plot_flip - mean[None, None, :]) / std[None, None, :]

    img = np.expand_dims(img, 0)
    img_flip = np.expand_dims(img_flip, 0)

    img = np.concatenate([img, img_flip], axis=0)
    return img, img_plot


def nms(dets: np.ndarray, nms_thresh: float = 0.3) -> list[int]:
    """Non Maximum Suppression for Bboxes

    Reference Code: https://github.com/deepinsight/insightface/blob/cc64315087304e3eda99559a5539f5d0b62512c5/detection/scrfd/tools/scrfd.py#L263
    Args:
        dets : [N,4] bbox coordinates
        nms_threshold: threshold for NMS

    return:
        keep: valid bounding boxes after NMS
    """
    thresh = nms_thresh
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]

    return keep


def get_positive_faces(
    output: list[np.ndarray],
    pos_inds: tuple[np.ndarray, ...],
    num_detect_faces: int,
    det_scale: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    pos_scores = output[0][0][pos_inds]
    pos_bboxes = output[1][0][pos_inds]
    pos_kps = output[2][0][pos_inds]
    scores = np.vstack(pos_scores)
    scores_ravel = scores.ravel()
    order = scores_ravel.argsort()[::-1]
    bboxes = np.vstack(pos_bboxes) / det_scale
    pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
    pre_det = pre_det[order, :]
    keep = nms(pre_det)
    det = pre_det[keep, :]
    kpss = pos_kps[order, :] / det_scale
    kpss = kpss[keep, :]
    det = det[:num_detect_faces, :]
    kpss = kpss[:num_detect_faces, :]

    return det, kpss, scores


_FD_MODEL_NAME = "scrfd_s"
_FR_MODEL_NAME = "ms1mv3_vit_run_4"


def _derive_identity_name(image_path: Path, input_dir: Path) -> str:
    """Derive an identity name from an image path.

    Prefers the parent directory name when it is not the gallery root (e.g.
    ``gallery/alice/01.jpg`` → ``alice``). Falls back to the filename stem
    with trailing ``_NNN`` suffixes stripped (e.g. ``alice_01.jpg`` → ``alice``).

    Args:
        image_path: Path to the source image.
        input_dir: Root gallery directory passed on the CLI.

    Returns:
        Identity name string.
    """
    parent = image_path.parent
    if parent.resolve() != input_dir.resolve():
        return parent.name
    stem = image_path.stem
    parts = stem.rsplit("_", 1)
    if len(parts) == 2 and parts[1].isdigit():
        return parts[0]
    return stem


def generate_embeddings(
    image_paths: list[Path], input_dir: Path, resample: bool,
) -> dict:
    """Generate embeddings for all image files and return a v1 gallery dict.

    Args:
        image_paths: Paths to gallery images.
        input_dir: Root gallery directory (used for identity name derivation).
        resample: If True, average embeddings across downsampled variants.

    Returns:
        Gallery dict conforming to the v1 schema.
    """
    identities: dict[str, list[dict]] = {}
    onlyface_session, onlyface_input_name, onlyface_output_names = (
        get_onlyface_onnx_session()
    )
    scrfd_session, scrfd_input_name, scrfd_output_names = get_scrfd_onnx_session()

    for image_path in image_paths:
        kps = None
        bgr = cv2.imread(image_path)
        if bgr is None:
            continue
        original_image = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        det_img, det_scale = preprocess_input_scrfd(original_image)
        output = scrfd_session.run(
            scrfd_output_names,
            {scrfd_input_name: np.transpose(det_img.astype("float32"), [0, 3, 1, 2])},
        )
        pos_inds = np.where(output[0][0] >= 0.3)
        if len(pos_inds[0]) != 0:
            kps = []
            det, kpss, scores = get_positive_faces(output, pos_inds, 1, det_scale)
            bbox = det[0]
            onlyface_input = original_image[
                int(bbox[1]) : int(bbox[3]),
                int(bbox[0]) : int(bbox[2]),
            ]
            original_kpss = kpss[0]
            for kp in original_kpss:
                kps.append(
                    [
                        kp[0].astype("int") - int(bbox[0]),
                        kp[1].astype("int") - int(bbox[1]),
                    ],
                )
            kps = np.array(kps)
        else:
            onlyface_input = original_image

        landmark_alignment_used = kps is not None
        image, _ = preprocess_input_onlyface(onlyface_input, kps)

        embeddings = onlyface_session.run(
            onlyface_output_names, {onlyface_input_name: image.astype("float32")},
        )
        person_embedding = embeddings[0][0].astype("float") + embeddings[0][1].astype(
            "float",
        )
        person_embedding = person_embedding / np.linalg.norm(person_embedding)

        if resample is True:
            downsample_image_1 = cv2.resize(
                onlyface_input, (56, 56), fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC,
            )
            downsample_image_1 = cv2.resize(
                downsample_image_1, (112, 112), interpolation=cv2.INTER_CUBIC,
            )
            downsample_image_1, _ = preprocess_input_onlyface(downsample_image_1)
            embeddings = onlyface_session.run(
                onlyface_output_names,
                {onlyface_input_name: downsample_image_1.astype("float32")},
            )
            resample_1 = embeddings[0][0].astype("float") + embeddings[0][1].astype(
                "float",
            )

            downsample_image_2 = cv2.resize(
                onlyface_input, (28, 28), fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC,
            )
            downsample_image_2 = cv2.resize(
                downsample_image_2, (112, 112), interpolation=cv2.INTER_CUBIC,
            )
            downsample_image_2, _ = preprocess_input_onlyface(downsample_image_2)
            embeddings = onlyface_session.run(
                onlyface_output_names,
                {onlyface_input_name: downsample_image_2.astype("float32")},
            )
            resample_2 = embeddings[0][0].astype("float") + embeddings[0][1].astype(
                "float",
            )

            downsample_image_3 = cv2.resize(
                onlyface_input, (14, 14), fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC,
            )
            downsample_image_3 = cv2.resize(
                downsample_image_3, (112, 112), interpolation=cv2.INTER_CUBIC,
            )
            downsample_image_3, _ = preprocess_input_onlyface(downsample_image_3)
            embeddings = onlyface_session.run(
                onlyface_output_names,
                {onlyface_input_name: downsample_image_3.astype("float32")},
            )
            resample_3 = embeddings[0][0].astype("float") + embeddings[0][1].astype(
                "float",
            )

            person_embedding = np.array(
                [person_embedding, resample_1, resample_2, resample_3],
            )
            person_embedding = np.mean(person_embedding, 0, keepdims=True)
            person_embedding = person_embedding / np.sqrt(
                np.sum(person_embedding**2, -1, keepdims=True),
            )

        person_embedding = np.squeeze(person_embedding).tolist()
        name = _derive_identity_name(image_path, input_dir)
        entry = {
            "embedding": person_embedding,
            "source": str(image_path.resolve()),
            "landmark_alignment": landmark_alignment_used,
        }
        identities.setdefault(name, []).append(entry)

    return {
        "version": 1,
        "detection_model": _FD_MODEL_NAME,
        "recognition_model": _FR_MODEL_NAME,
        "identities": identities,
    }


if __name__ == "__main__":
    arg_parse = argparse.ArgumentParser()
    arg_parse.add_argument(
        "-input_dir",
        dest="dir",
        type=Path,
        help="Path to directory containing gallery images",
        default=_APP_DIR / "../data/sample_gallery/high_res_gallery/",
    )
    arg_parse.add_argument(
        "-output_path",
        dest="out_path",
        type=Path,
        help="Output json file path.",
        default=_APP_DIR / "../data/gallery_embeddings/sample_gallery_embeddings.json",
    )
    arg_parse.add_argument(
        "-resample",
        dest="resample",
        action="store_true",
        help="whether to resample image, defaults to False",
        default=False,
    )

    (opt_args, args) = arg_parse.parse_known_args()
    if len(args) > 0:
        arg_parse.print_help()
        sys.exit(1)

    image_paths = [p for p in opt_args.dir.rglob("*") if _is_image_file(p)]
    print("========================================")
    print("Gallery Images Detected : {}".format(len(image_paths)))
    print("Resample Flag is {}".format(str(opt_args.resample)))
    print("Generating Embeddings...")
    start_time = time.time()

    gallery = generate_embeddings(image_paths, opt_args.dir, opt_args.resample)

    time_taken = time.time() - start_time
    identity_count = len(gallery["identities"])
    embedding_count = sum(len(v) for v in gallery["identities"].values())
    print(
        "Completed Generating {} Embeddings for {} Identities".format(embedding_count, identity_count),
    )
    print("Time Taken : {} Seconds".format(round(time_taken, 2)))
    print("Saving gallery embeddings at {}".format(opt_args.out_path))

    with open(opt_args.out_path, "w") as file:
        json.dump(gallery, file, indent=2)

    print("Job Completed Successfully!")
    print("========================================")
