import asyncio
import json
from pathlib import Path
from unittest.mock import MagicMock

import pytest
import websockets

from edge_proxy.server import EdgeProxyServer


@pytest.mark.asyncio
async def test_server_mock_navigate(tmp_path):
    waypoints = tmp_path / "wps.yaml"
    waypoints.write_text(
        """
waypoints:
  - name: lobby
    x: 0
    y: 0
    theta: 0
""".lstrip()
    )

    server = EdgeProxyServer(
        host="127.0.0.1",
        port=0,
        ws_path="/edge",
        health_path="/health",
        backend="mock",
        waypoints_path=str(waypoints),
    )

    # Bind to an ephemeral port
    await server._backend.start()
    ws_server = await websockets.serve(server._handler, server.host, 0, process_request=server._process_request)
    try:
        port = ws_server.sockets[0].getsockname()[1]
        async with websockets.connect(f"ws://127.0.0.1:{port}/edge") as ws:
            # drain initial messages
            init1 = json.loads(await ws.recv())
            init2 = json.loads(await ws.recv())
            assert init1["type"] in ("waypoint_list", "robot_state")
            assert init2["type"] in ("waypoint_list", "robot_state")

            await ws.send(
                json.dumps(
                    {
                        "type": "navigate",
                        "request_id": "r1",
                        "goal": {"type": "waypoint", "name": "lobby"},
                        "speed": "normal",
                    }
                )
            )

            # Expect accepted then arrived eventually
            statuses = []
            deadline = asyncio.get_event_loop().time() + 5.0
            while asyncio.get_event_loop().time() < deadline:
                msg = json.loads(await ws.recv())
                if msg.get("type") == "nav_status" and msg.get("request_id") == "r1":
                    statuses.append(msg.get("status"))
                    if msg.get("status") == "arrived":
                        break

            assert "accepted" in statuses
            assert "arrived" in statuses
    finally:
        ws_server.close()
        await ws_server.wait_closed()
        await server._backend.stop()


@pytest.mark.asyncio
async def test_periodic_state_broadcast(tmp_path):
    """Periodic broadcast delivers robot_state to connected clients every ~2 s."""
    waypoints = tmp_path / "wps.yaml"
    waypoints.write_text(
        """
waypoints:
  - name: lobby
    x: 0
    y: 0
    theta: 0
""".lstrip()
    )

    server = EdgeProxyServer(
        host="127.0.0.1",
        port=0,
        ws_path="/edge",
        health_path="/health",
        backend="mock",
        waypoints_path=str(waypoints),
    )

    await server._backend.start()

    # Start the periodic broadcast task explicitly (serve() is not called here).
    broadcast_task = asyncio.create_task(server._periodic_state_broadcast())

    ws_server = await websockets.serve(server._handler, server.host, 0, process_request=server._process_request)
    try:
        port = ws_server.sockets[0].getsockname()[1]
        async with websockets.connect(f"ws://127.0.0.1:{port}/edge") as ws:
            # Drain initial waypoint_list + robot_state sent on connection.
            init1 = json.loads(await asyncio.wait_for(ws.recv(), timeout=2.0))
            init2 = json.loads(await asyncio.wait_for(ws.recv(), timeout=2.0))
            assert init1["type"] in ("waypoint_list", "robot_state")
            assert init2["type"] in ("waypoint_list", "robot_state")

            # Wait for at least one periodic robot_state broadcast (arrives within ~2.5 s).
            deadline = asyncio.get_event_loop().time() + 4.0
            periodic_received = False
            while asyncio.get_event_loop().time() < deadline:
                try:
                    msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=2.5))
                except asyncio.TimeoutError:
                    break
                if msg.get("type") == "robot_state":
                    periodic_received = True
                    break

            assert periodic_received, "No periodic robot_state received within 4 s"
    finally:
        broadcast_task.cancel()
        try:
            await broadcast_task
        except asyncio.CancelledError:
            pass
        ws_server.close()
        await ws_server.wait_closed()
        await server._backend.stop()


@pytest.mark.asyncio
async def test_server_mock_navigate_relative(tmp_path):
    waypoints = tmp_path / "wps.yaml"
    waypoints.write_text(
        """
waypoints:
  - name: lobby
    x: 0
    y: 0
    theta: 0
""".lstrip()
    )

    server = EdgeProxyServer(
        host="127.0.0.1",
        port=0,
        ws_path="/edge",
        health_path="/health",
        backend="mock",
        waypoints_path=str(waypoints),
    )

    # Bind to an ephemeral port
    await server._backend.start()
    ws_server = await websockets.serve(server._handler, server.host, 0, process_request=server._process_request)
    try:
        port = ws_server.sockets[0].getsockname()[1]
        async with websockets.connect(f"ws://127.0.0.1:{port}/edge") as ws:
            # drain initial messages
            init1 = json.loads(await ws.recv())
            init2 = json.loads(await ws.recv())
            assert init1["type"] in ("waypoint_list", "robot_state")
            assert init2["type"] in ("waypoint_list", "robot_state")

            await ws.send(
                json.dumps(
                    {
                        "type": "navigate",
                        "request_id": "r1",
                        "goal": {"type": "relative", "direction": "forward", "distance": 1.0},
                        "speed": "slow",
                    }
                )
            )

            # Expect accepted then arrived eventually
            statuses = []
            deadline = asyncio.get_event_loop().time() + 5.0
            while asyncio.get_event_loop().time() < deadline:
                msg = json.loads(await ws.recv())
                if msg.get("type") == "nav_status" and msg.get("request_id") == "r1":
                    statuses.append(msg.get("status"))
                    if msg.get("status") == "arrived":
                        break

            assert "accepted" in statuses
            assert "arrived" in statuses
    finally:
        ws_server.close()
        await ws_server.wait_closed()
        await server._backend.stop()


@pytest.mark.asyncio
async def test_offline_event_replay_and_arrival_artifact(tmp_path, monkeypatch):
    waypoints = tmp_path / "wps.yaml"
    waypoints.write_text(
        """
waypoints:
  - name: office_scene
    x: 1
    y: 1
    theta: 0
""".lstrip()
    )

    journal_path = tmp_path / "edge-events.jsonl"
    artifact_dir = tmp_path / "artifacts"
    monkeypatch.setenv("EDGE_PROXY_EVENT_JOURNAL_PATH", str(journal_path))
    monkeypatch.setenv("EDGE_PROXY_ARTIFACT_DIR", str(artifact_dir))

    server = EdgeProxyServer(
        host="127.0.0.1",
        port=0,
        ws_path="/edge",
        health_path="/health",
        backend="mock",
        waypoints_path=str(waypoints),
    )
    server._frame_capture.capture = MagicMock(return_value=b"\xff\xd8\xff\xd9")

    await server._backend.start()
    ws_server = await websockets.serve(
        server._handler, server.host, 0, process_request=server._process_request
    )
    try:
        port = ws_server.sockets[0].getsockname()[1]

        async with websockets.connect(f"ws://127.0.0.1:{port}/edge") as ws:
            await ws.recv()
            await ws.recv()
            await ws.send(
                json.dumps(
                    {
                        "type": "navigate",
                        "request_id": "offline_nav",
                        "goal": {"type": "waypoint", "name": "office_scene"},
                        "speed": "normal",
                    }
                )
            )
            # Simulate orchestrator disconnect shortly after issuing command.

        await asyncio.sleep(3.2)

        replay_events = []
        async with websockets.connect(f"ws://127.0.0.1:{port}/edge") as ws2:
            await ws2.recv()
            await ws2.recv()
            deadline = asyncio.get_event_loop().time() + 3.0
            while asyncio.get_event_loop().time() < deadline:
                msg = json.loads(await asyncio.wait_for(ws2.recv(), timeout=1.0))
                if msg.get("type") != "event_log":
                    continue
                replay_events.append(msg)
                have_arrived = any(
                    ev.get("event_type") == "nav_status"
                    and ev.get("status") == "arrived"
                    for ev in replay_events
                )
                have_artifact = any(
                    ev.get("event_type") == "scan_area_capture"
                    and ev.get("status") == "saved"
                    for ev in replay_events
                )
                if have_arrived and have_artifact:
                    break

        assert replay_events, "Expected replayed event_log messages after reconnect"
        assert all(ev.get("replay") is True for ev in replay_events)

        arrived_events = [
            ev for ev in replay_events
            if ev.get("event_type") == "nav_status" and ev.get("status") == "arrived"
        ]
        assert arrived_events

        artifact_events = [
            ev for ev in replay_events
            if ev.get("event_type") == "scan_area_capture" and ev.get("status") == "saved"
        ]
        assert artifact_events
        artifact_path = artifact_events[-1]["payload"]["path"]
        assert artifact_path
        assert (tmp_path / "artifacts").exists()
        assert (tmp_path / "artifacts" / Path(artifact_path).name).exists()

        # Replay success should clear persisted queue.
        lines = journal_path.read_text(encoding="utf-8").strip()
        assert lines == ""
    finally:
        ws_server.close()
        await ws_server.wait_closed()
        await server._backend.stop()
