import inspect
import json
import time

import cv2
import numpy as np
from sanic import Request, Sanic, Websocket
from sanic.config import Config

from app.config import Defaults, load_dotenv
from app.face import (
    EmbeddingGallery,
    get_onnx_session,
    get_positive_faces,
    preprocess_input_onlyface,
    preprocess_input_scrfd,
)
from app.models import FaceDetection, InferenceResponse
from app.track import FaceTracker, get_identity_resolver_class

load_dotenv(".env", override=False)

app = Sanic(
    "fr-in-the-cloud",
    config=Config(defaults=Defaults.dump(), env_prefix="FR_"),
)


@app.before_server_start
async def initialize_models(app: Sanic):
    app.ctx.fd_session, app.ctx.fd_input_name, app.ctx.fd_output_names = get_onnx_session(app.config.FD_ONNX_PATH)
    app.ctx.fr_session, app.ctx.fr_input_name, app.ctx.fr_output_names = get_onnx_session(app.config.FR_ONNX_PATH)


@app.before_server_start
async def load_gallery(app: Sanic):
    gallery = EmbeddingGallery()
    gallery.load(app.config.GALLERY_EMBEDDING_PATH)
    app.ctx.gallery = gallery


@app.websocket("/")
async def infer(request: Request, ws: Websocket):
    """WebSocket endpoint for face detection and recognition.

    Receives encoded image bytes (JPEG/PNG), runs the full FD/FR/tracking
    pipeline, and returns a JSON response with detected face identities.

    One tracker and identity resolver is maintained per WebSocket connection.

    Input:  Binary WebSocket message containing encoded image bytes (JPEG/PNG).
    Output: JSON-serialized `InferenceResponse` with a timestamp and a list of
            `FaceDetection` items (identity, confidence, bbox) for each tracked face.
    """
    tracker = FaceTracker()

    IdentityResolver = get_identity_resolver_class(app.config.TRACK_AGGREGATION_STRATEGY)
    resolver_kwargs = {
        "gallery": app.ctx.gallery,
        "top_k": app.config.TRACK_VOTING_TOP_K,
        "max_history": app.config.TRACK_MAX_HISTORY,
        "device": app.config.TORCH_DEVICE,
    }
    accepted_params = set(inspect.signature(IdentityResolver.__init__).parameters)
    resolver_kwargs = {k: v for k, v in resolver_kwargs.items() if k in accepted_params}
    resolver = IdentityResolver(**resolver_kwargs)

    async for msg in ws:
        perf_start = time.perf_counter()
        receive_timestamp = time.time()
        metrics: dict[str, float | int] = {}

        t_stage = time.perf_counter()
        frame_bytes = np.frombuffer(msg, dtype=np.uint8)
        frame = cv2.imdecode(frame_bytes, cv2.IMREAD_COLOR)
        metrics["decode_ms"] = (time.perf_counter() - t_stage) * 1000.0

        # Preprocess frame for face detection
        t_stage = time.perf_counter()
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        detection_input, detection_scale = preprocess_input_scrfd(frame_rgb)
        metrics["fd_preprocess_ms"] = (time.perf_counter() - t_stage) * 1000.0

        # Run face detection
        t_stage = time.perf_counter()
        detection_output = app.ctx.fd_session.run(
            app.ctx.fd_output_names,
            {
                app.ctx.fd_input_name: np.transpose(
                    detection_input.astype("float32"),
                    [0, 3, 1, 2],
                ),
            },
        )
        metrics["fd_infer_ms"] = (time.perf_counter() - t_stage) * 1000.0

        t_stage = time.perf_counter()
        positive_indices = np.where(
            detection_output[0][0] >= app.config.DETECTOR_THRESHOLD,
        )
        detection_bboxes, all_keypoints, detection_scores = get_positive_faces(
            detection_output,
            positive_indices,
            app.config.DETECT_NUM_FACES,
            detection_scale,
            device=app.config.TORCH_DEVICE,
        )

        # Filter detections that are too small
        if len(detection_bboxes) > 0:
            bboxes = detection_bboxes[:, :4]
            widths = bboxes[:, 2] - bboxes[:, 0]
            heights = bboxes[:, 3] - bboxes[:, 1]
            keep = (widths >= app.config.TRACK_MIN_FACE_SIZE) & (
                heights >= app.config.TRACK_MIN_FACE_SIZE
            )
            detection_bboxes = detection_bboxes[keep]
            all_keypoints = all_keypoints[keep]
            detection_scores = detection_scores[keep]
        metrics["fd_postprocess_ms"] = (time.perf_counter() - t_stage) * 1000.0

        t_stage = time.perf_counter()
        tracked = tracker.update(detection_bboxes, detection_scores, keypoints=all_keypoints)
        metrics["track_ms"] = (time.perf_counter() - t_stage) * 1000.0

        # Crop faces from tracked bboxes and preprocess for FR
        h, w = frame_rgb.shape[:2]
        fr_inputs = []
        fr_track_ids = []

        t_stage = time.perf_counter()
        for i in range(len(tracked)):
            x1, y1, x2, y2 = tracked.xyxy[i].astype(int)
            track_id = tracked.tracker_id[i]

            x1 = max(0, min(x1, w - 1))
            y1 = max(0, min(y1, h - 1))
            x2 = max(0, min(x2, w))
            y2 = max(0, min(y2, h))

            if x2 <= x1 or y2 <= y1:
                continue

            face_crop = frame_rgb[y1:y2, x1:x2, :]

            landmark5 = None
            if app.config.USE_LANDMARK_ALIGNMENT and "keypoints" in tracked.data:
                kps_flat = tracked.data["keypoints"][i]  # shape (10,)
                landmark5 = kps_flat.reshape(5, 2) - np.array([x1, y1], dtype=np.float32)

            preprocessed, _ = preprocess_input_onlyface(face_crop, landmark5)

            fr_inputs.append(preprocessed)
            fr_track_ids.append(track_id)
        metrics["fr_prepare_ms"] = (time.perf_counter() - t_stage) * 1000.0

        # Batch FR inference and normalize embeddings
        batch_embeddings = []
        t_stage = time.perf_counter()
        if len(fr_inputs) > 0:
            batch_input = np.concatenate(fr_inputs, axis=0)

            raw_output = app.ctx.fr_session.run(
                app.ctx.fr_output_names,
                {app.ctx.fr_input_name: batch_input.astype("float32")},
            )

            # Model outputs original+flipped embeddings interleaved; sum and normalize
            raw_embeddings = raw_output[0].astype("float")
            num_faces = len(fr_inputs)
            summed = raw_embeddings.reshape(num_faces, 2, -1).sum(axis=1)
            norms = np.maximum(np.linalg.norm(summed, axis=1, keepdims=True), 1e-12)
            batch_embeddings = list(summed / norms)
        metrics["fr_infer_ms"] = (time.perf_counter() - t_stage) * 1000.0

        t_stage = time.perf_counter()
        identities = resolver.update_identities(batch_embeddings, fr_track_ids)
        metrics["identity_resolve_ms"] = (time.perf_counter() - t_stage) * 1000.0

        # Build response
        detections = []
        for i in range(len(tracked)):
            track_id = tracked.tracker_id[i]
            bbox = tracked.xyxy[i].tolist()
            identity_info = identities.get(track_id, {"name": None, "score": None})

            detections.append(
                FaceDetection(
                    identity=identity_info["name"],
                    confidence=identity_info["score"] or 0.0,
                    bbox=bbox,
                ),
            )

        response = InferenceResponse(timestamp=receive_timestamp, detections=detections)
        metrics["tracked_faces"] = int(len(tracked))
        metrics["fr_inputs"] = int(len(fr_inputs))
        metrics["total_ms"] = (time.perf_counter() - perf_start) * 1000.0

        payload = response.model_dump()
        payload["metrics"] = metrics
        await ws.send(json.dumps(payload))
