"""Edge Proxy WebSocket client.

This module provides the EdgeProxyClient for bidirectional WebSocket communication
with the robot Edge Proxy, following the Edge Proxy Design spec.
See: /home/nelsen/Projects/HRI/docs/plans/2026-02-04-edge-proxy-design.md

Topology (current implementation):
- Edge Proxy runs as a WebSocket *server* on the robot (default :8080/edge).
- Orchestrator connects to it as a WebSocket *client*.
"""

from __future__ import annotations

import asyncio
import json
import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional

import websockets
from websockets.client import WebSocketClientProtocol

from .messages import (
    CancelNavigationCommand,
    CaptureFrameCommand,
    ErrorMessage,
    EventLogMessage,
    FRDetectionsMessage,
    FrameResponseMessage,
    GetStateCommand,
    MessageType,
    NavigateCommand,
    NavStatusMessage,
    PongMessage,
    RobotState,
    Waypoint,
    WaypointListMessage,
    parse_edge_message,
)


# Event handler types
NavStatusHandler = Callable[[NavStatusMessage], Any]
RobotStateHandler = Callable[[RobotState], Any]
WaypointListHandler = Callable[[WaypointListMessage], Any]
ErrorHandler = Callable[[ErrorMessage], Any]
PongHandler = Callable[[PongMessage], Any]
FrameResponseHandler = Callable[[FrameResponseMessage], Any]
FRDetectionsHandler = Callable[[FRDetectionsMessage], Any]
EventLogHandler = Callable[[EventLogMessage], Any]
ConnectionChangeHandler = Callable[[bool], Any]


class EdgeProxyClientError(Exception):
    """Base exception for Edge Proxy client errors."""

    pass


class EdgeProxyConnectionError(EdgeProxyClientError):
    """Raised when connection to Edge Proxy fails."""

    pass


class MessageError(EdgeProxyClientError):
    """Raised when message parsing or validation fails."""

    pass


@dataclass
class ClientConfig:
    """Configuration for EdgeProxyClient.

    Spec: Section 3.1 of Edge Proxy Design
    """

    host: str = "localhost"
    # Edge Proxy WS server default is 8080 (robot-side).
    port: int = 8080
    ws_path: str = "/edge"
    reconnect: bool = True
    reconnect_delay: float = 1.0
    max_reconnect_delay: float = 30.0
    ping_interval: float = 10.0  # Spec says 10 seconds
    ping_timeout: float = 10.0


class EdgeProxyClient:
    """WebSocket client for Edge Proxy communication.

    This client connects to the robot's Edge Proxy server and handles
    bidirectional communication with automatic reconnection.

    Example:
        client = EdgeProxyClient(host="robot.local", port=8080)

        @client.on_nav_status
        def handle_nav_status(msg):
            print(f"Nav status: {msg.status}, progress: {msg.progress}")

        await client.connect()
        await client.send_navigate_waypoint("waypoint1", request_id="nav_001")
        await client.disconnect()
    """

    def __init__(
        self,
        config: Optional[ClientConfig] = None,
        logger: Optional[logging.Logger] = None,
    ) -> None:
        """Initialize the Edge Proxy client.

        Args:
            config: Client configuration. Uses defaults if not provided.
            logger: Logger instance. Creates default logger if not provided.
        """
        self.config = config or ClientConfig()
        self.logger = logger or logging.getLogger(__name__)

        # WebSocket connection
        self._ws: Optional[WebSocketClientProtocol] = None
        self._connected = False
        self._should_stop = False

        # Event handlers
        self._nav_status_handlers: List[NavStatusHandler] = []
        self._robot_state_handlers: List[RobotStateHandler] = []
        self._waypoint_list_handlers: List[WaypointListHandler] = []
        self._error_handlers: List[ErrorHandler] = []
        self._pong_handlers: List[PongHandler] = []
        self._frame_response_handlers: List[FrameResponseHandler] = []
        self._fr_detections_handlers: List[FRDetectionsHandler] = []
        self._event_log_handlers: List[EventLogHandler] = []
        self._connection_change_handlers: List[ConnectionChangeHandler] = []

        # Background task
        self._receiver_task: Optional[asyncio.Task[None]] = None

        # Last received robot state (updated on every robot_state message)
        self.last_robot_state: Optional[RobotState] = None

        # Last received FR detections (updated on every fr_detections message)
        self.last_fr_detections: Optional[FRDetectionsMessage] = None
        self.last_event_log: Optional[EventLogMessage] = None

        # Deduplicate replayed edge events across reconnects.
        self._seen_event_ids: set[str] = set()
        self._seen_event_ids_order: List[str] = []
        self._seen_event_ids_limit = 2000

        # Pending frame-capture futures keyed by request_id.
        # Populated by send_capture_frame(); resolved by _dispatch_message().
        self._pending_frames: Dict[str, asyncio.Future] = {}

    def on_nav_status(self, handler: NavStatusHandler) -> NavStatusHandler:
        """Register a navigation status event handler.

        Spec: Section 5.1 - Nav status messages

        Args:
            handler: Async or sync function that takes NavStatusMessage.

        Returns:
            The handler function for use as a decorator.
        """
        self._nav_status_handlers.append(handler)
        return handler

    def on_robot_state(self, handler: RobotStateHandler) -> RobotStateHandler:
        """Register a robot state event handler.

        Spec: Section 5.2 - Robot state messages

        Args:
            handler: Async or sync function that takes RobotState.

        Returns:
            The handler function for use as a decorator.
        """
        self._robot_state_handlers.append(handler)
        return handler

    def on_waypoint_list(self, handler: WaypointListHandler) -> WaypointListHandler:
        """Register a waypoint list event handler.

        Spec: Section 5.3 - Waypoint list messages

        Args:
            handler: Async or sync function that takes WaypointListMessage.

        Returns:
            The handler function for use as a decorator.
        """
        self._waypoint_list_handlers.append(handler)
        return handler

    def on_error(self, handler: ErrorHandler) -> ErrorHandler:
        """Register an error event handler.

        Spec: Section 5.4 - Error messages

        Args:
            handler: Async or sync function that takes ErrorMessage.

        Returns:
            The handler function for use as a decorator.
        """
        self._error_handlers.append(handler)
        return handler

    def on_pong(self, handler: PongHandler) -> PongHandler:
        """Register a pong event handler.

        Spec: Section 5.5 - Pong messages

        Args:
            handler: Async or sync function that takes PongMessage.

        Returns:
            The handler function for use as a decorator.
        """
        self._pong_handlers.append(handler)
        return handler

    def on_frame_response(self, handler: FrameResponseHandler) -> FrameResponseHandler:
        """Register a frame response event handler.

        Called when the Edge Proxy sends a ``frame_response`` message.
        Note: futures registered via ``send_capture_frame`` are resolved
        *before* these handlers are called, so handlers see all responses
        including those matched to pending futures.

        Args:
            handler: Async or sync function that takes FrameResponseMessage.

        Returns:
            The handler function for use as a decorator.
        """
        self._frame_response_handlers.append(handler)
        return handler

    def on_fr_detections(self, handler: FRDetectionsHandler) -> FRDetectionsHandler:
        """Register a face recognition detections event handler.

        Called when the Edge Proxy broadcasts ``fr_detections`` messages
        from its FR loop.

        Args:
            handler: Async or sync function that takes FRDetectionsMessage.

        Returns:
            The handler function for use as a decorator.
        """
        self._fr_detections_handlers.append(handler)
        return handler

    def on_event_log(self, handler: EventLogHandler) -> EventLogHandler:
        """Register an edge event_log handler."""
        self._event_log_handlers.append(handler)
        return handler

    def on_connection_change(self, handler: ConnectionChangeHandler) -> ConnectionChangeHandler:
        """Register a connection state change handler.

        Args:
            handler: Async or sync function that takes bool (connected).

        Returns:
            The handler function for use as a decorator.
        """
        self._connection_change_handlers.append(handler)
        return handler

    @property
    def is_connected(self) -> bool:
        """Check if the client is currently connected.

        Returns:
            True if connected, False otherwise.
        """
        return self._connected and self._ws is not None and not self._ws.closed

    async def connect(self) -> None:
        """Connect to the Edge Proxy server.

        Starts the connection and message receiver task.
        If reconnect is enabled in config, will automatically reconnect on disconnect.

        Raises:
            EdgeProxyConnectionError: If initial connection fails and reconnect is disabled.
        """
        self._should_stop = False
        await self._connect_with_retry()

    async def disconnect(self) -> None:
        """Disconnect from the Edge Proxy server.

        Stops the receiver task and closes the WebSocket connection.
        """
        self._should_stop = True

        if self._receiver_task:
            self._receiver_task.cancel()
            try:
                await self._receiver_task
            except asyncio.CancelledError:
                pass
            self._receiver_task = None

        if self._ws and not self._ws.closed:
            await self._ws.close()

        was_connected = self._connected
        self._connected = False

        if was_connected:
            await self._notify_connection_change(False)

        self.logger.info("Disconnected from Edge Proxy")

    # ========================================================================
    # Navigation Commands (Spec: Section 4)
    # ========================================================================

    async def send_navigate_waypoint(
        self,
        name: str,
        request_id: str = "",
        speed: str = "normal"
    ) -> None:
        """Send a waypoint navigation command.

        Spec: Section 4.1 - Navigate to Waypoint

        Args:
            name: Waypoint name.
            request_id: Optional correlation ID.
            speed: Navigation speed ("slow", "normal", "fast").

        Raises:
            EdgeProxyConnectionError: If not connected.
        """
        cmd = NavigateCommand.to_waypoint(name, request_id, speed)
        await self._send_message(cmd.to_dict())

    async def send_navigate_pose(
        self,
        x: float,
        y: float,
        theta: float = 0.0,
        request_id: str = "",
        speed: str = "normal"
    ) -> None:
        """Send a pose navigation command.

        Spec: Section 4.2 - Navigate to Pose

        Args:
            x: X coordinate in map frame (meters).
            y: Y coordinate in map frame (meters).
            theta: Orientation in radians.
            request_id: Optional correlation ID.
            speed: Navigation speed ("slow", "normal", "fast").

        Raises:
            EdgeProxyConnectionError: If not connected.
        """
        cmd = NavigateCommand.to_pose(x, y, theta, request_id, speed)
        await self._send_message(cmd.to_dict())

    async def send_navigate_relative(
        self,
        direction: str,
        distance: float,
        request_id: str = "",
        speed: str = "slow"
    ) -> None:
        """Send a relative navigation command.

        Spec: Section 4.3 - Navigate Relative

        Args:
            direction: Direction ("forward", "backward", "left", "right").
            distance: Distance in meters.
            request_id: Optional correlation ID.
            speed: Navigation speed (default "slow" for relative).

        Raises:
            EdgeProxyConnectionError: If not connected.
        """
        cmd = NavigateCommand.to_relative(direction, distance, request_id, speed)
        await self._send_message(cmd.to_dict())

    async def send_cancel_navigation(
        self,
        request_id: str = "",
        reason: str = ""
    ) -> None:
        """Send a cancel navigation command.

        Spec: Section 4.4 - Cancel Navigation

        Args:
            request_id: Optional correlation ID.
            reason: Optional reason for cancellation.

        Raises:
            EdgeProxyConnectionError: If not connected.
        """
        cmd = CancelNavigationCommand(
            type=MessageType.CANCEL_NAVIGATION,
            request_id=request_id,
            reason=reason
        )
        await self._send_message(cmd.to_dict())

    async def send_get_state(self, request_id: str = "") -> None:
        """Request the current robot state.

        Spec: Section 4.5 - Request State

        The response will be delivered via on_robot_state handlers.

        Args:
            request_id: Optional correlation ID.

        Raises:
            EdgeProxyConnectionError: If not connected.
        """
        cmd = GetStateCommand(type=MessageType.GET_STATE, request_id=request_id)
        await self._send_message(cmd.to_dict())

    async def send_ping(self) -> None:
        """Send a ping message for keepalive.

        Spec: Section 4.6 - Ping

        Raises:
            EdgeProxyConnectionError: If not connected.
        """
        from .messages import PingMessage
        cmd = PingMessage()
        await self._send_message(cmd.to_dict())

    async def send_capture_frame(self, request_id: str = "") -> None:
        """Send a capture_frame command to the Edge Proxy.

        The Edge Proxy will capture one JPEG frame from the camera and reply
        with a ``frame_response`` message.  To await the response, create an
        ``asyncio.Future`` and store it in ``self._pending_frames[request_id]``
        *before* calling this method, then ``await`` the future.

        In practice callers should use ``ActionExecutor._capture_frame()``
        which handles the Future lifecycle automatically.

        Args:
            request_id: Correlation ID used to match the response.

        Raises:
            EdgeProxyConnectionError: If not connected.
        """
        cmd = CaptureFrameCommand(request_id=request_id)
        await self._send_message(cmd.to_dict())

    # ========================================================================
    # Internal Methods
    # ========================================================================

    async def _connect_with_retry(self) -> None:
        """Connect with exponential backoff retry if enabled.

        Spec: Section 7.3 - Reconnection
        """
        delay = self.config.reconnect_delay

        while not self._should_stop:
            try:
                await self._connect_once()
                return
            except (OSError, websockets.exceptions.WebSocketException) as exc:
                self.logger.warning("Connection failed: %s", exc)

                if not self.config.reconnect or self._should_stop:
                    raise EdgeProxyConnectionError(f"Failed to connect: {exc}") from exc

                # Exponential backoff: 1s, 2s, 4s, 8s... max 30s
                await asyncio.sleep(min(delay, self.config.max_reconnect_delay))
                delay = delay * 2

    async def _connect_once(self) -> None:
        """Perform a single connection attempt.

        Spec: Section 3.1 - Connection
        Endpoint: ws://<robot-ip>:8080/edge (defaults; configurable)
        """
        uri = f"ws://{self.config.host}:{self.config.port}{self.config.ws_path}"
        self.logger.info("Connecting to Edge Proxy at %s", uri)

        self._ws = await websockets.connect(
            uri,
            ping_interval=self.config.ping_interval,
            ping_timeout=self.config.ping_timeout,
        )

        was_not_connected = not self._connected
        self._connected = True

        if was_not_connected:
            await self._notify_connection_change(True)

        self.logger.info("Connected to Edge Proxy")

        # Start message receiver
        self._receiver_task = asyncio.create_task(self._receive_messages())

    async def _receive_messages(self) -> None:
        """Receive and handle incoming messages.

        Runs in background until connection is closed.
        """
        if not self._ws:
            return

        try:
            async for message in self._ws:
                await self._handle_message(message)
        except websockets.exceptions.ConnectionClosed as exc:
            self.logger.warning("Connection closed: %s", exc)
            self._connected = False
            await self._notify_connection_change(False)
        except Exception as exc:
            self.logger.error("Error receiving message: %s", exc)
        finally:
            self._connected = False
            await self._notify_connection_change(False)

            # Attempt reconnect if enabled
            if self.config.reconnect and not self._should_stop:
                await self._connect_with_retry()

    async def _handle_message(self, message: Any) -> None:
        """Handle an incoming message.

        Args:
            message: Raw message (bytes or str).
        """
        # Decode bytes to string
        if isinstance(message, (bytes, bytearray)):
            try:
                message = message.decode("utf-8", errors="replace")
            except Exception as exc:
                self.logger.error("Failed to decode message: %s", exc)
                return

        if not isinstance(message, str):
            self.logger.error("Unsupported message type: %s", type(message))
            return

        # Parse JSON
        try:
            data = json.loads(message)
        except json.JSONDecodeError as exc:
            self.logger.error("Invalid JSON message: %s", exc)
            return

        if not isinstance(data, dict):
            self.logger.error("Message payload is not a dict")
            return

        # Parse and dispatch
        try:
            msg = parse_edge_message(data)
        except ValueError as exc:
            self.logger.error("Failed to parse message: %s", exc)
            return

        if msg is None:
            self.logger.debug("Unknown message type: %s", data.get("type"))
            return

        await self._dispatch_message(msg)

    async def _dispatch_message(self, msg: Any) -> None:
        """Dispatch message to registered handlers.

        Args:
            msg: Parsed message object.
        """
        if isinstance(msg, NavStatusMessage):
            await self._call_handlers(self._nav_status_handlers, msg)

        elif isinstance(msg, RobotState):
            self.last_robot_state = msg
            await self._call_handlers(self._robot_state_handlers, msg)

        elif isinstance(msg, WaypointListMessage):
            await self._call_handlers(self._waypoint_list_handlers, msg)

        elif isinstance(msg, ErrorMessage):
            await self._call_handlers(self._error_handlers, msg)

        elif isinstance(msg, PongMessage):
            await self._call_handlers(self._pong_handlers, msg)

        elif isinstance(msg, FrameResponseMessage):
            # Resolve a pending Future if one is registered for this request_id
            future = self._pending_frames.pop(msg.request_id, None)
            if future is not None and not future.done():
                future.set_result(msg)
            await self._call_handlers(self._frame_response_handlers, msg)

        elif isinstance(msg, FRDetectionsMessage):
            self.last_fr_detections = msg
            await self._call_handlers(self._fr_detections_handlers, msg)

        elif isinstance(msg, EventLogMessage):
            event_id = msg.event_id.strip() if msg.event_id else ""
            if event_id:
                if event_id in self._seen_event_ids:
                    return
                self._seen_event_ids.add(event_id)
                self._seen_event_ids_order.append(event_id)
                if len(self._seen_event_ids_order) > self._seen_event_ids_limit:
                    oldest = self._seen_event_ids_order.pop(0)
                    self._seen_event_ids.discard(oldest)
            self.last_event_log = msg
            await self._call_handlers(self._event_log_handlers, msg)

    async def _call_handlers(self, handlers: List[Any], msg: Any) -> None:
        """Call all handlers for a message type.

        Args:
            handlers: List of handler callables.
            msg: Message to pass to handlers.
        """
        for handler in handlers:
            try:
                result = handler(msg)
                if asyncio.iscoroutine(result):
                    await result
            except Exception as exc:
                self.logger.error(
                    "Handler %s failed: %s", handler.__name__, exc
                )

    async def _send_message(self, data: Dict[str, Any]) -> None:
        """Send a message to the Edge Proxy.

        Args:
            data: Message dictionary to send as JSON.

        Raises:
            EdgeProxyConnectionError: If not connected.
        """
        if not self.is_connected or not self._ws:
            raise EdgeProxyConnectionError("Not connected to Edge Proxy")

        try:
            message = json.dumps(data)
            await self._ws.send(message)
        except Exception as exc:
            self.logger.error("Failed to send message: %s", exc)
            raise EdgeProxyConnectionError(f"Failed to send message: {exc}") from exc

    async def _notify_connection_change(self, connected: bool) -> None:
        """Notify all connection change handlers.

        Args:
            connected: New connection state.
        """
        for handler in self._connection_change_handlers:
            try:
                result = handler(connected)
                if asyncio.iscoroutine(result):
                    await result
            except Exception as exc:
                self.logger.error(
                    "Connection change handler failed: %s", exc
                )
