"""Tests for the FR loop and /fr WebSocket endpoint."""

import asyncio
import json
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
import websockets

from edge_proxy.fr_loop import FRLoop
from edge_proxy.server import EdgeProxyServer


# ---------------------------------------------------------------------------
# FRLoop unit tests
# ---------------------------------------------------------------------------


class TestFRLoop:
    """Unit tests for FRLoop with mocked FrameCapture and FR server."""

    def _make_frame_capture(self, frame: bytes | None = b"\xff\xd8fake_jpeg"):
        fc = MagicMock()
        fc.capture.return_value = frame
        return fc

    @pytest.mark.asyncio
    async def test_callback_receives_detections(self):
        """FRLoop should invoke registered callbacks with detections."""
        fc = self._make_frame_capture()
        detections_received = []

        async def on_det(dets):
            detections_received.append(dets)

        fr_response = json.dumps({
            "detections": [{"identity": "Alice", "confidence": 0.95, "bbox": [10, 20, 100, 200]}],
            "timestamp": 1234567890.0,
        })

        loop = FRLoop(frame_capture=fc, fr_endpoint="ws://localhost:9999/", fps=100)
        loop.on_detections(on_det)

        # Patch websockets.connect to return a mock WS
        mock_ws = AsyncMock()
        mock_ws.send = AsyncMock()
        mock_ws.recv = AsyncMock(return_value=fr_response)

        call_count = 0

        async def fake_capture_loop(ws):
            nonlocal call_count
            # Run one iteration manually
            frame = fc.capture()
            if frame is not None:
                dets = await loop._send_to_fr(ws, frame)
                if dets is not None:
                    await loop._notify(dets)
            call_count += 1
            loop.stop()

        with patch.object(loop, "_capture_loop", fake_capture_loop):
            with patch("edge_proxy.fr_loop.websockets.connect") as mock_connect:
                mock_connect.return_value.__aenter__ = AsyncMock(return_value=mock_ws)
                mock_connect.return_value.__aexit__ = AsyncMock(return_value=False)
                await loop.start()
                # Give the background task a moment to run
                await asyncio.sleep(0.1)

        assert len(detections_received) == 1
        assert detections_received[0][0]["identity"] == "Alice"

    @pytest.mark.asyncio
    async def test_skip_frame_on_capture_failure(self):
        """FRLoop should skip frames when capture returns None."""
        fc = self._make_frame_capture(frame=None)

        callback = AsyncMock()
        loop = FRLoop(frame_capture=fc, fr_endpoint="ws://localhost:9999/", fps=100)
        loop.on_detections(callback)

        # _send_to_fr should NOT be called if frame is None
        with patch.object(loop, "_send_to_fr", new_callable=AsyncMock) as mock_send:
            # Simulate one iteration of _capture_loop
            mock_ws = AsyncMock()
            loop._running = True

            # Manually run one iteration
            eloop = asyncio.get_event_loop()
            frame = await eloop.run_in_executor(None, fc.capture)
            assert frame is None
            mock_send.assert_not_called()

    @pytest.mark.asyncio
    async def test_send_to_fr_sends_binary(self):
        """_send_to_fr should send binary JPEG and parse JSON response."""
        fc = self._make_frame_capture()
        loop = FRLoop(frame_capture=fc, fr_endpoint="ws://localhost:9999/", fps=5)

        mock_ws = AsyncMock()
        mock_ws.recv = AsyncMock(return_value=json.dumps({
            "detections": [{"identity": None, "confidence": 0.3, "bbox": [0, 0, 50, 50]}],
        }))

        result = await loop._send_to_fr(mock_ws, b"\xff\xd8jpeg_data")

        mock_ws.send.assert_called_once_with(b"\xff\xd8jpeg_data")
        assert len(result) == 1
        assert result[0]["identity"] is None

    @pytest.mark.asyncio
    async def test_stop_cancels_task(self):
        """stop() should cancel the background task."""
        fc = self._make_frame_capture()
        loop = FRLoop(frame_capture=fc, fr_endpoint="ws://localhost:9999/", fps=5)

        # Mock the _run method to just sleep forever
        async def fake_run():
            await asyncio.sleep(999)

        with patch.object(loop, "_run", fake_run):
            await loop.start()
            assert loop._running is True
            assert loop._task is not None

            loop.stop()
            assert loop._running is False
            # Give cancellation a moment to propagate
            await asyncio.sleep(0.05)
            assert loop._task.done()

    @pytest.mark.asyncio
    async def test_on_detections_sync_callback(self):
        """on_detections should support synchronous callbacks too."""
        fc = self._make_frame_capture()
        loop = FRLoop(frame_capture=fc, fr_endpoint="ws://localhost:9999/", fps=5)

        results = []

        def sync_cb(dets):
            results.append(dets)

        loop.on_detections(sync_cb)
        await loop._notify([{"identity": "Bob", "confidence": 0.8, "bbox": [0, 0, 1, 1]}])
        assert len(results) == 1
        assert results[0][0]["identity"] == "Bob"


# ---------------------------------------------------------------------------
# Server integration tests — /fr endpoint
# ---------------------------------------------------------------------------


@pytest.fixture
def waypoints_file(tmp_path):
    wp = tmp_path / "wps.yaml"
    wp.write_text("waypoints:\n  - name: lobby\n    x: 0\n    y: 0\n    theta: 0\n")
    return str(wp)


def _make_server(waypoints_file: str) -> EdgeProxyServer:
    """Create an EdgeProxyServer with FR loop disabled (mocked)."""
    with patch("edge_proxy.server.FRLoop") as MockFRLoop:
        mock_fr = MagicMock()
        mock_fr.start = AsyncMock()
        mock_fr.stop = MagicMock()
        mock_fr.on_detections = MagicMock()
        MockFRLoop.return_value = mock_fr

        server = EdgeProxyServer(
            host="127.0.0.1",
            port=0,
            ws_path="/edge",
            health_path="/health",
            backend="mock",
            waypoints_path=waypoints_file,
        )
    return server


@pytest.mark.asyncio
async def test_fr_endpoint_receives_detections(waypoints_file):
    """A client connected to /fr should receive fr_detections broadcasts."""
    server = _make_server(waypoints_file)
    await server._backend.start()

    ws_server = await websockets.serve(
        server._route_handler, server.host, 0,
        process_request=server._process_request,
    )
    try:
        port = ws_server.sockets[0].getsockname()[1]

        async with websockets.connect(f"ws://127.0.0.1:{port}/fr") as ws:
            # Wait briefly for the connection to register
            await asyncio.sleep(0.05)

            # Simulate FR detections broadcast
            test_detections = [{"identity": "Alice", "confidence": 0.9, "bbox": [10, 20, 100, 200]}]
            await server._broadcast_fr_detections(test_detections)

            msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=2.0))
            assert msg["type"] == "fr_detections"
            assert len(msg["detections"]) == 1
            assert msg["detections"][0]["identity"] == "Alice"
            assert "timestamp" in msg
    finally:
        ws_server.close()
        await ws_server.wait_closed()


@pytest.mark.asyncio
async def test_fr_detections_also_sent_to_edge_clients(waypoints_file):
    """FR detections should be broadcast to /edge clients as well."""
    server = _make_server(waypoints_file)
    await server._backend.start()

    ws_server = await websockets.serve(
        server._route_handler, server.host, 0,
        process_request=server._process_request,
    )
    try:
        port = ws_server.sockets[0].getsockname()[1]

        async with websockets.connect(f"ws://127.0.0.1:{port}/edge") as ws:
            # Drain initial sync messages (waypoint_list + robot_state)
            await asyncio.wait_for(ws.recv(), timeout=2.0)
            await asyncio.wait_for(ws.recv(), timeout=2.0)

            test_detections = [{"identity": "Bob", "confidence": 0.85, "bbox": [5, 5, 50, 50]}]
            await server._broadcast_fr_detections(test_detections)

            deadline = asyncio.get_event_loop().time() + 2.0
            while True:
                if asyncio.get_event_loop().time() > deadline:
                    raise AssertionError("Timed out waiting for fr_detections message")
                msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=2.0))
                if msg.get("type") != "fr_detections":
                    continue
                assert msg["detections"][0]["identity"] == "Bob"
                break
    finally:
        ws_server.close()
        await ws_server.wait_closed()


@pytest.mark.asyncio
async def test_fr_endpoint_ignores_incoming_messages(waypoints_file):
    """The /fr endpoint should accept but ignore any incoming messages."""
    server = _make_server(waypoints_file)
    await server._backend.start()

    ws_server = await websockets.serve(
        server._route_handler, server.host, 0,
        process_request=server._process_request,
    )
    try:
        port = ws_server.sockets[0].getsockname()[1]

        async with websockets.connect(f"ws://127.0.0.1:{port}/fr") as ws:
            # Sending a message should not cause any error
            await ws.send("hello")
            await ws.send(json.dumps({"type": "ping"}))
            # Connection should still be alive
            await asyncio.sleep(0.05)
            assert ws.open
    finally:
        ws_server.close()
        await ws_server.wait_closed()


@pytest.mark.asyncio
async def test_process_request_accepts_fr_path(waypoints_file):
    """_process_request should accept /fr path (not return 404)."""
    server = _make_server(waypoints_file)
    result = server._process_request("/fr", {})
    assert result is None  # None means "proceed with WebSocket upgrade"


@pytest.mark.asyncio
async def test_process_request_rejects_unknown_path(waypoints_file):
    """_process_request should reject unknown paths with 404."""
    server = _make_server(waypoints_file)
    result = server._process_request("/unknown", {})
    assert result is not None
    assert result[0].value == 404
