from __future__ import annotations

import asyncio
import math
import threading
from typing import Optional

from .base import NavBackend, NavUpdate, NavUpdateCallback, Pose2D

# Import message types at module level for type hints only (actual imports are lazy inside start())
try:
    from edge_proxy.messages import Pose, Battery
except ImportError:
    pass


class Nav2Backend(NavBackend):
    """Nav2 backend using rclpy ActionClient.

    Runs rclpy spinning in a background thread; asyncio awaits by polling rclpy futures.
    """

    def __init__(self, action_name: str = "/navigate_to_pose") -> None:
        self._action_name = action_name
        self._spin_thread: Optional[threading.Thread] = None
        self._stop_event = threading.Event()

        self._rclpy = None
        self._node = None
        self._executor = None
        self._client = None
        self._goal_handle = None
        self._active_request_id: Optional[str] = None

        # Telemetry — updated by ROS2 subscribers
        self._latest_pose: Optional["Pose"] = None
        self._latest_battery: Optional["Battery"] = None
        self._pose_lock = threading.Lock()
        self._battery_lock = threading.Lock()

    async def start(self) -> None:
        try:
            import rclpy
            from rclpy.executors import SingleThreadedExecutor
            from rclpy.node import Node
            from rclpy.action import ActionClient
            from nav2_msgs.action import NavigateToPose
            from geometry_msgs.msg import Point
            from sensor_msgs.msg import BatteryState
        except Exception as exc:  # pragma: no cover
            raise RuntimeError(f"Nav2 backend requires ROS2 Python packages (rclpy/nav2_msgs): {exc}")

        # Lazy import of message dataclasses to avoid circular dependency at module load
        from edge_proxy.messages import Pose as PoseMsg, Battery as BatteryMsg

        self._rclpy = rclpy
        if not rclpy.ok():
            rclpy.init(args=None)

        self._node = Node("edge_proxy")
        self._executor = SingleThreadedExecutor()
        self._executor.add_node(self._node)

        self._client = ActionClient(self._node, NavigateToPose, self._action_name)

        # --- Pose subscriber ---
        # waypoint_handler publishes geometry_msgs/Point on gui/get_robot_pose at ~10 Hz.
        # Convention: Point.x = map x, Point.y = map y, Point.z = heading theta (radians).
        def _pose_cb(msg: Point) -> None:
            with self._pose_lock:
                self._latest_pose = PoseMsg(x=msg.x, y=msg.y, theta=msg.z)

        self._node.create_subscription(Point, "gui/get_robot_pose", _pose_cb, 10)

        # --- Battery subscriber ---
        # /kstack/state/battery publishes sensor_msgs/BatteryState.
        # percentage is in [0.0, 1.0]; we convert to integer 0–100.
        # power_supply_status == 1 (CHARGING) indicates it is charging.
        _POWER_SUPPLY_STATUS_CHARGING = 1

        def _battery_cb(msg: BatteryState) -> None:
            level = int(round(float(msg.percentage) * 100.0))
            charging = int(msg.power_supply_status) == _POWER_SUPPLY_STATUS_CHARGING
            with self._battery_lock:
                self._latest_battery = BatteryMsg(level=level, charging=charging)

        self._node.create_subscription(BatteryState, "/kstack/state/battery", _battery_cb, 10)

        self._stop_event.clear()
        self._spin_thread = threading.Thread(target=self._spin, daemon=True)
        self._spin_thread.start()

        # Wait briefly for server.
        # Don't block forever; navigation will error later if server is down.
        for _ in range(50):
            if self._client.wait_for_server(timeout_sec=0.1):
                return
            await asyncio.sleep(0.05)

    def get_pose(self) -> "Pose":
        """Return the latest robot pose from ROS2, or a zero pose if not yet received."""
        from edge_proxy.messages import Pose as PoseMsg

        with self._pose_lock:
            if self._latest_pose is not None:
                return self._latest_pose
        return PoseMsg()

    def get_battery(self) -> "Battery":
        """Return the latest battery state from ROS2, or 100% not-charging if not yet received."""
        from edge_proxy.messages import Battery as BatteryMsg

        with self._battery_lock:
            if self._latest_battery is not None:
                return self._latest_battery
        return BatteryMsg()

    def _spin(self) -> None:
        assert self._executor is not None
        while not self._stop_event.is_set():
            self._executor.spin_once(timeout_sec=0.1)

    async def stop(self) -> None:
        if self._stop_event.is_set():
            return
        self._stop_event.set()

        if self._spin_thread is not None:
            self._spin_thread.join(timeout=2.0)

        if self._executor is not None and self._node is not None:
            try:
                self._executor.remove_node(self._node)
            except Exception:
                pass

        if self._node is not None:
            try:
                self._node.destroy_node()
            except Exception:
                pass

        if self._rclpy is not None:
            try:
                self._rclpy.shutdown()
            except Exception:
                pass

    async def navigate_to_pose(
        self,
        request_id: str,
        destination: str,
        pose: Pose2D,
        speed: str,
        on_update: NavUpdateCallback,
        frame_id: str = "map",
    ) -> None:
        if self._client is None:
            raise RuntimeError("Nav2 backend not started")

        from nav2_msgs.action import NavigateToPose
        from geometry_msgs.msg import PoseStamped

        self._active_request_id = request_id

        await on_update(NavUpdate(status="accepted", progress=0.0))

        loop = asyncio.get_running_loop()

        # Build goal
        goal_msg = NavigateToPose.Goal()
        stamped = PoseStamped()
        stamped.header.frame_id = frame_id
        stamped.header.stamp = self._node.get_clock().now().to_msg()
        stamped.pose.position.x = float(pose.x)
        stamped.pose.position.y = float(pose.y)
        stamped.pose.position.z = 0.0

        qz = math.sin(float(pose.theta) * 0.5)
        qw = math.cos(float(pose.theta) * 0.5)
        stamped.pose.orientation.z = qz
        stamped.pose.orientation.w = qw

        goal_msg.pose = stamped

        def feedback_cb(feedback_msg) -> None:
            # Nav2 feedback has fields like distance_remaining; progress isn't directly available.
            # Emit a navigating update without progress.
            # This callback runs in rclpy executor thread.
            if self._active_request_id != request_id:
                return
            # Best-effort: only emit navigating state.
            try:
                asyncio.run_coroutine_threadsafe(
                    on_update(NavUpdate(status="navigating")),
                    loop,
                )
            except Exception:
                # Don't let feedback callback crash due to loop/thread timing.
                return

        send_goal_future = self._client.send_goal_async(goal_msg, feedback_callback=feedback_cb)

        # Wait for goal acceptance.
        while not send_goal_future.done():
            await asyncio.sleep(0.05)
        goal_handle = send_goal_future.result()
        self._goal_handle = goal_handle

        if not goal_handle.accepted:
            await on_update(
                NavUpdate(status="failed", reason="goal_rejected", error_code="NAV_INVALID_GOAL")
            )
            return

        # Wait for result
        result_future = goal_handle.get_result_async()
        while not result_future.done():
            await asyncio.sleep(0.1)

        result = result_future.result()
        status_code = getattr(result, "status", None)

        # 4 == SUCCEEDED in action_msgs/GoalStatus
        if status_code == 4:
            await on_update(NavUpdate(status="arrived", progress=1.0))
            return

        # 5 == CANCELED, 6 == ABORTED (typically)
        if status_code == 5:
            await on_update(NavUpdate(status="cancelled"))
            return

        await on_update(NavUpdate(status="failed", reason="nav2_failed"))

    async def cancel(self, request_id: str, reason: str = "") -> None:
        if self._goal_handle is None:
            return
        if self._active_request_id != request_id:
            return

        cancel_future = self._goal_handle.cancel_goal_async()
        while not cancel_future.done():
            await asyncio.sleep(0.05)
