#!/usr/bin/env python3
"""Visualise FR detections from the edge proxy's /fr WebSocket.

Connects to the /fr endpoint, receives detection messages (which include
bounding boxes), and also captures the RTSP frame directly to display
with overlaid bounding boxes in an OpenCV window.

Usage:
    python3 tools/visualize_fr.py [--ws ws://localhost:8080/fr] [--rtsp rtsp://192.168.168.105:8554/cam]

Press 'q' to quit.
"""

from __future__ import annotations

import argparse
import asyncio
import json
import threading
import time
from typing import Optional

import cv2
import numpy as np
import websockets


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Visualise FR detections")
    p.add_argument("--ws", default="ws://localhost:8080/fr", help="Edge proxy /fr WebSocket URL")
    p.add_argument("--rtsp", default="rtsp://192.168.168.105:8554/cam", help="RTSP URL for live frame")
    return p.parse_args()


# ── Shared state ──────────────────────────────────────────────────────
_lock = threading.Lock()
_latest_detections: list[dict] = []
_detection_time: float = 0.0
_ws_connected = False
_ws_msg_count = 0


# ── WebSocket listener (runs in background thread) ───────────────────
def ws_listener(ws_url: str) -> None:
    """Connect to /fr and update shared detection state."""
    global _latest_detections, _detection_time, _ws_connected, _ws_msg_count

    async def _listen() -> None:
        global _latest_detections, _detection_time, _ws_connected, _ws_msg_count
        while True:
            try:
                async with websockets.connect(ws_url) as ws:
                    _ws_connected = True
                    print(f"[WS] Connected to {ws_url}")
                    async for raw in ws:
                        msg = json.loads(raw)
                        if msg.get("type") != "fr_detections":
                            continue
                        dets = msg.get("detections", [])
                        with _lock:
                            _latest_detections = dets
                            _detection_time = time.monotonic()
                            _ws_msg_count += 1
            except Exception as exc:
                _ws_connected = False
                print(f"[WS] Disconnected: {exc} — reconnecting in 2s")
                await asyncio.sleep(2)

    asyncio.run(_listen())


def draw_detections(
    frame: np.ndarray,
    detections: list[dict],
    det_age_ms: float,
) -> np.ndarray:
    """Draw bounding boxes and labels on the frame."""
    h, w = frame.shape[:2]

    for det in detections:
        bbox = det.get("bbox", [0, 0, 0, 0])
        identity = det.get("identity")
        confidence = det.get("confidence", 0.0)

        x1, y1, x2, y2 = [int(v) for v in bbox]

        # Clamp to frame bounds
        x1 = max(0, min(x1, w - 1))
        y1 = max(0, min(y1, h - 1))
        x2 = max(0, min(x2, w - 1))
        y2 = max(0, min(y2, h - 1))

        # Color: green if identified, yellow if unknown
        color = (0, 255, 0) if identity else (0, 255, 255)

        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)

        # Label
        label = f"{identity or 'Unknown'} ({confidence:.0%})"
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.6
        thickness = 2
        (tw, th), _ = cv2.getTextSize(label, font, font_scale, thickness)
        cv2.rectangle(frame, (x1, y1 - th - 8), (x1 + tw + 4, y1), color, -1)
        cv2.putText(frame, label, (x1 + 2, y1 - 4), font, font_scale, (0, 0, 0), thickness)

    # HUD: detection age and count
    hud_color = (0, 255, 0) if det_age_ms < 200 else (0, 165, 255) if det_age_ms < 500 else (0, 0, 255)
    cv2.putText(
        frame,
        f"FR det age: {det_age_ms:.0f}ms | faces: {len(detections)} | WS msgs: {_ws_msg_count}",
        (10, 30),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.7,
        hud_color,
        2,
    )
    ws_status = "CONNECTED" if _ws_connected else "DISCONNECTED"
    cv2.putText(frame, f"WS: {ws_status}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, hud_color, 2)

    return frame


def main() -> None:
    args = parse_args()

    # Start WebSocket listener in background thread
    ws_thread = threading.Thread(target=ws_listener, args=(args.ws,), daemon=True)
    ws_thread.start()

    # Open RTSP stream — use a grabber thread that always keeps only the
    # latest frame, exactly like the edge proxy's FrameCapture does.
    print(f"[RTSP] Opening {args.rtsp} ...")

    _frame_lock = threading.Lock()
    _latest_rtsp: dict = {"frame": None, "time": 0.0, "count": 0}
    _rtsp_ok = threading.Event()

    def grabber() -> None:
        """Continuously grab frames, only keep the latest."""
        # Use FFMPEG backend with environment-level low-latency options
        import os
        os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = "rtsp_transport;tcp|fflags;nobuffer|flags;low_delay"
        cap = cv2.VideoCapture(args.rtsp, cv2.CAP_FFMPEG)
        if not cap.isOpened():
            print("[RTSP] Failed to open stream")
            return
        _rtsp_ok.set()
        while True:
            ret, frame = cap.read()
            if not ret:
                time.sleep(0.05)
                continue
            with _frame_lock:
                _latest_rtsp["frame"] = frame
                _latest_rtsp["time"] = time.monotonic()
                _latest_rtsp["count"] += 1
        cap.release()

    grab_thread = threading.Thread(target=grabber, daemon=True)
    grab_thread.start()

    if not _rtsp_ok.wait(timeout=10):
        print("[RTSP] Failed to open stream within 10s")
        return

    print("[RTSP] Stream opened — press 'q' to quit")

    display_count = 0
    fps_start = time.monotonic()

    while True:
        # Always get the LATEST frame (grabber thread overwrites continuously)
        with _frame_lock:
            frame = _latest_rtsp["frame"]
            rtsp_count = _latest_rtsp["count"]
            rtsp_age = (time.monotonic() - _latest_rtsp["time"]) * 1000 if _latest_rtsp["time"] > 0 else 9999

        if frame is None:
            time.sleep(0.01)
            continue

        display_count += 1

        # Get latest detections
        with _lock:
            detections = list(_latest_detections)
            det_time = _detection_time

        det_age_ms = (time.monotonic() - det_time) * 1000 if det_time > 0 else 9999

        # Draw
        frame = draw_detections(frame, detections, det_age_ms)

        # FPS + frame age counters
        elapsed = time.monotonic() - fps_start
        fps = display_count / elapsed if elapsed > 0 else 0
        cv2.putText(frame, f"RTSP grab FPS: {fps:.1f} | frame age: {rtsp_age:.0f}ms | grabbed: {rtsp_count}", (10, 90),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

        # Resize for display
        display = cv2.resize(frame, (960, 540))
        cv2.imshow("FR Visualizer (RTSP + Edge Proxy Detections)", display)

        key = cv2.waitKey(1) & 0xFF
        if key == ord("q"):
            break

    cv2.destroyAllWindows()
    print("Done.")


if __name__ == "__main__":
    main()
