"""Edge Proxy WebSocket message definitions.

This module defines the message types and dataclasses for communication
with the cloud Orchestrator over WebSocket.

This is intentionally kept aligned with `orchestrator/src/edge_proxy/messages.py`.
"""

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Union


class MessageType(str, Enum):
    # Incoming from Edge Proxy (to Orchestrator)
    NAV_STATUS = "nav_status"
    ROBOT_STATE = "robot_state"
    WAYPOINT_LIST = "waypoint_list"
    ERROR = "error"
    PONG = "pong"
    FRAME_RESPONSE = "frame_response"
    FR_DETECTIONS = "fr_detections"
    EVENT_LOG = "event_log"

    # Outgoing to Edge Proxy (from Orchestrator)
    NAVIGATE = "navigate"
    CANCEL_NAVIGATION = "cancel_navigation"
    GET_STATE = "get_state"
    PING = "ping"
    CAPTURE_FRAME = "capture_frame"


class NavStatus(str, Enum):
    ACCEPTED = "accepted"
    NAVIGATING = "navigating"
    ARRIVED = "arrived"
    FAILED = "failed"
    CANCELLED = "cancelled"


class NavErrorCode(str, Enum):
    NAV_BLOCKED = "NAV_BLOCKED"
    NAV_TIMEOUT = "NAV_TIMEOUT"
    NAV_LOCALIZATION_LOST = "NAV_LOCALIZATION_LOST"
    NAV_INVALID_GOAL = "NAV_INVALID_GOAL"
    NAV_WAYPOINT_NOT_FOUND = "NAV_WAYPOINT_NOT_FOUND"


class Speed(str, Enum):
    SLOW = "slow"
    NORMAL = "normal"
    FAST = "fast"


class GoalType(str, Enum):
    WAYPOINT = "waypoint"
    POSE = "pose"
    RELATIVE = "relative"


class RelativeDirection(str, Enum):
    FORWARD = "forward"
    BACKWARD = "backward"
    LEFT = "left"
    RIGHT = "right"


# ==========================================================================
# Orchestrator -> Edge Proxy
# ==========================================================================


@dataclass
class NavigateCommand:
    type: str = MessageType.NAVIGATE
    request_id: str = ""
    goal: Dict[str, Any] = None
    speed: str = Speed.NORMAL

    def __post_init__(self) -> None:
        if self.goal is None:
            self.goal = {}

    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": self.type,
            "request_id": self.request_id,
            "goal": self.goal,
            "speed": self.speed,
        }

    @classmethod
    def to_waypoint(cls, name: str, request_id: str = "", speed: str = Speed.NORMAL) -> "NavigateCommand":
        return cls(
            type=MessageType.NAVIGATE,
            request_id=request_id,
            goal={"type": GoalType.WAYPOINT, "name": name},
            speed=speed,
        )

    @classmethod
    def to_pose(
        cls,
        x: float,
        y: float,
        theta: float = 0.0,
        request_id: str = "",
        speed: str = Speed.NORMAL,
    ) -> "NavigateCommand":
        return cls(
            type=MessageType.NAVIGATE,
            request_id=request_id,
            goal={"type": GoalType.POSE, "x": x, "y": y, "theta": theta},
            speed=speed,
        )

    @classmethod
    def to_relative(
        cls,
        direction: str,
        distance: float,
        request_id: str = "",
        speed: str = Speed.SLOW,
    ) -> "NavigateCommand":
        return cls(
            type=MessageType.NAVIGATE,
            request_id=request_id,
            goal={"type": GoalType.RELATIVE, "direction": direction, "distance": distance},
            speed=speed,
        )


@dataclass
class CancelNavigationCommand:
    type: str = MessageType.CANCEL_NAVIGATION
    request_id: str = ""
    reason: str = ""

    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": self.type,
            "request_id": self.request_id,
            "reason": self.reason,
        }


@dataclass
class GetStateCommand:
    type: str = MessageType.GET_STATE
    request_id: str = ""

    def to_dict(self) -> Dict[str, Any]:
        return {"type": self.type, "request_id": self.request_id}


@dataclass
class PingMessage:
    type: str = MessageType.PING

    def to_dict(self) -> Dict[str, Any]:
        return {"type": self.type}


@dataclass
class CaptureFrameCommand:
    type: str = MessageType.CAPTURE_FRAME
    request_id: str = ""

    def to_dict(self) -> Dict[str, Any]:
        return {"type": self.type, "request_id": self.request_id}


# ==========================================================================
# Edge Proxy -> Orchestrator
# ==========================================================================


@dataclass
class NavStatusMessage:
    type: str = MessageType.NAV_STATUS
    request_id: str = ""
    status: str = ""
    destination: str = ""
    progress: float = 0.0
    eta_sec: Optional[float] = None
    reason: Optional[str] = None
    error_code: Optional[str] = None

    def to_dict(self) -> Dict[str, Any]:
        payload: Dict[str, Any] = {
            "type": self.type,
            "request_id": self.request_id,
            "status": self.status,
            "destination": self.destination,
            "progress": self.progress,
        }
        if self.eta_sec is not None:
            payload["eta_sec"] = self.eta_sec
        if self.reason is not None:
            payload["reason"] = self.reason
        if self.error_code is not None:
            payload["error_code"] = self.error_code
        return payload

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "NavStatusMessage":
        return cls(
            type=data.get("type", MessageType.NAV_STATUS),
            request_id=data.get("request_id", ""),
            status=data.get("status", ""),
            destination=data.get("destination", ""),
            progress=float(data.get("progress", 0.0)),
            eta_sec=data.get("eta_sec"),
            reason=data.get("reason"),
            error_code=data.get("error_code"),
        )


@dataclass
class Pose:
    x: float = 0.0
    y: float = 0.0
    theta: float = 0.0

    def to_dict(self) -> Dict[str, Any]:
        return {"x": self.x, "y": self.y, "theta": self.theta}

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Pose":
        return cls(
            x=float(data.get("x", 0.0)),
            y=float(data.get("y", 0.0)),
            theta=float(data.get("theta", 0.0)),
        )


@dataclass
class Battery:
    level: int = 100
    charging: bool = False

    def to_dict(self) -> Dict[str, Any]:
        return {"level": self.level, "charging": self.charging}

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Battery":
        return cls(level=int(data.get("level", 100)), charging=bool(data.get("charging", False)))


@dataclass
class RobotState:
    type: str = MessageType.ROBOT_STATE
    timestamp: float = 0.0
    pose: Pose = None
    location: str = ""
    battery: Battery = None
    nav_state: str = ""
    nav_progress: float = 0.0
    nav_destination: str = ""

    def __post_init__(self) -> None:
        if self.pose is None:
            self.pose = Pose()
        if self.battery is None:
            self.battery = Battery()

    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": self.type,
            "timestamp": self.timestamp,
            "pose": self.pose.to_dict(),
            "location": self.location,
            "battery": self.battery.to_dict(),
            "nav_state": self.nav_state,
            "nav_progress": self.nav_progress,
            "nav_destination": self.nav_destination,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "RobotState":
        return cls(
            type=data.get("type", MessageType.ROBOT_STATE),
            timestamp=float(data.get("timestamp", 0.0)),
            pose=Pose.from_dict(data.get("pose", {})),
            location=data.get("location", ""),
            battery=Battery.from_dict(data.get("battery", {})),
            nav_state=data.get("nav_state", ""),
            nav_progress=float(data.get("nav_progress", 0.0)),
            nav_destination=data.get("nav_destination", ""),
        )


@dataclass
class Waypoint:
    name: str = ""
    x: float = 0.0
    y: float = 0.0
    theta: float = 0.0

    def to_dict(self) -> Dict[str, Any]:
        return {"name": self.name, "x": self.x, "y": self.y, "theta": self.theta}

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Waypoint":
        return cls(
            name=data.get("name", ""),
            x=float(data.get("x", 0.0)),
            y=float(data.get("y", 0.0)),
            theta=float(data.get("theta", 0.0)),
        )


@dataclass
class WaypointListMessage:
    type: str = MessageType.WAYPOINT_LIST
    waypoints: List[Waypoint] = None

    def __post_init__(self) -> None:
        if self.waypoints is None:
            self.waypoints = []

    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": self.type,
            "waypoints": [wp.to_dict() for wp in self.waypoints],
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "WaypointListMessage":
        waypoints_data = data.get("waypoints", [])
        waypoints = [Waypoint.from_dict(wp) for wp in waypoints_data]
        return cls(type=data.get("type", MessageType.WAYPOINT_LIST), waypoints=waypoints)


@dataclass
class ErrorMessage:
    type: str = MessageType.ERROR
    request_id: str = ""
    error: str = ""
    message: str = ""

    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": self.type,
            "request_id": self.request_id,
            "error": self.error,
            "message": self.message,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "ErrorMessage":
        return cls(
            type=data.get("type", MessageType.ERROR),
            request_id=data.get("request_id", ""),
            error=data.get("error", ""),
            message=data.get("message", ""),
        )


@dataclass
class PongMessage:
    type: str = MessageType.PONG
    state: str = ""

    def to_dict(self) -> Dict[str, Any]:
        return {"type": self.type, "state": self.state}

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "PongMessage":
        return cls(type=data.get("type", MessageType.PONG), state=data.get("state", ""))


@dataclass
class FrameResponseMessage:
    type: str = MessageType.FRAME_RESPONSE
    request_id: str = ""
    jpeg_b64: Optional[str] = None
    error: Optional[str] = None

    def to_dict(self) -> Dict[str, Any]:
        payload: Dict[str, Any] = {"type": self.type, "request_id": self.request_id}
        if self.jpeg_b64 is not None:
            payload["jpeg_b64"] = self.jpeg_b64
        if self.error is not None:
            payload["error"] = self.error
        return payload

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "FrameResponseMessage":
        return cls(
            type=data.get("type", MessageType.FRAME_RESPONSE),
            request_id=data.get("request_id", ""),
            jpeg_b64=data.get("jpeg_b64"),
            error=data.get("error"),
        )


@dataclass
class FRDetectionsMessage:
    type: str = MessageType.FR_DETECTIONS
    detections: List[Dict[str, Any]] = None  # type: ignore[assignment]
    timestamp: float = 0.0
    metrics: Optional[Dict[str, Any]] = None

    def __post_init__(self) -> None:
        if self.detections is None:
            self.detections = []

    def to_dict(self) -> Dict[str, Any]:
        payload: Dict[str, Any] = {
            "type": self.type,
            "detections": self.detections,
            "timestamp": self.timestamp,
        }
        if self.metrics is not None:
            payload["metrics"] = self.metrics
        return payload

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "FRDetectionsMessage":
        return cls(
            type=data.get("type", MessageType.FR_DETECTIONS),
            detections=data.get("detections", []),
            timestamp=float(data.get("timestamp", 0.0)),
            metrics=data.get("metrics"),
        )


@dataclass
class EventLogMessage:
    type: str = MessageType.EVENT_LOG
    event_id: str = ""
    event_type: str = ""
    request_id: str = ""
    status: str = ""
    timestamp: float = 0.0
    replay: bool = False
    payload: Dict[str, Any] = None

    def __post_init__(self) -> None:
        if self.payload is None:
            self.payload = {}

    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": self.type,
            "event_id": self.event_id,
            "event_type": self.event_type,
            "request_id": self.request_id,
            "status": self.status,
            "timestamp": self.timestamp,
            "replay": self.replay,
            "payload": self.payload,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "EventLogMessage":
        return cls(
            type=data.get("type", MessageType.EVENT_LOG),
            event_id=data.get("event_id", ""),
            event_type=data.get("event_type", ""),
            request_id=data.get("request_id", ""),
            status=data.get("status", ""),
            timestamp=float(data.get("timestamp", 0.0)),
            replay=bool(data.get("replay", False)),
            payload=data.get("payload", {}),
        )


IncomingMessage = Union[
    NavStatusMessage,
    RobotState,
    WaypointListMessage,
    ErrorMessage,
    PongMessage,
    FrameResponseMessage,
    FRDetectionsMessage,
    EventLogMessage,
]


def parse_edge_message(data: Dict[str, Any]) -> Optional[IncomingMessage]:
    msg_type_str = data.get("type")
    if not msg_type_str:
        raise ValueError("Missing 'type' field in message")

    msg_type = msg_type_str

    if msg_type == MessageType.NAV_STATUS:
        return NavStatusMessage.from_dict(data)
    if msg_type == MessageType.ROBOT_STATE:
        return RobotState.from_dict(data)
    if msg_type == MessageType.WAYPOINT_LIST:
        return WaypointListMessage.from_dict(data)
    if msg_type == MessageType.ERROR:
        return ErrorMessage.from_dict(data)
    if msg_type == MessageType.PONG:
        return PongMessage.from_dict(data)
    if msg_type == MessageType.FRAME_RESPONSE:
        return FrameResponseMessage.from_dict(data)
    if msg_type == MessageType.FR_DETECTIONS:
        return FRDetectionsMessage.from_dict(data)
    if msg_type == MessageType.EVENT_LOG:
        return EventLogMessage.from_dict(data)

    return None
