from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set

from .dag_parser import PlanGraph


DEFAULT_ALLOWED_ACTIONS = {
    "CHECK_BATTERY",
    "NAVIGATE",
    "SCAN_AREA",
    "VERIFY_OBJECT",
    "SPEAK",
    "WAIT",
    "IDENTIFY_PERSON",
    "ALERT_OPERATOR",
}

BLOCKING_SEVERITIES = {"CRITICAL", "HIGH"}


@dataclass
class Violation:
    rule_id: str
    severity: str
    message: str
    affected_steps: List[str]
    suggested_fix: str


@dataclass
class VerificationResult:
    passed: bool
    violations: List[Violation]
    warnings: List[str]


class SymbolicVerifier:
    def __init__(self, max_plan_depth: int = 20) -> None:
        self.max_plan_depth = max_plan_depth

    def verify(self, plan_graph: PlanGraph, context: Optional[Any] = None) -> VerificationResult:
        violations: List[Violation] = []
        warnings: List[str] = []

        allowed_actions = self._resolve_allowed_actions(context)

        for step_id, step in plan_graph.step_lookup.items():
            action = step.get("action")
            if allowed_actions and action not in allowed_actions:
                violations.append(
                    Violation(
                        rule_id="unknown_action",
                        severity="HIGH",
                        message=f'Action "{action}" is not in allowed actions.',
                        affected_steps=[step_id],
                        suggested_fix="Use only allowed actions for this robot.",
                    )
                )

        cycle_detected = self._detect_cycle(plan_graph)
        if cycle_detected:
            violations.append(
                Violation(
                    rule_id="no_cycles",
                    severity="CRITICAL",
                    message="Plan contains a cycle.",
                    affected_steps=cycle_detected,
                    suggested_fix="Remove circular dependencies between steps.",
                )
            )

        if not plan_graph.entry_steps:
            violations.append(
                Violation(
                    rule_id="no_orphan_nodes",
                    severity="CRITICAL",
                    message="Plan has no entry steps.",
                    affected_steps=list(plan_graph.step_lookup.keys()),
                    suggested_fix="Ensure at least one step has no dependencies.",
                )
            )
        else:
            unreachable = self._find_unreachable(plan_graph)
            if unreachable:
                violations.append(
                    Violation(
                        rule_id="no_orphan_nodes",
                        severity="CRITICAL",
                        message="Some steps are unreachable from the entry point.",
                        affected_steps=sorted(unreachable),
                        suggested_fix="Connect all steps to the main flow.",
                    )
                )

        for step_id, step in plan_graph.step_lookup.items():
            if step.get("action") == "NAVIGATE":
                has_battery = self._has_ancestor(
                    plan_graph, step_id, lambda s: s.get("action") == "CHECK_BATTERY"
                )
                if not has_battery:
                    violations.append(
                        Violation(
                            rule_id="battery_before_movement",
                            severity="CRITICAL",
                            message="NAVIGATE requires CHECK_BATTERY beforehand.",
                            affected_steps=[step_id],
                            suggested_fix="Add CHECK_BATTERY before NAVIGATE.",
                        )
                    )

                params = step.get("params") or {}
                if params.get("speed") == "fast":
                    has_scan = self._has_ancestor(
                        plan_graph,
                        step_id,
                        lambda s: s.get("action") == "SCAN_AREA"
                        and (s.get("params") or {}).get("target") == "humans",
                    )
                    if not has_scan:
                        violations.append(
                            Violation(
                                rule_id="human_proximity_before_speed",
                                severity="CRITICAL",
                                message="Fast NAVIGATE requires SCAN_AREA targeting humans.",
                                affected_steps=[step_id],
                                suggested_fix="Add SCAN_AREA(target=humans) before fast NAVIGATE.",
                            )
                        )

        if not cycle_detected:
            max_depth = self._longest_path_length(plan_graph)
            if max_depth > self.max_plan_depth:
                violations.append(
                    Violation(
                        rule_id="max_plan_depth",
                        severity="HIGH",
                        message=f"Plan depth {max_depth} exceeds limit {self.max_plan_depth}.",
                        affected_steps=list(plan_graph.step_lookup.keys()),
                        suggested_fix="Shorten the plan or split it into smaller plans.",
                    )
                )

        passed = not any(v.severity in BLOCKING_SEVERITIES for v in violations)
        return VerificationResult(passed=passed, violations=violations, warnings=warnings)

    def format_feedback(self, result: VerificationResult) -> str:
        if result.passed:
            return "Plan accepted."

        lines = ["Your previous plan was REJECTED.", ""]
        for violation in result.violations:
            lines.append(f'VIOLATION: {violation.rule_id} ({violation.severity})')
            lines.append(f"DETAILS: {violation.message}")
            if violation.affected_steps:
                lines.append(f"AFFECTED: {', '.join(violation.affected_steps)}")
            if violation.suggested_fix:
                lines.append(f"FIX: {violation.suggested_fix}")
            lines.append("")

        return "\n".join(lines).strip()

    def _resolve_allowed_actions(self, context: Optional[Any]) -> Set[str]:
        if context is None:
            return set(DEFAULT_ALLOWED_ACTIONS)

        actions = getattr(context, "available_actions", None)
        if actions:
            return set(actions)

        if isinstance(context, dict) and context.get("available_actions"):
            return set(context["available_actions"])

        return set(DEFAULT_ALLOWED_ACTIONS)

    def _has_ancestor(
        self,
        plan_graph: PlanGraph,
        step_id: str,
        predicate,
    ) -> bool:
        stack = list(plan_graph.dependencies.get(step_id, set()))
        visited: Set[str] = set()
        while stack:
            current = stack.pop()
            if current in visited:
                continue
            visited.add(current)
            step = plan_graph.step_lookup.get(current)
            if step and predicate(step):
                return True
            stack.extend(plan_graph.dependencies.get(current, set()))
        return False

    def _find_unreachable(self, plan_graph: PlanGraph) -> Set[str]:
        reachable: Set[str] = set()
        stack = list(plan_graph.entry_steps)
        while stack:
            current = stack.pop()
            if current in reachable:
                continue
            reachable.add(current)
            stack.extend(plan_graph.children.get(current, set()))
        all_steps = set(plan_graph.step_lookup.keys())
        return all_steps - reachable

    def _detect_cycle(self, plan_graph: PlanGraph) -> List[str]:
        order = self._topological_sort(plan_graph)
        if order is None:
            return list(plan_graph.step_lookup.keys())
        return []

    def _topological_sort(self, plan_graph: PlanGraph) -> Optional[List[str]]:
        in_degree: Dict[str, int] = {}
        for step_id in plan_graph.step_lookup.keys():
            in_degree[step_id] = len(plan_graph.dependencies.get(step_id, set()))

        queue = [step_id for step_id, degree in in_degree.items() if degree == 0]
        order: List[str] = []

        while queue:
            node = queue.pop()
            order.append(node)
            for child in plan_graph.children.get(node, set()):
                in_degree[child] -= 1
                if in_degree[child] == 0:
                    queue.append(child)

        if len(order) != len(plan_graph.step_lookup):
            return None
        return order

    def _longest_path_length(self, plan_graph: PlanGraph) -> int:
        order = self._topological_sort(plan_graph)
        if order is None:
            return self.max_plan_depth + 1

        distances: Dict[str, int] = {step_id: 1 for step_id in plan_graph.step_lookup.keys()}
        for node in order:
            for child in plan_graph.children.get(node, set()):
                distances[child] = max(distances[child], distances[node] + 1)
        return max(distances.values(), default=0)
