from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
import time
from typing import Any, Callable, Dict, List, Optional


class OrchestratorState(str, Enum):
    IDLE = "IDLE"
    LISTENING = "LISTENING"
    PROCESSING = "PROCESSING"
    AWAITING_CONFIRM = "AWAITING_CONFIRM"
    EXECUTING_ACTIONS = "EXECUTING_ACTIONS"
    ACTING = "ACTING"
    SPEAKING = "SPEAKING"


class InvalidTransition(RuntimeError):
    pass


@dataclass(frozen=True)
class TransitionRecord:
    timestamp: float
    previous: OrchestratorState
    event: str
    next: OrchestratorState
    payload: Optional[Dict[str, Any]] = None


TransitionListener = Callable[[TransitionRecord], None]


DEFAULT_TRANSITIONS: Dict[OrchestratorState, Dict[str, OrchestratorState]] = {
    OrchestratorState.IDLE: {
        "wake_word_detected": OrchestratorState.LISTENING,
    },
    OrchestratorState.LISTENING: {
        "final_transcript": OrchestratorState.PROCESSING,
        "timeout": OrchestratorState.SPEAKING,
    },
    OrchestratorState.PROCESSING: {
        "llm_response": OrchestratorState.AWAITING_CONFIRM,
        "processing_failed": OrchestratorState.SPEAKING,
    },
    OrchestratorState.AWAITING_CONFIRM: {
        "user_confirmed": OrchestratorState.EXECUTING_ACTIONS,
        "user_rejected": OrchestratorState.IDLE,
        "barge_in": OrchestratorState.LISTENING,
    },
    OrchestratorState.EXECUTING_ACTIONS: {
        "actions_complete": OrchestratorState.SPEAKING,
        "actions_failed": OrchestratorState.SPEAKING,
        "barge_in": OrchestratorState.LISTENING,
    },
    OrchestratorState.ACTING: {
        "action_complete": OrchestratorState.SPEAKING,
        "barge_in": OrchestratorState.LISTENING,
    },
    OrchestratorState.SPEAKING: {
        "tts_complete": OrchestratorState.IDLE,
        "barge_in": OrchestratorState.LISTENING,
    },
}

GLOBAL_TRANSITIONS: Dict[str, OrchestratorState] = {
    "reset": OrchestratorState.IDLE,
}


class OrchestratorStateMachine:
    def __init__(
        self,
        initial_state: OrchestratorState = OrchestratorState.IDLE,
        transitions: Optional[Dict[OrchestratorState, Dict[str, OrchestratorState]]] = None,
    ) -> None:
        self._state = initial_state
        self._transitions = transitions or DEFAULT_TRANSITIONS
        self._history: List[TransitionRecord] = []
        self._listeners: List[TransitionListener] = []

    @property
    def state(self) -> OrchestratorState:
        return self._state

    def add_listener(self, listener: TransitionListener) -> None:
        self._listeners.append(listener)

    def can_handle(self, event: str) -> bool:
        event_key = event.strip()
        return event_key in self._transitions.get(self._state, {}) or event_key in GLOBAL_TRANSITIONS

    def valid_events(self) -> List[str]:
        local = list(self._transitions.get(self._state, {}).keys())
        return sorted(set(local + list(GLOBAL_TRANSITIONS.keys())))

    def handle_event(self, event: str, payload: Optional[Dict[str, Any]] = None) -> OrchestratorState:
        event_key = event.strip()
        target = self._transitions.get(self._state, {}).get(event_key)
        if target is None:
            target = GLOBAL_TRANSITIONS.get(event_key)
        if target is None:
            raise InvalidTransition(
                f"Cannot handle event '{event_key}' from state {self._state.value}."
            )

        previous = self._state
        self._state = target
        record = TransitionRecord(
            timestamp=time.time(),
            previous=previous,
            event=event_key,
            next=target,
            payload=payload,
        )
        self._history.append(record)
        for listener in list(self._listeners):
            listener(record)
        return self._state

    def history(self) -> List[TransitionRecord]:
        return list(self._history)
