"""Edge Proxy WebSocket message definitions.

This module defines the message types and dataclasses for communication
with the robot Edge Proxy over WebSocket, following the Edge Proxy Design spec.
See: /home/nelsen/Projects/HRI/docs/plans/2026-02-04-edge-proxy-design.md
"""

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Union


class MessageType(str, Enum):
    """Edge Proxy message types.

    These correspond to the message types defined in the Edge Proxy Design spec.
    """

    # Incoming from Edge Proxy
    NAV_STATUS = "nav_status"
    ROBOT_STATE = "robot_state"
    WAYPOINT_LIST = "waypoint_list"
    ERROR = "error"
    PONG = "pong"
    FRAME_RESPONSE = "frame_response"
    FR_DETECTIONS = "fr_detections"
    EVENT_LOG = "event_log"

    # Outgoing to Edge Proxy
    NAVIGATE = "navigate"
    CANCEL_NAVIGATION = "cancel_navigation"
    GET_STATE = "get_state"
    PING = "ping"
    CAPTURE_FRAME = "capture_frame"


class NavStatus(str, Enum):
    """Navigation status values from Edge Proxy."""

    ACCEPTED = "accepted"
    NAVIGATING = "navigating"
    ARRIVED = "arrived"
    FAILED = "failed"
    CANCELLED = "cancelled"


class NavErrorCode(str, Enum):
    """Navigation error codes from Edge Proxy."""

    NAV_BLOCKED = "NAV_BLOCKED"
    NAV_TIMEOUT = "NAV_TIMEOUT"
    NAV_LOCALIZATION_LOST = "NAV_LOCALIZATION_LOST"
    NAV_INVALID_GOAL = "NAV_INVALID_GOAL"
    NAV_WAYPOINT_NOT_FOUND = "NAV_WAYPOINT_NOT_FOUND"


class Speed(str, Enum):
    """Navigation speed settings."""

    SLOW = "slow"
    NORMAL = "normal"
    FAST = "fast"


class GoalType(str, Enum):
    """Goal type for navigation commands."""

    WAYPOINT = "waypoint"
    POSE = "pose"
    RELATIVE = "relative"


class RelativeDirection(str, Enum):
    """Direction for relative navigation."""

    FORWARD = "forward"
    BACKWARD = "backward"
    LEFT = "left"
    RIGHT = "right"


# ============================================================================
# Outgoing Messages (Orchestrator -> Edge Proxy)
# ============================================================================


@dataclass
class NavigateCommand:
    """Navigate command message (Orchestrator -> Edge Proxy).

    Spec: Section 4.1-4.3 of Edge Proxy Design
    """

    type: str = MessageType.NAVIGATE
    request_id: str = ""
    goal: Dict[str, Any] = None  # WaypointGoal, PoseGoal, or RelativeGoal as dict
    speed: str = Speed.NORMAL

    def __post_init__(self):
        if self.goal is None:
            self.goal = {}

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for JSON serialization."""
        return {
            "type": self.type,
            "request_id": self.request_id,
            "goal": self.goal,
            "speed": self.speed,
        }

    @classmethod
    def to_waypoint(cls, name: str, request_id: str = "", speed: str = Speed.NORMAL) -> NavigateCommand:
        """Create a waypoint navigation command."""
        return cls(
            type=MessageType.NAVIGATE,
            request_id=request_id,
            goal={"type": GoalType.WAYPOINT, "name": name},
            speed=speed,
        )

    @classmethod
    def to_pose(cls, x: float, y: float, theta: float = 0.0,
                request_id: str = "", speed: str = Speed.NORMAL) -> NavigateCommand:
        """Create a pose navigation command."""
        return cls(
            type=MessageType.NAVIGATE,
            request_id=request_id,
            goal={"type": GoalType.POSE, "x": x, "y": y, "theta": theta},
            speed=speed,
        )

    @classmethod
    def to_relative(cls, direction: str, distance: float,
                   request_id: str = "", speed: str = Speed.SLOW) -> NavigateCommand:
        """Create a relative navigation command."""
        return cls(
            type=MessageType.NAVIGATE,
            request_id=request_id,
            goal={"type": GoalType.RELATIVE, "direction": direction, "distance": distance},
            speed=speed,
        )


@dataclass
class CancelNavigationCommand:
    """Cancel navigation command message.

    Spec: Section 4.4 of Edge Proxy Design
    """

    type: str = MessageType.CANCEL_NAVIGATION
    request_id: str = ""
    reason: str = ""

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for JSON serialization."""
        return {
            "type": self.type,
            "request_id": self.request_id,
            "reason": self.reason,
        }


@dataclass
class GetStateCommand:
    """Request state command message.

    Spec: Section 4.5 of Edge Proxy Design
    """

    type: str = MessageType.GET_STATE
    request_id: str = ""

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for JSON serialization."""
        return {
            "type": self.type,
            "request_id": self.request_id,
        }


@dataclass
class PingMessage:
    """Ping message for keepalive.

    Spec: Section 4.6 of Edge Proxy Design
    """

    type: str = MessageType.PING

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for JSON serialization."""
        return {"type": self.type}


@dataclass
class CaptureFrameCommand:
    """Request the Edge Proxy to capture a camera frame.

    The Edge Proxy will respond with a FrameResponseMessage containing the
    base64-encoded JPEG or an error field.
    """

    type: str = MessageType.CAPTURE_FRAME
    request_id: str = ""

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for JSON serialization."""
        return {"type": self.type, "request_id": self.request_id}


# ============================================================================
# Incoming Messages (Edge Proxy -> Orchestrator)
# ============================================================================


@dataclass
class NavStatusMessage:
    """Navigation status update message (Edge Proxy -> Orchestrator).

    Spec: Section 5.1 of Edge Proxy Design
    """

    type: str = MessageType.NAV_STATUS
    request_id: str = ""
    status: str = ""
    destination: str = ""
    progress: float = 0.0
    eta_sec: Optional[float] = None
    reason: Optional[str] = None  # For failed status
    error_code: Optional[str] = None  # For failed status

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> NavStatusMessage:
        """Create from dictionary."""
        return cls(
            type=data.get("type", MessageType.NAV_STATUS),
            request_id=data.get("request_id", ""),
            status=data.get("status", ""),
            destination=data.get("destination", ""),
            progress=float(data.get("progress", 0.0)),
            eta_sec=data.get("eta_sec"),
            reason=data.get("reason"),
            error_code=data.get("error_code"),
        )


@dataclass
class Pose:
    """Robot pose."""

    x: float = 0.0
    y: float = 0.0
    theta: float = 0.0

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> Pose:
        """Create from dictionary."""
        return cls(
            x=float(data.get("x", 0.0)),
            y=float(data.get("y", 0.0)),
            theta=float(data.get("theta", 0.0)),
        )


@dataclass
class Battery:
    """Battery state."""

    level: int = 100
    charging: bool = False

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> Battery:
        """Create from dictionary."""
        return cls(
            level=int(data.get("level", 100)),
            charging=bool(data.get("charging", False)),
        )


@dataclass
class RobotState:
    """Robot state message (Edge Proxy -> Orchestrator).

    Spec: Section 5.2 of Edge Proxy Design
    """

    type: str = MessageType.ROBOT_STATE
    timestamp: float = 0.0
    pose: Pose = None
    location: str = ""
    battery: Battery = None
    nav_state: str = ""
    nav_progress: float = 0.0
    nav_destination: str = ""

    def __post_init__(self):
        if self.pose is None:
            self.pose = Pose()
        if self.battery is None:
            self.battery = Battery()

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> RobotState:
        """Create from dictionary."""
        return cls(
            type=data.get("type", MessageType.ROBOT_STATE),
            timestamp=float(data.get("timestamp", 0.0)),
            pose=Pose.from_dict(data.get("pose", {})),
            location=data.get("location", ""),
            battery=Battery.from_dict(data.get("battery", {})),
            nav_state=data.get("nav_state", ""),
            nav_progress=float(data.get("nav_progress", 0.0)),
            nav_destination=data.get("nav_destination", ""),
        )


@dataclass
class Waypoint:
    """A waypoint in the robot's map.

    Spec: Section 5.3 of Edge Proxy Design
    """

    name: str = ""
    x: float = 0.0
    y: float = 0.0
    theta: float = 0.0

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> Waypoint:
        """Create from dictionary."""
        return cls(
            name=data.get("name", ""),
            x=float(data.get("x", 0.0)),
            y=float(data.get("y", 0.0)),
            theta=float(data.get("theta", 0.0)),
        )


@dataclass
class WaypointListMessage:
    """Waypoint list message (Edge Proxy -> Orchestrator).

    Spec: Section 5.3 of Edge Proxy Design
    """

    type: str = MessageType.WAYPOINT_LIST
    waypoints: List[Waypoint] = None

    def __post_init__(self):
        if self.waypoints is None:
            self.waypoints = []

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> WaypointListMessage:
        """Create from dictionary."""
        waypoints_data = data.get("waypoints", [])
        waypoints = [Waypoint.from_dict(wp) for wp in waypoints_data]
        return cls(
            type=data.get("type", MessageType.WAYPOINT_LIST),
            waypoints=waypoints,
        )


@dataclass
class ErrorMessage:
    """Error message from Edge Proxy.

    Spec: Section 5.4 of Edge Proxy Design
    """

    type: str = MessageType.ERROR
    request_id: str = ""
    error: str = ""
    message: str = ""

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> ErrorMessage:
        """Create from dictionary."""
        return cls(
            type=data.get("type", MessageType.ERROR),
            request_id=data.get("request_id", ""),
            error=data.get("error", ""),
            message=data.get("message", ""),
        )


@dataclass
class PongMessage:
    """Pong message (Edge Proxy -> Orchestrator).

    Spec: Section 5.5 of Edge Proxy Design
    """

    type: str = MessageType.PONG
    state: str = ""

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> PongMessage:
        """Create from dictionary."""
        return cls(
            type=data.get("type", MessageType.PONG),
            state=data.get("state", ""),
        )


@dataclass
class FrameResponseMessage:
    """Frame response from Edge Proxy containing a captured JPEG or an error.

    Sent in response to a CaptureFrameCommand.  Either ``jpeg_b64`` or
    ``error`` will be set, never both.
    """

    type: str = MessageType.FRAME_RESPONSE
    request_id: str = ""
    jpeg_b64: Optional[str] = None
    error: Optional[str] = None

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "FrameResponseMessage":
        """Create from dictionary."""
        return cls(
            type=data.get("type", MessageType.FRAME_RESPONSE),
            request_id=data.get("request_id", ""),
            jpeg_b64=data.get("jpeg_b64"),
            error=data.get("error"),
        )


@dataclass
class FRDetectionsMessage:
    """Face recognition detections broadcast from the edge proxy FR loop.

    Attributes:
        type: Always ``"fr_detections"``.
        detections: List of detection dicts, each with at least ``identity``
            and ``confidence`` keys.
        timestamp: Unix epoch seconds of the capture.
    """

    type: str = MessageType.FR_DETECTIONS
    detections: List[Dict[str, Any]] = None  # type: ignore[assignment]
    timestamp: float = 0.0

    def __post_init__(self) -> None:
        if self.detections is None:
            self.detections = []

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "FRDetectionsMessage":
        """Create from dictionary."""
        return cls(
            type=data.get("type", MessageType.FR_DETECTIONS),
            detections=data.get("detections", []),
            timestamp=data.get("timestamp", 0.0),
        )


@dataclass
class EventLogMessage:
    """Replay-safe edge event log message from Edge Proxy."""

    type: str = MessageType.EVENT_LOG
    event_id: str = ""
    event_type: str = ""
    request_id: str = ""
    status: str = ""
    timestamp: float = 0.0
    replay: bool = False
    payload: Dict[str, Any] = None

    def __post_init__(self) -> None:
        if self.payload is None:
            self.payload = {}

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "EventLogMessage":
        return cls(
            type=data.get("type", MessageType.EVENT_LOG),
            event_id=data.get("event_id", ""),
            event_type=data.get("event_type", ""),
            request_id=data.get("request_id", ""),
            status=data.get("status", ""),
            timestamp=float(data.get("timestamp", 0.0)),
            replay=bool(data.get("replay", False)),
            payload=data.get("payload", {}),
        )


# ============================================================================
# Message Parser
# ============================================================================

IncomingMessage = Union[
    NavStatusMessage,
    RobotState,
    WaypointListMessage,
    ErrorMessage,
    PongMessage,
    FrameResponseMessage,
    FRDetectionsMessage,
    EventLogMessage,
]


def parse_edge_message(data: Dict[str, Any]) -> Optional[IncomingMessage]:
    """Parse an incoming Edge Proxy message.

    Args:
        data: Dictionary parsed from JSON message.

    Returns:
        The appropriate message dataclass or None if the message type is unknown.

    Raises:
        ValueError: If message data is invalid (missing type field).
    """
    msg_type_str = data.get("type")
    if not msg_type_str:
        raise ValueError("Missing 'type' field in message")

    msg_type = msg_type_str

    if msg_type == MessageType.NAV_STATUS:
        return NavStatusMessage.from_dict(data)

    if msg_type == MessageType.ROBOT_STATE:
        return RobotState.from_dict(data)

    if msg_type == MessageType.WAYPOINT_LIST:
        return WaypointListMessage.from_dict(data)

    if msg_type == MessageType.ERROR:
        return ErrorMessage.from_dict(data)

    if msg_type == MessageType.PONG:
        return PongMessage.from_dict(data)

    if msg_type == MessageType.FRAME_RESPONSE:
        return FrameResponseMessage.from_dict(data)

    if msg_type == MessageType.FR_DETECTIONS:
        return FRDetectionsMessage.from_dict(data)
    if msg_type == MessageType.EVENT_LOG:
        return EventLogMessage.from_dict(data)

    # Unknown message type - return None instead of raising
    return None
