from __future__ import annotations

import asyncio
import json
import math
import ssl
import time
from dataclasses import dataclass
from typing import Any, Dict, Optional
from urllib.parse import urlparse

import websockets
from websockets.client import WebSocketClientProtocol
from websockets.exceptions import ConnectionClosed

from edge_proxy.messages import Battery, Pose

from .base import NavBackend, NavUpdate, NavUpdateCallback, Pose2D


@dataclass(frozen=True)
class _TopicConfig:
    topic: str
    msg_type: str


class RosbridgeBackend(NavBackend):
    """Nav backend using rosbridge WebSocket and waypoint_handler gui/* topics."""

    _PLAN_FAILED = -1
    _PLAN_INACTIVE = 0
    _PLAN_ACTIVE = 1
    _PLAN_PAUSED = 2

    _WP_STATUS_FAILED = -1
    _WP_STATUS_NOT_REACHED = 0
    _WP_STATUS_CURRENT = 1
    _WP_STATUS_REACHED = 2

    _POWER_SUPPLY_STATUS_CHARGING = 1

    def __init__(
        self,
        *,
        rosbridge_url: str = "wss://127.0.0.1:9090",
        insecure_tls: bool = True,
    ) -> None:
        self._rosbridge_url = rosbridge_url
        self._insecure_tls = insecure_tls

        self._ws: Optional[WebSocketClientProtocol] = None
        self._receiver_task: Optional[asyncio.Task[None]] = None
        self._send_lock = asyncio.Lock()
        self._state_lock = asyncio.Lock()

        self._latest_pose = Pose()
        self._latest_pose_ts: Optional[float] = None
        self._latest_battery = Battery()

        self._active_request_id: Optional[str] = None
        self._active_on_update: Optional[NavUpdateCallback] = None
        self._active_done: Optional[asyncio.Event] = None
        self._active_phase: str = "idle"  # idle | pending | active | complete | failed | cancelled
        self._active_progress: float = 0.0
        self._active_cancelled: bool = False
        self._last_progress_sent: float = -1.0

        self._publisher_topics = (
            _TopicConfig("gui/add_waypoint_new", "waypoint_handler_idl/msg/WaypointArray"),
            _TopicConfig("gui/execute_plan", "std_msgs/msg/Empty"),
            _TopicConfig("gui/cancel_plan", "std_msgs/msg/Empty"),
        )
        self._subscriber_topics = (
            _TopicConfig("gui/get_plan_feedback", "std_msgs/msg/Int8"),
            _TopicConfig("gui/get_waypoints_new", "waypoint_handler_idl/msg/WaypointArray"),
            _TopicConfig("gui/get_robot_pose", "geometry_msgs/msg/Point"),
            _TopicConfig("/kstack/state/battery", "sensor_msgs/msg/BatteryState"),
        )

    async def start(self) -> None:
        ssl_ctx = self._build_ssl_context()
        self._ws = await websockets.connect(
            self._rosbridge_url,
            ping_interval=20,
            ping_timeout=20,
            ssl=ssl_ctx,
        )
        await self._register_rosbridge_topics()
        self._receiver_task = asyncio.create_task(self._receiver_loop())

    async def stop(self) -> None:
        if self._receiver_task is not None:
            self._receiver_task.cancel()
            try:
                await self._receiver_task
            except asyncio.CancelledError:
                pass
            self._receiver_task = None

        if self._ws is not None and not self._ws.closed:
            await self._ws.close()
        self._ws = None

    async def navigate_to_pose(
        self,
        request_id: str,
        destination: str,
        pose: Pose2D,
        speed: str,
        on_update: NavUpdateCallback,
        frame_id: str = "map",
    ) -> None:
        _ = destination
        _ = speed
        _ = frame_id

        async with self._state_lock:
            if self._active_request_id is not None:
                raise RuntimeError("Another navigation request is already active")

            self._active_request_id = request_id
            self._active_on_update = on_update
            self._active_done = asyncio.Event()
            self._active_phase = "pending"
            self._active_progress = 0.0
            self._active_cancelled = False
            self._last_progress_sent = -1.0

        await on_update(NavUpdate(status="accepted", progress=0.0))

        try:
            await self._publish_single_waypoint_and_execute(pose)
            await asyncio.wait_for(self._active_done.wait(), timeout=120.0)
        except asyncio.TimeoutError:
            await self._emit_update(
                NavUpdate(
                    status="failed",
                    progress=self._active_progress,
                    reason="timeout",
                    error_code="NAV_TIMEOUT",
                ),
                set_done=True,
            )
        finally:
            async with self._state_lock:
                self._active_request_id = None
                self._active_on_update = None
                self._active_done = None
                self._active_phase = "idle"
                self._active_cancelled = False

    async def cancel(self, request_id: str, reason: str = "") -> None:
        _ = reason
        async with self._state_lock:
            if self._active_request_id != request_id:
                return
            self._active_cancelled = True
            done_evt = self._active_done

        await self._publish_message(
            topic="gui/cancel_plan",
            msg={},
        )

        if done_evt is not None:
            done_evt.set()

    def get_pose(self) -> Pose:
        return self._latest_pose

    def get_pose_age_sec(self) -> Optional[float]:
        """Return age of latest pose sample in seconds, or None if unavailable."""
        if self._latest_pose_ts is None:
            return None
        return max(0.0, time.monotonic() - self._latest_pose_ts)

    def get_battery(self) -> Battery:
        return self._latest_battery

    def _build_ssl_context(self) -> Optional[ssl.SSLContext]:
        scheme = urlparse(self._rosbridge_url).scheme.lower()
        if scheme != "wss":
            return None

        ctx = ssl.create_default_context()
        if self._insecure_tls:
            ctx.check_hostname = False
            ctx.verify_mode = ssl.CERT_NONE
        return ctx

    async def _register_rosbridge_topics(self) -> None:
        for cfg in self._publisher_topics:
            await self._send_raw(
                {
                    "op": "advertise",
                    "topic": cfg.topic,
                    "type": cfg.msg_type,
                }
            )

        for cfg in self._subscriber_topics:
            await self._send_raw(
                {
                    "op": "subscribe",
                    "topic": cfg.topic,
                    "type": cfg.msg_type,
                }
            )

    async def _receiver_loop(self) -> None:
        ws = self._ws
        if ws is None:
            return

        try:
            async for raw in ws:
                if isinstance(raw, (bytes, bytearray)):
                    try:
                        raw = raw.decode("utf-8", errors="replace")
                    except Exception:
                        continue

                if not isinstance(raw, str):
                    continue

                try:
                    payload = json.loads(raw)
                except json.JSONDecodeError:
                    continue

                if not isinstance(payload, dict):
                    continue

                if payload.get("op") != "publish":
                    continue

                topic = str(payload.get("topic", ""))
                msg = payload.get("msg")
                if not isinstance(msg, dict):
                    continue

                await self._handle_publish(topic, msg)
        except asyncio.CancelledError:
            raise
        except ConnectionClosed:
            pass
        finally:
            await self._emit_update(
                NavUpdate(
                    status="failed",
                    progress=self._active_progress,
                    reason="rosbridge_disconnected",
                    error_code="NAV_LOCALIZATION_LOST",
                ),
                set_done=True,
            )

    async def _handle_publish(self, topic: str, msg: Dict[str, Any]) -> None:
        if topic == "gui/get_robot_pose":
            self._latest_pose = Pose(
                x=float(msg.get("x", 0.0)),
                y=float(msg.get("y", 0.0)),
                theta=float(msg.get("z", 0.0)),
            )
            self._latest_pose_ts = time.monotonic()
            return

        if topic == "/kstack/state/battery":
            percentage = float(msg.get("percentage", 1.0))
            if percentage <= 1.0:
                level = int(round(percentage * 100.0))
            else:
                level = int(round(percentage))
            charging = int(msg.get("power_supply_status", 0)) == self._POWER_SUPPLY_STATUS_CHARGING
            self._latest_battery = Battery(level=level, charging=charging)
            return

        if topic == "gui/get_waypoints_new":
            await self._handle_waypoint_feedback(msg)
            return

        if topic == "gui/get_plan_feedback":
            await self._handle_plan_feedback(msg)
            return

    async def _handle_waypoint_feedback(self, msg: Dict[str, Any]) -> None:
        waypoints = msg.get("waypoints")
        if not isinstance(waypoints, list) or not waypoints:
            return

        reached = 0
        total = 0
        for wp in waypoints:
            if not isinstance(wp, dict):
                continue
            total += 1
            status = int(wp.get("status", self._WP_STATUS_NOT_REACHED))
            if status in (self._WP_STATUS_REACHED, self._WP_STATUS_FAILED):
                reached += 1

        if total == 0:
            return

        progress = max(0.0, min(0.99, reached / float(total)))
        if math.isclose(progress, self._last_progress_sent, abs_tol=1e-2):
            return

        self._last_progress_sent = progress
        await self._emit_update(NavUpdate(status="navigating", progress=progress))

    async def _handle_plan_feedback(self, msg: Dict[str, Any]) -> None:
        status = int(msg.get("data", self._PLAN_INACTIVE))

        async with self._state_lock:
            if self._active_request_id is None:
                return

            phase = self._active_phase
            cancelled = self._active_cancelled

        if status == self._PLAN_ACTIVE:
            async with self._state_lock:
                if self._active_phase in {"pending", "active"}:
                    self._active_phase = "active"
            await self._emit_update(NavUpdate(status="navigating", progress=self._active_progress))
            return

        if status == self._PLAN_PAUSED:
            await self._emit_update(NavUpdate(status="navigating", progress=self._active_progress))
            return

        if status == self._PLAN_FAILED:
            async with self._state_lock:
                self._active_phase = "failed"
            await self._emit_update(
                NavUpdate(
                    status="failed",
                    progress=self._active_progress,
                    reason="waypoint_handler_failed",
                    error_code="NAV_BLOCKED",
                ),
                set_done=True,
            )
            return

        if status == self._PLAN_INACTIVE:
            if cancelled:
                async with self._state_lock:
                    self._active_phase = "cancelled"
                await self._emit_update(
                    NavUpdate(
                        status="cancelled",
                        progress=self._active_progress,
                    ),
                    set_done=True,
                )
                return

            if phase == "active":
                async with self._state_lock:
                    self._active_phase = "complete"
                    self._active_progress = 1.0
                await self._emit_update(
                    NavUpdate(
                        status="arrived",
                        progress=1.0,
                    ),
                    set_done=True,
                )

    async def _publish_single_waypoint_and_execute(self, pose: Pose2D) -> None:
        qz = math.sin(float(pose.theta) * 0.5)
        qw = math.cos(float(pose.theta) * 0.5)

        sec, nanosec = self._stamp_now()
        waypoint_array_msg = {
            "header": {
                "stamp": {"sec": sec, "nanosec": nanosec},
                "frame_id": "map",
            },
            "waypoints": [
                {
                    "pose": {
                        "position": {
                            "x": float(pose.x),
                            "y": float(pose.y),
                            "z": 0.0,
                        },
                        "orientation": {
                            "x": 0.0,
                            "y": 0.0,
                            "z": qz,
                            "w": qw,
                        },
                    },
                    "mode": 0,
                    "status": 0,
                }
            ],
            "next_waypoint": 0,
        }

        await self._publish_message(
            topic="gui/add_waypoint_new",
            msg=waypoint_array_msg,
        )
        await self._publish_message(
            topic="gui/execute_plan",
            msg={},
        )

    async def _publish_message(self, *, topic: str, msg: Dict[str, Any]) -> None:
        await self._send_raw(
            {
                "op": "publish",
                "topic": topic,
                "msg": msg,
            }
        )

    async def _emit_update(self, update: NavUpdate, *, set_done: bool = False) -> None:
        async with self._state_lock:
            callback = self._active_on_update
            done_evt = self._active_done
            if update.progress is not None:
                self._active_progress = float(update.progress)

        if callback is not None:
            await callback(update)

        if set_done and done_evt is not None:
            done_evt.set()

    async def _send_raw(self, payload: Dict[str, Any]) -> None:
        ws = self._ws
        if ws is None or ws.closed:
            raise RuntimeError("rosbridge is not connected")

        async with self._send_lock:
            await ws.send(json.dumps(payload))

    def _stamp_now(self) -> tuple[int, int]:
        t = time.time()
        sec = int(t)
        nanosec = int((t - sec) * 1_000_000_000)
        return sec, nanosec
