#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import sys
import time
from pathlib import Path
from typing import Iterable, Optional, Sequence


ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))

from planner.dag_parser import DAGParser, PlanGraph  # noqa: E402
from planner.llm_client import LLMClient, LLMConfig, LLMClientError  # noqa: E402
from planner.planner import PlanContext, PlanRequest, PlannerService  # noqa: E402


DEFAULT_ACTIONS = [
    "CHECK_BATTERY",
    "NAVIGATE",
    "SCAN_AREA",
    "VERIFY_OBJECT",
    "SPEAK",
    "WAIT",
    "IDENTIFY_PERSON",
    "ALERT_OPERATOR",
]


class TemplatedLLMClient(LLMClient):
    def __init__(
        self,
        config: Optional[LLMConfig] = None,
        system_prompt_path: Optional[Path] = None,
        max_plan_depth: int = 20,
    ) -> None:
        super().__init__(config=config)
        self.max_plan_depth = max_plan_depth
        self.system_prompt_template = self._load_template(system_prompt_path)

    def _load_template(self, system_prompt_path: Optional[Path]) -> Optional[str]:
        candidate = system_prompt_path or (ROOT / "src" / "planner" / "prompts" / "system.txt")
        if not candidate.exists():
            return None
        return candidate.read_text(encoding="utf-8")

    def _build_system_prompt(self, available_actions: Optional[set]) -> str:
        template = self.system_prompt_template
        if not template:
            return super()._build_system_prompt(available_actions)

        capabilities = (
            ", ".join(sorted(available_actions))
            if available_actions
            else ", ".join(DEFAULT_ACTIONS)
        )
        few_shot = json.dumps(
            {
                "intent": "Explore office zone and report hazards.",
                "workflow_topology": "DAG",
                "plan_id": "plan_example_001",
                "steps": [
                    {
                        "step_id": "s1",
                        "type": "EXECUTION",
                        "agent_role": "robot",
                        "action": "CHECK_BATTERY",
                        "params": {"min_level": 30},
                    },
                    {
                        "step_id": "s2",
                        "type": "EXECUTION",
                        "agent_role": "robot",
                        "action": "NAVIGATE",
                        "dependencies": ["s1"],
                        "params": {"destination": "office_zone", "speed": "normal"},
                    },
                    {
                        "step_id": "s3",
                        "type": "PERCEPTION",
                        "agent_role": "robot",
                        "action": "SCAN_AREA",
                        "dependencies": ["s2"],
                        "params": {"target": "hazards"},
                    },
                    {
                        "step_id": "s4",
                        "type": "EXECUTION",
                        "agent_role": "robot",
                        "action": "ALERT_OPERATOR",
                        "dependencies": ["s3"],
                        "params": {"reason": "hazard_report"},
                    },
                ],
            },
            indent=2,
        )

        return (
            template.replace("{{SCHEMA_JSON}}", json.dumps(self.schema, indent=2))
            .replace("{{MAX_PLAN_DEPTH}}", str(self.max_plan_depth))
            .replace("{{ROBOT_CAPABILITIES}}", capabilities)
            .replace("{{FEW_SHOT_EXAMPLES}}", few_shot)
        )


def to_mermaid(plan_graph: PlanGraph) -> str:
    lines = ["graph TD"]
    for step in plan_graph.steps:
        step_id = step["step_id"]
        action = step.get("action", "")
        label = f"{step_id} | {action}"
        lines.append(f'  {step_id}["{label}"]')

    for step in plan_graph.steps:
        step_id = step["step_id"]
        deps = step.get("dependencies", []) or []
        if not deps:
            continue
        for parent in deps:
            lines.append(f"  {parent} --> {step_id}")
    return "\n".join(lines) + "\n"


def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run live orchestrator planner call and emit graph visualization."
    )
    parser.add_argument(
        "--intent",
        default="Explore the office area, identify people, and report hazards.",
        help="User instruction for planner.",
    )
    parser.add_argument("--robot-id", default="ghost-v60")
    parser.add_argument("--location", default="office_entry")
    parser.add_argument("--battery-level", type=int, default=78)
    parser.add_argument("--venue-id", default="onyx-lab")
    parser.add_argument("--mission-type", default="exploration")
    parser.add_argument(
        "--actions",
        nargs="*",
        default=DEFAULT_ACTIONS,
        help="Allowed actions for planner context.",
    )
    parser.add_argument(
        "--endpoint",
        default="https://modelapi.klass.dev/v1/chat/completions",
        help="Model API endpoint.",
    )
    parser.add_argument(
        "--model",
        default="Qwen3-Next-80B-A3B-FP8",
        help="Model id for API request.",
    )
    parser.add_argument(
        "--timeout-s",
        type=int,
        default=30,
        help="LLM API timeout in seconds.",
    )
    parser.add_argument(
        "--no-response-format",
        action="store_true",
        help="Do not send response_format=json_object in chat payload.",
    )
    parser.add_argument(
        "--accept-language",
        default=None,
        help="Optional Accept-Language header, e.g. en-US,en",
    )
    parser.add_argument(
        "--max-tokens",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--top-p",
        type=float,
        default=None,
    )
    parser.add_argument(
        "--top-k",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--use-openai-sdk",
        action="store_true",
        help="Use openai-python client against OpenAI-compatible endpoint.",
    )
    parser.add_argument(
        "--stream",
        action="store_true",
        help="Enable streaming mode (chunks are accumulated before parse).",
    )
    parser.add_argument(
        "--system-prompt",
        default=str(ROOT / "src" / "planner" / "prompts" / "system.txt"),
        help="System prompt template file.",
    )
    parser.add_argument(
        "--out-dir",
        default=str(ROOT / "artifacts" / "planner_demo"),
        help="Directory for JSON + Mermaid outputs.",
    )
    return parser.parse_args(argv)


def unique_actions(actions: Iterable[str]) -> list[str]:
    seen = set()
    output = []
    for item in actions:
        value = item.strip()
        if not value or value in seen:
            continue
        seen.add(value)
        output.append(value)
    return output


def main(argv: Optional[Sequence[str]] = None) -> int:
    args = parse_args(argv)
    actions = unique_actions(args.actions)

    llm_config = LLMConfig(
        endpoint=args.endpoint,
        model=args.model,
        timeout_s=args.timeout_s,
        include_response_format=not args.no_response_format,
        accept_language=args.accept_language,
        max_tokens=args.max_tokens,
        top_p=args.top_p,
        top_k=args.top_k,
        use_openai_sdk=args.use_openai_sdk,
        stream=args.stream,
    )
    llm_client = TemplatedLLMClient(
        config=llm_config,
        system_prompt_path=Path(args.system_prompt),
        max_plan_depth=20,
    )
    service = PlannerService(llm_client=llm_client)

    context = PlanContext(
        robot_id=args.robot_id,
        current_location=args.location,
        battery_level=args.battery_level,
        available_actions=actions,
        venue_id=args.venue_id,
        mission_type=args.mission_type,
    )
    request = PlanRequest(
        intent=args.intent,
        context=context,
        request_id=f"live-demo-{int(time.time())}",
    )

    print(f"[planner] endpoint={args.endpoint}")
    print(f"[planner] model={args.model}")
    print(f"[planner] intent={args.intent}")

    try:
        result = service.plan(request)
    except LLMClientError as exc:
        print(f"[error] LLM request failed: {exc}")
        return 2

    print(
        json.dumps(
            {
                "success": result.success,
                "is_safe_noop": result.is_safe_noop,
                "attempts": result.attempts,
                "failure_reason": result.failure_reason,
                "trace_id": result.trace_id,
                "total_latency_ms": result.total_latency_ms,
            },
            indent=2,
        )
    )

    if not result.plan:
        print("[error] Planner did not return a plan payload.")
        return 3

    parser = DAGParser()
    parsed = parser.parse(result.plan, expected_intent=args.intent)
    if not parsed.ok or not parsed.plan_graph:
        print("[error] Returned plan could not be parsed.")
        if parsed.error:
            print(parsed.error.message)
            if parsed.error.details:
                print(parsed.error.details)
        return 4

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    stamp = time.strftime("%Y%m%d-%H%M%S")
    json_path = out_dir / f"plan-{stamp}.json"
    mermaid_path = out_dir / f"plan-{stamp}.mmd"

    json_path.write_text(json.dumps(result.plan, indent=2), encoding="utf-8")
    mermaid_path.write_text(to_mermaid(parsed.plan_graph), encoding="utf-8")

    print(f"[ok] plan json: {json_path}")
    print(f"[ok] graph mermaid: {mermaid_path}")
    print("\n--- Mermaid Preview ---")
    print(mermaid_path.read_text(encoding="utf-8"))

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
