"""Tests for Edge Proxy WebSocket client.

Tests follow the Edge Proxy Design spec at:
/home/nelsen/Projects/HRI/docs/plans/2026-02-04-edge-proxy-design.md
"""

from __future__ import annotations

import asyncio
import json
import sys
from pathlib import Path
from unittest.mock import AsyncMock, patch

import pytest

ROOT = Path(__file__).resolve().parents[3]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.edge_proxy.client import (
    ClientConfig,
    EdgeProxyClient,
    EdgeProxyClientError,
    EdgeProxyConnectionError,
)
from src.edge_proxy.messages import (
    Battery,
    CancelNavigationCommand,
    ErrorMessage,
    EventLogMessage,
    GetStateCommand,
    MessageType,
    NavigateCommand,
    NavStatus,
    NavErrorCode,
    NavStatusMessage,
    PongMessage,
    Pose,
    RobotState,
    Speed,
    Waypoint,
    WaypointListMessage,
    parse_edge_message,
)


# =============================================================================
# Fixtures
# =============================================================================


@pytest.fixture
def client_config():
    """Create a test client configuration."""
    return ClientConfig(
        host="localhost",
        port=8080,
        ws_path="/edge",
        reconnect=False,  # Disable auto-reconnect for tests
    )


@pytest.fixture
def client(client_config):
    """Create a test client instance."""
    return EdgeProxyClient(config=client_config)


# =============================================================================
# Message Dataclass Tests (Spec: Section 4-5)
# =============================================================================


class TestNavigateCommand:
    """Test NavigateCommand message creation (Spec: Section 4.1-4.3)."""

    def test_to_waypoint(self):
        """Test creating waypoint navigation command."""
        cmd = NavigateCommand.to_waypoint("lobby", "nav_001", "normal")
        data = cmd.to_dict()

        assert data["type"] == "navigate"
        assert data["request_id"] == "nav_001"
        assert data["goal"]["type"] == "waypoint"
        assert data["goal"]["name"] == "lobby"
        assert data["speed"] == "normal"

    def test_to_pose(self):
        """Test creating pose navigation command."""
        cmd = NavigateCommand.to_pose(5.2, 3.1, 1.57, "nav_002", "fast")
        data = cmd.to_dict()

        assert data["type"] == "navigate"
        assert data["request_id"] == "nav_002"
        assert data["goal"]["type"] == "pose"
        assert data["goal"]["x"] == 5.2
        assert data["goal"]["y"] == 3.1
        assert data["goal"]["theta"] == 1.57
        assert data["speed"] == "fast"

    def test_to_relative(self):
        """Test creating relative navigation command."""
        cmd = NavigateCommand.to_relative("forward", 1.0, "nav_003", "slow")
        data = cmd.to_dict()

        assert data["type"] == "navigate"
        assert data["request_id"] == "nav_003"
        assert data["goal"]["type"] == "relative"
        assert data["goal"]["direction"] == "forward"
        assert data["goal"]["distance"] == 1.0
        assert data["speed"] == "slow"


class TestClientConfigDefaults:
    def test_default_port(self):
        """Default port should match the robot-side Edge Proxy server default."""
        assert ClientConfig().port == 8080


class TestCancelNavigationCommand:
    """Test CancelNavigationCommand (Spec: Section 4.4)."""

    def test_to_dict(self):
        """Test converting CancelNavigationCommand to dictionary."""
        cmd = CancelNavigationCommand(
            type=MessageType.CANCEL_NAVIGATION,
            request_id="cancel_001",
            reason="barge_in"
        )
        data = cmd.to_dict()

        assert data["type"] == "cancel_navigation"
        assert data["request_id"] == "cancel_001"
        assert data["reason"] == "barge_in"


class TestGetStateCommand:
    """Test GetStateCommand (Spec: Section 4.5)."""

    def test_to_dict(self):
        """Test converting GetStateCommand to dictionary."""
        cmd = GetStateCommand(type=MessageType.GET_STATE, request_id="state_001")
        data = cmd.to_dict()

        assert data["type"] == "get_state"
        assert data["request_id"] == "state_001"


class TestNavStatusMessage:
    """Test NavStatusMessage parsing (Spec: Section 5.1)."""

    def test_from_dict_navigating(self):
        """Test parsing navigating status message."""
        data = {
            "type": "nav_status",
            "request_id": "nav_001",
            "status": "navigating",
            "destination": "lobby",
            "progress": 0.45,
            "eta_sec": 12,
        }
        msg = NavStatusMessage.from_dict(data)

        assert msg.status == "navigating"
        assert msg.destination == "lobby"
        assert msg.progress == 0.45
        assert msg.eta_sec == 12

    def test_from_dict_failed(self):
        """Test parsing failed status message with error code."""
        data = {
            "type": "nav_status",
            "request_id": "nav_001",
            "status": "failed",
            "reason": "obstacle_detected",
            "error_code": "NAV_BLOCKED",
        }
        msg = NavStatusMessage.from_dict(data)

        assert msg.status == "failed"
        assert msg.reason == "obstacle_detected"
        assert msg.error_code == "NAV_BLOCKED"


class TestRobotState:
    """Test RobotState parsing (Spec: Section 5.2)."""

    def test_from_dict_full(self):
        """Test parsing full robot state message."""
        data = {
            "type": "robot_state",
            "timestamp": 1706500000.123,
            "pose": {"x": 5.2, "y": 3.1, "theta": 1.57},
            "location": "near_lobby",
            "battery": {"level": 85, "charging": False},
            "nav_state": "navigating",
            "nav_progress": 0.45,
            "nav_destination": "lobby",
        }
        state = RobotState.from_dict(data)

        assert state.timestamp == 1706500000.123
        assert state.pose.x == 5.2
        assert state.pose.y == 3.1
        assert state.pose.theta == 1.57
        assert state.location == "near_lobby"
        assert state.battery.level == 85
        assert state.battery.charging is False
        assert state.nav_state == "navigating"
        assert state.nav_progress == 0.45
        assert state.nav_destination == "lobby"


class TestWaypointListMessage:
    """Test WaypointListMessage parsing (Spec: Section 5.3)."""

    def test_from_dict(self):
        """Test parsing waypoint list message."""
        data = {
            "type": "waypoint_list",
            "waypoints": [
                {"name": "lobby", "x": 0.0, "y": 0.0, "theta": 0},
                {"name": "checkpoint_1", "x": 5.0, "y": 2.0, "theta": 1.57},
            ],
        }
        msg = WaypointListMessage.from_dict(data)

        assert len(msg.waypoints) == 2
        assert msg.waypoints[0].name == "lobby"
        assert msg.waypoints[1].name == "checkpoint_1"


class TestErrorMessage:
    """Test ErrorMessage parsing (Spec: Section 5.4)."""

    def test_from_dict(self):
        """Test parsing error message."""
        data = {
            "type": "error",
            "request_id": "nav_001",
            "error": "waypoint_not_found",
            "message": "Waypoint 'unknown_location' not in registry",
        }
        err = ErrorMessage.from_dict(data)

        assert err.request_id == "nav_001"
        assert err.error == "waypoint_not_found"
        assert err.message == "Waypoint 'unknown_location' not in registry"


class TestPongMessage:
    """Test PongMessage parsing (Spec: Section 5.5)."""

    def test_from_dict(self):
        """Test parsing pong message."""
        data = {"type": "pong", "state": "idle"}
        msg = PongMessage.from_dict(data)

        assert msg.state == "idle"


class TestParseEdgeMessage:
    """Test parse_edge_message function."""

    def test_parse_nav_status(self):
        """Test parsing nav_status message."""
        data = {
            "type": "nav_status",
            "request_id": "nav_001",
            "status": "arrived",
            "destination": "goal",
        }
        msg = parse_edge_message(data)

        assert isinstance(msg, NavStatusMessage)
        assert msg.status == "arrived"

    def test_parse_robot_state(self):
        """Test parsing robot_state message."""
        data = {
            "type": "robot_state",
            "timestamp": 1706500000.0,
            "pose": {"x": 1.0, "y": 2.0, "theta": 0.0},
            "location": "home",
            "battery": {"level": 90, "charging": False},
        }
        msg = parse_edge_message(data)

        assert isinstance(msg, RobotState)
        assert msg.location == "home"

    def test_parse_waypoint_list(self):
        """Test parsing waypoint_list message."""
        data = {
            "type": "waypoint_list",
            "waypoints": [
                {"name": "wp1", "x": 1.0, "y": 2.0, "theta": 0.0},
            ],
        }
        msg = parse_edge_message(data)

        assert isinstance(msg, WaypointListMessage)
        assert len(msg.waypoints) == 1

    def test_parse_error(self):
        """Test parsing error message."""
        data = {
            "type": "error",
            "request_id": "req_001",
            "error": "test_error",
            "message": "Test error message",
        }
        msg = parse_edge_message(data)

        assert isinstance(msg, ErrorMessage)
        assert msg.error == "test_error"

    def test_parse_pong(self):
        """Test parsing pong message."""
        data = {"type": "pong", "state": "idle"}
        msg = parse_edge_message(data)

        assert isinstance(msg, PongMessage)
        assert msg.state == "idle"

    def test_parse_event_log(self):
        """Test parsing event_log replay message."""
        data = {
            "type": "event_log",
            "event_id": "edge-1",
            "event_type": "nav_status",
            "request_id": "nav_001",
            "status": "failed",
            "timestamp": 1706500000.0,
            "replay": True,
            "payload": {"error_code": "NAV_LOCALIZATION_LOST"},
        }
        msg = parse_edge_message(data)

        assert isinstance(msg, EventLogMessage)
        assert msg.event_id == "edge-1"
        assert msg.replay is True

    def test_parse_unknown_type(self):
        """Test parsing unknown message type returns None."""
        data = {"type": "unknown_type"}
        msg = parse_edge_message(data)

        assert msg is None

    def test_parse_missing_type(self):
        """Test parsing message without type raises ValueError."""
        data = {"status": "idle"}

        with pytest.raises(ValueError, match="Missing 'type' field"):
            parse_edge_message(data)


# =============================================================================
# Client Initialization Tests
# =============================================================================


class TestEdgeProxyClientInit:
    """Test EdgeProxyClient initialization."""

    def test_init_default_config(self):
        """Test client initialization with default config."""
        client = EdgeProxyClient()

        assert client.config.host == "localhost"
        assert client.config.port == 8080
        assert client.config.ws_path == "/edge"
        assert client.config.ping_interval == 10.0
        assert client.is_connected is False

    def test_init_custom_config(self):
        """Test client initialization with custom config."""
        config = ClientConfig(host="robot.local", port=9000)
        client = EdgeProxyClient(config=config)

        assert client.config.host == "robot.local"
        assert client.config.port == 9000


# =============================================================================
# Async Tests
# =============================================================================


@pytest.mark.asyncio
class TestEdgeProxyClientConnection:
    """Test EdgeProxyClient connection handling."""

    async def test_connect_success(self, client):
        """Test successful connection."""
        async def mock_connect(*args, **kwargs):
            mock_ws = AsyncMock()
            mock_ws.closed = False
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()
            assert client.is_connected
            await client.disconnect()

    async def test_connect_failure_no_reconnect(self, client):
        """Test connection failure without reconnect."""
        async def mock_connect_fail(*args, **kwargs):
            raise OSError("Connection refused")

        with patch("websockets.connect", side_effect=mock_connect_fail):
            with pytest.raises(EdgeProxyConnectionError):
                await client.connect()

    async def test_disconnect(self, client):
        """Test disconnection."""
        async def mock_connect(*args, **kwargs):
            mock_ws = AsyncMock()
            mock_ws.closed = False
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()
            assert client.is_connected

            await client.disconnect()
            assert client.is_connected is False


@pytest.mark.asyncio
class TestEdgeProxyClientCommands:
    """Test EdgeProxyClient command sending (Spec: Section 4)."""

    async def test_send_navigate_waypoint(self, client):
        """Test sending waypoint navigation command."""
        mock_ws = AsyncMock()
        mock_ws.closed = False
        mock_ws.send = AsyncMock()

        async def mock_connect(*args, **kwargs):
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()
            await client.send_navigate_waypoint("waypoint1", "nav_001", "fast")

            sent_data = json.loads(mock_ws.send.call_args[0][0])
            assert sent_data["type"] == "navigate"
            assert sent_data["request_id"] == "nav_001"
            assert sent_data["goal"]["type"] == "waypoint"
            assert sent_data["goal"]["name"] == "waypoint1"
            assert sent_data["speed"] == "fast"

    async def test_send_navigate_pose(self, client):
        """Test sending pose navigation command."""
        mock_ws = AsyncMock()
        mock_ws.closed = False
        mock_ws.send = AsyncMock()

        async def mock_connect(*args, **kwargs):
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()
            await client.send_navigate_pose(5.2, 3.1, 1.57, "nav_002")

            sent_data = json.loads(mock_ws.send.call_args[0][0])
            assert sent_data["goal"]["type"] == "pose"
            assert sent_data["goal"]["x"] == 5.2
            assert sent_data["goal"]["y"] == 3.1
            assert sent_data["goal"]["theta"] == 1.57

    async def test_send_navigate_relative(self, client):
        """Test sending relative navigation command."""
        mock_ws = AsyncMock()
        mock_ws.closed = False
        mock_ws.send = AsyncMock()

        async def mock_connect(*args, **kwargs):
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()
            await client.send_navigate_relative("forward", 1.0)

            sent_data = json.loads(mock_ws.send.call_args[0][0])
            assert sent_data["goal"]["type"] == "relative"
            assert sent_data["goal"]["direction"] == "forward"
            assert sent_data["goal"]["distance"] == 1.0
            assert sent_data["speed"] == "slow"

    async def test_send_cancel_navigation(self, client):
        """Test sending cancel navigation command."""
        mock_ws = AsyncMock()
        mock_ws.closed = False
        mock_ws.send = AsyncMock()

        async def mock_connect(*args, **kwargs):
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()
            await client.send_cancel_navigation("cancel_001", "barge_in")

            sent_data = json.loads(mock_ws.send.call_args[0][0])
            assert sent_data["type"] == "cancel_navigation"
            assert sent_data["request_id"] == "cancel_001"
            assert sent_data["reason"] == "barge_in"

    async def test_send_get_state(self, client):
        """Test sending get_state command."""
        mock_ws = AsyncMock()
        mock_ws.closed = False
        mock_ws.send = AsyncMock()

        async def mock_connect(*args, **kwargs):
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()
            await client.send_get_state("state_001")

            sent_data = json.loads(mock_ws.send.call_args[0][0])
            assert sent_data["type"] == "get_state"
            assert sent_data["request_id"] == "state_001"

    async def test_send_ping(self, client):
        """Test sending ping command."""
        mock_ws = AsyncMock()
        mock_ws.closed = False
        mock_ws.send = AsyncMock()

        async def mock_connect(*args, **kwargs):
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()
            await client.send_ping()

            sent_data = json.loads(mock_ws.send.call_args[0][0])
            assert sent_data["type"] == "ping"

    async def test_send_when_disconnected_raises(self, client):
        """Test that sending when disconnected raises EdgeProxyConnectionError."""
        with pytest.raises(EdgeProxyConnectionError, match="Not connected"):
            await client.send_navigate_waypoint("waypoint1")


@pytest.mark.asyncio
class TestEdgeProxyClientHandlers:
    """Test EdgeProxyClient event handlers (Spec: Section 5)."""

    async def test_nav_status_handler(self, client):
        """Test nav_status event handler."""
        received_messages = []

        @client.on_nav_status
        def handle_nav_status(msg):
            received_messages.append(msg)

        async def mock_connect(*args, **kwargs):
            mock_ws = AsyncMock()
            mock_ws.closed = False
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()

            message = json.dumps({
                "type": "nav_status",
                "request_id": "nav_001",
                "status": "arrived",
                "destination": "goal",
                "progress": 1.0,
            })

            await client._handle_message(message)
            await client.disconnect()

        assert len(received_messages) == 1
        assert received_messages[0].status == "arrived"

    async def test_robot_state_handler(self, client):
        """Test robot_state event handler."""
        received_states = []

        @client.on_robot_state
        def handle_robot_state(state):
            received_states.append(state)

        async def mock_connect(*args, **kwargs):
            mock_ws = AsyncMock()
            mock_ws.closed = False
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()

            message = json.dumps({
                "type": "robot_state",
                "timestamp": 1706500000.0,
                "pose": {"x": 1.0, "y": 2.0, "theta": 0.0},
                "location": "waypoint3",
                "battery": {"level": 75, "charging": False},
            })

            await client._handle_message(message)
            await client.disconnect()

        assert len(received_states) == 1
        assert received_states[0].battery.level == 75

    async def test_waypoint_list_handler(self, client):
        """Test waypoint_list event handler."""
        received_lists = []

        @client.on_waypoint_list
        def handle_waypoint_list(msg):
            received_lists.append(msg)

        async def mock_connect(*args, **kwargs):
            mock_ws = AsyncMock()
            mock_ws.closed = False
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()

            message = json.dumps({
                "type": "waypoint_list",
                "waypoints": [
                    {"name": "wp1", "x": 1.0, "y": 2.0, "theta": 0.0},
                    {"name": "wp2", "x": 3.0, "y": 4.0, "theta": 0.0},
                ],
            })

            await client._handle_message(message)
            await client.disconnect()

        assert len(received_lists) == 1
        assert len(received_lists[0].waypoints) == 2

    async def test_error_handler(self, client):
        """Test error event handler."""
        received_errors = []

        @client.on_error
        def handle_error(err):
            received_errors.append(err)

        async def mock_connect(*args, **kwargs):
            mock_ws = AsyncMock()
            mock_ws.closed = False
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()

            message = json.dumps({
                "type": "error",
                "request_id": "req_001",
                "error": "test_error",
                "message": "Test error message",
            })

            await client._handle_message(message)
            await client.disconnect()

        assert len(received_errors) == 1
        assert received_errors[0].error == "test_error"

    async def test_pong_handler(self, client):
        """Test pong event handler."""
        received_pongs = []

        @client.on_pong
        def handle_pong(msg):
            received_pongs.append(msg)

        async def mock_connect(*args, **kwargs):
            mock_ws = AsyncMock()
            mock_ws.closed = False
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()

            message = json.dumps({
                "type": "pong",
                "state": "idle",
            })

            await client._handle_message(message)
            await client.disconnect()

        assert len(received_pongs) == 1
        assert received_pongs[0].state == "idle"

    async def test_connection_change_handler(self, client):
        """Test connection state change handler."""
        changes = []

        @client.on_connection_change
        def handle_connection_change(connected):
            changes.append(connected)

        async def mock_connect(*args, **kwargs):
            mock_ws = AsyncMock()
            mock_ws.closed = False
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()
            await client.disconnect()

        assert changes == [True, False]

    async def test_async_handler(self, client):
        """Test async event handler."""
        received = []

        @client.on_nav_status
        async def handle_nav_status(msg):
            await asyncio.sleep(0)
            received.append(msg)

        async def mock_connect(*args, **kwargs):
            mock_ws = AsyncMock()
            mock_ws.closed = False
            return mock_ws

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()

            message = json.dumps({
                "type": "nav_status",
                "request_id": "nav_001",
                "status": "navigating",
                "destination": "goal",
                "progress": 0.5,
            })

            await client._handle_message(message)
            await client.disconnect()

        assert len(received) == 1

    async def test_event_log_handler_deduplicates_replays(self, client):
        """Repeated event_id should only be handled once."""
        received = []

        @client.on_event_log
        def handle_event_log(msg):
            received.append(msg)

        async def mock_connect(*args, **kwargs):
            mock_ws = AsyncMock()
            mock_ws.closed = False
            return mock_ws

        payload = json.dumps(
            {
                "type": "event_log",
                "event_id": "edge-123",
                "event_type": "nav_status",
                "request_id": "nav_001",
                "status": "failed",
                "timestamp": 1706500000.0,
                "replay": True,
                "payload": {"reason": "localization_unavailable"},
            }
        )

        with patch("websockets.connect", side_effect=mock_connect):
            await client.connect()
            await client._handle_message(payload)
            await client._handle_message(payload)
            await client.disconnect()

        assert len(received) == 1
        assert received[0].event_id == "edge-123"
