"""Continuous face recognition loop.

Captures JPEG frames from the robot camera via ``FrameCapture`` and sends them
to the FR server over a persistent WebSocket connection.  Detection results are
forwarded to registered callbacks.
"""

from __future__ import annotations

import asyncio
import json
import logging
import os
import time
from typing import Any, Callable, Dict, List, Optional

import websockets

from .frame_capture import FrameCapture

logger = logging.getLogger(__name__)

# How long to wait before attempting to reconnect to the FR server.
_RECONNECT_DELAY_SEC = 2.0


class FRLoop:
    """Continuously capture frames and run face recognition.

    Args:
        frame_capture: A ``FrameCapture`` instance for grabbing JPEG frames.
        fr_endpoint: WebSocket URL of the FR server
            (e.g. ``ws://kluster.klass.dev:42067/``).
        fps: Target frames per second to send to FR.
    """

    def __init__(
        self,
        frame_capture: FrameCapture,
        fr_endpoint: str,
        fps: float = 5.0,
    ) -> None:
        self._frame_capture = frame_capture
        self._fr_endpoint = fr_endpoint
        self._interval = 1.0 / max(fps, 0.1)
        self._callbacks: List[Callable[[List[Dict[str, Any]]], Any]] = []
        self._running = False
        self._task: Optional[asyncio.Task] = None
        self._last_cycle_metrics: Dict[str, Any] = {}
        self._last_fr_payload: Dict[str, Any] = {}
        self._metric_count = 0
        self._metrics_log_every = max(int(os.getenv("FR_METRICS_LOG_EVERY", "10")), 1)

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def on_detections(self, callback: Callable[[List[Dict[str, Any]]], Any]) -> None:
        """Register a callback invoked with a list of detections."""
        self._callbacks.append(callback)

    def get_last_cycle_metrics(self) -> Dict[str, Any]:
        """Return a copy of the latest FR loop timing metrics."""
        return dict(self._last_cycle_metrics)

    async def start(self) -> None:
        """Start the FR loop as a background task."""
        if self._running:
            return
        self._running = True
        self._task = asyncio.create_task(self._run())
        logger.info("FRLoop: started (endpoint=%s, interval=%.2fs)", self._fr_endpoint, self._interval)

    def stop(self) -> None:
        """Stop the FR loop."""
        self._running = False
        if self._task and not self._task.done():
            self._task.cancel()
        logger.info("FRLoop: stopped")

    # ------------------------------------------------------------------
    # Internal
    # ------------------------------------------------------------------

    async def _run(self) -> None:
        """Outer loop: maintain a persistent WS connection to FR server."""
        while self._running:
            try:
                async with websockets.connect(self._fr_endpoint) as ws:
                    logger.info("FRLoop: connected to FR server at %s", self._fr_endpoint)
                    await self._capture_loop(ws)
            except asyncio.CancelledError:
                break
            except Exception as exc:
                if not self._running:
                    break
                logger.warning("FRLoop: FR connection error: %s — reconnecting in %.0fs", exc, _RECONNECT_DELAY_SEC)
                await asyncio.sleep(_RECONNECT_DELAY_SEC)

    async def _capture_loop(self, ws: Any) -> None:
        """Inner loop: capture frame -> send to FR -> notify callbacks."""
        loop = asyncio.get_event_loop()
        while self._running:
            t0 = time.monotonic()
            metrics: Dict[str, Any] = {"loop_started_at": time.time()}
            detections_count = 0

            # Capture frame (sync, runs in executor)
            capture_t0 = time.monotonic()
            try:
                frame: Optional[bytes] = await loop.run_in_executor(
                    None, self._frame_capture.capture
                )
            except Exception as exc:
                logger.debug("FRLoop: frame capture error: %s — skipping", exc)
                frame = None
            metrics["capture_ms"] = (time.monotonic() - capture_t0) * 1000.0

            try:
                capture_stats = self._frame_capture.get_stats()
                if isinstance(capture_stats, dict):
                    metrics["capture_stats"] = capture_stats
                    frame_age_ms = capture_stats.get("frame_age_ms")
                    if isinstance(frame_age_ms, (int, float)):
                        metrics["frame_age_ms"] = float(frame_age_ms)
            except Exception:
                # metrics should never break the FR loop
                pass

            if frame is not None:
                fr_t0 = time.monotonic()
                detections = await self._send_to_fr(ws, frame)
                metrics["fr_roundtrip_ms"] = (time.monotonic() - fr_t0) * 1000.0

                fr_recv_ts = self._last_fr_payload.get("timestamp")
                if isinstance(fr_recv_ts, (int, float)) and fr_recv_ts > 0:
                    metrics["fr_server_receive_timestamp"] = float(fr_recv_ts)
                    metrics["edge_since_fr_receive_ms"] = max(
                        0.0, (time.time() - float(fr_recv_ts)) * 1000.0
                    )

                fr_server_metrics = self._last_fr_payload.get("metrics")
                if isinstance(fr_server_metrics, dict):
                    metrics["fr_server_metrics"] = fr_server_metrics
                    total_ms = fr_server_metrics.get("total_ms")
                    if isinstance(total_ms, (int, float)):
                        metrics["fr_server_total_ms"] = float(total_ms)

                self._last_cycle_metrics = dict(metrics)
                if detections is not None:
                    detections_count = len(detections)
                    notify_t0 = time.monotonic()
                    await self._notify(detections)
                    metrics["notify_ms"] = (time.monotonic() - notify_t0) * 1000.0

            # Sleep for remainder of interval
            elapsed = time.monotonic() - t0
            metrics["loop_ms"] = elapsed * 1000.0
            self._last_cycle_metrics = dict(metrics)
            self._maybe_log_metrics(metrics, detections_count)

            sleep_time = self._interval - elapsed
            if sleep_time > 0:
                await asyncio.sleep(sleep_time)

    async def _send_to_fr(self, ws: Any, jpeg_bytes: bytes) -> Optional[List[Dict[str, Any]]]:
        """Send binary JPEG to FR server WS, receive JSON detections.

        Returns the detections list, or ``None`` if the exchange failed
        (which will cause the outer loop to reconnect).
        """
        try:
            await ws.send(jpeg_bytes)
            response = await asyncio.wait_for(ws.recv(), timeout=5.0)
            if isinstance(response, (bytes, bytearray)):
                response = response.decode("utf-8", errors="replace")
            data = json.loads(response)
            if not isinstance(data, dict):
                raise ValueError("FR response must be a JSON object")
            self._last_fr_payload = data
            detections = data.get("detections", [])
            return detections if isinstance(detections, list) else []
        except asyncio.TimeoutError:
            logger.warning("FRLoop: FR server response timed out")
            raise  # let outer loop reconnect
        except Exception as exc:
            logger.warning("FRLoop: FR exchange error: %s", exc)
            raise  # let outer loop reconnect

    async def _notify(self, detections: List[Dict[str, Any]]) -> None:
        """Invoke all registered callbacks with the detections list."""
        for cb in self._callbacks:
            try:
                result = cb(detections)
                if asyncio.iscoroutine(result):
                    await result
            except Exception as exc:
                logger.warning("FRLoop: callback error: %s", exc)

    def _maybe_log_metrics(self, metrics: Dict[str, Any], detections_count: int) -> None:
        self._metric_count += 1
        if self._metric_count % self._metrics_log_every != 0:
            return

        def _fmt(value: Any) -> str:
            if isinstance(value, (int, float)):
                return f"{float(value):.1f}ms"
            return "n/a"

        logger.info(
            "FRLoop metrics #%d: capture=%s frame_age=%s fr_rtt=%s fr_server=%s notify=%s loop=%s detections=%d",
            self._metric_count,
            _fmt(metrics.get("capture_ms")),
            _fmt(metrics.get("frame_age_ms")),
            _fmt(metrics.get("fr_roundtrip_ms")),
            _fmt(metrics.get("fr_server_total_ms")),
            _fmt(metrics.get("notify_ms")),
            _fmt(metrics.get("loop_ms")),
            detections_count,
        )
