from __future__ import annotations

from dataclasses import dataclass
import copy
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Set

try:
    import jsonschema

    _HAS_JSONSCHEMA = True
except Exception:
    jsonschema = None
    _HAS_JSONSCHEMA = False


SCHEMA_PATH = Path(__file__).resolve().parents[1] / "verifier" / "schema.json"


def load_schema(schema_path: Optional[Path] = None) -> Dict[str, Any]:
    path = schema_path or SCHEMA_PATH
    with path.open("r", encoding="utf-8") as handle:
        return json.load(handle)


@dataclass
class ParseError:
    message: str
    details: Optional[str] = None
    hint: Optional[str] = None


@dataclass
class PlanGraph:
    plan: Dict[str, Any]
    steps: List[Dict[str, Any]]
    step_lookup: Dict[str, Dict[str, Any]]
    dependencies: Dict[str, Set[str]]
    children: Dict[str, Set[str]]
    entry_steps: List[str]


@dataclass
class ParseResult:
    ok: bool
    plan_graph: Optional[PlanGraph]
    error: Optional[ParseError]


class DAGParser:
    def __init__(self, schema_path: Optional[Path] = None, strict_intent: bool = True) -> None:
        self.schema_path = schema_path or SCHEMA_PATH
        self.schema = load_schema(self.schema_path)
        self.strict_intent = strict_intent

    def parse(self, raw_json: Any, expected_intent: Optional[str] = None) -> ParseResult:
        data = self._parse_json(raw_json)
        if isinstance(data, ParseError):
            return ParseResult(ok=False, plan_graph=None, error=data)

        schema_errors = self._validate_schema(data)
        if schema_errors:
            return ParseResult(
                ok=False,
                plan_graph=None,
                error=ParseError(
                    message="Schema validation failed.",
                    details="; ".join(schema_errors[:5]),
                    hint="Return a JSON object that matches schema.json.",
                ),
            )

        plan = copy.deepcopy(data)
        if not plan.get("workflow_topology"):
            default = (
                self.schema.get("properties", {})
                .get("workflow_topology", {})
                .get("default", "DAG")
            )
            plan["workflow_topology"] = default

        if expected_intent is not None and self.strict_intent:
            if plan.get("intent") != expected_intent:
                return ParseResult(
                    ok=False,
                    plan_graph=None,
                    error=ParseError(
                        message="Plan intent does not match user intent.",
                        details=f'Plan intent="{plan.get("intent")}"',
                        hint="Set intent to the original user intent string.",
                    ),
                )

        steps = plan.get("steps") or []
        if not steps:
            return ParseResult(
                ok=False,
                plan_graph=None,
                error=ParseError(
                    message="Plan has no steps.",
                    hint="Provide at least one step in the plan.",
                ),
            )

        step_ids: List[str] = []
        for step in steps:
            if "dependencies" not in step or step["dependencies"] is None:
                step["dependencies"] = []
            if not isinstance(step["dependencies"], list):
                return ParseResult(
                    ok=False,
                    plan_graph=None,
                    error=ParseError(
                        message="Step dependencies must be a list.",
                        details=f'step_id="{step.get("step_id")}"',
                        hint="Set dependencies to [] or a list of step_ids.",
                    ),
                )
            step_ids.append(step.get("step_id"))

        duplicates = {step_id for step_id in step_ids if step_ids.count(step_id) > 1}
        if duplicates:
            return ParseResult(
                ok=False,
                plan_graph=None,
                error=ParseError(
                    message="Duplicate step_id values found.",
                    details=", ".join(sorted(duplicates)),
                    hint="Ensure each step_id is unique.",
                ),
            )

        if plan.get("workflow_topology") == "SEQUENTIAL":
            for index in range(1, len(steps)):
                if not steps[index].get("dependencies"):
                    steps[index]["dependencies"] = [steps[index - 1]["step_id"]]

        step_id_set = {step_id for step_id in step_ids if step_id is not None}
        for step in steps:
            for dep in step.get("dependencies", []):
                if dep not in step_id_set:
                    return ParseResult(
                        ok=False,
                        plan_graph=None,
                        error=ParseError(
                            message="Dependency refers to unknown step_id.",
                            details=f'step_id="{step.get("step_id")}" dependency="{dep}"',
                            hint="Ensure dependencies reference valid step_id values.",
                        ),
                    )

        dependencies: Dict[str, Set[str]] = {}
        children: Dict[str, Set[str]] = {step_id: set() for step_id in step_id_set}
        step_lookup: Dict[str, Dict[str, Any]] = {}
        for step in steps:
            step_id = step["step_id"]
            step_lookup[step_id] = step
            deps = set(step.get("dependencies", []))
            dependencies[step_id] = deps
            for parent in deps:
                children.setdefault(parent, set()).add(step_id)

        entry_steps = [step_id for step_id, deps in dependencies.items() if not deps]

        plan_graph = PlanGraph(
            plan=plan,
            steps=steps,
            step_lookup=step_lookup,
            dependencies=dependencies,
            children=children,
            entry_steps=entry_steps,
        )
        return ParseResult(ok=True, plan_graph=plan_graph, error=None)

    def _parse_json(self, raw_json: Any) -> Any:
        if isinstance(raw_json, (bytes, bytearray)):
            raw_json = raw_json.decode("utf-8", errors="replace")

        if isinstance(raw_json, str):
            try:
                return json.loads(raw_json)
            except json.JSONDecodeError as exc:
                return ParseError(
                    message="Invalid JSON.",
                    details=str(exc),
                    hint="Return a JSON object, not markdown or text.",
                )

        if isinstance(raw_json, dict):
            return raw_json

        return ParseError(
            message="Planner output must be a JSON object.",
            details=f"type={type(raw_json).__name__}",
            hint="Return a JSON object only.",
        )

    def _validate_schema(self, data: Dict[str, Any]) -> List[str]:
        if not isinstance(data, dict):
            return ["Top-level JSON must be an object."]

        if _HAS_JSONSCHEMA:
            validator = jsonschema.Draft7Validator(self.schema)
            errors = sorted(validator.iter_errors(data), key=lambda err: list(err.path))
            return [self._format_schema_error(error) for error in errors]

        return self._basic_validate(data)

    def _basic_validate(self, data: Dict[str, Any]) -> List[str]:
        errors: List[str] = []
        required = self.schema.get("required", [])
        for field in required:
            if field not in data:
                errors.append(f'Missing required field "{field}".')

        if "intent" in data and not isinstance(data.get("intent"), str):
            errors.append('Field "intent" must be a string.')
        if "plan_id" in data and not isinstance(data.get("plan_id"), str):
            errors.append('Field "plan_id" must be a string.')
        if "response_text" in data and not isinstance(data.get("response_text"), str):
            errors.append('Field "response_text" must be a string.')

        if "workflow_topology" in data:
            allowed = (
                self.schema.get("properties", {})
                .get("workflow_topology", {})
                .get("enum", [])
            )
            if data["workflow_topology"] not in allowed:
                errors.append('workflow_topology must be "DAG" or "SEQUENTIAL".')

        steps = data.get("steps")
        if steps is not None and not isinstance(steps, list):
            errors.append('Field "steps" must be a list.')
            return errors

        if isinstance(steps, list):
            for idx, step in enumerate(steps):
                if not isinstance(step, dict):
                    errors.append(f"Step {idx} must be an object.")
                    continue
                for field in ("step_id", "type", "agent_role", "action"):
                    if field not in step:
                        errors.append(f'Step {idx} missing "{field}".')

                if "step_id" in step and not isinstance(step.get("step_id"), str):
                    errors.append(f"Step {idx} step_id must be a string.")
                if "agent_role" in step and not isinstance(step.get("agent_role"), str):
                    errors.append(f"Step {idx} agent_role must be a string.")
                if "action" in step and not isinstance(step.get("action"), str):
                    errors.append(f"Step {idx} action must be a string.")

                step_type = step.get("type")
                allowed_types = (
                    self.schema.get("properties", {})
                    .get("steps", {})
                    .get("items", {})
                    .get("properties", {})
                    .get("type", {})
                    .get("enum", [])
                )
                if step_type and step_type not in allowed_types:
                    errors.append(f'Step {idx} type "{step_type}" is invalid.')

                dependencies = step.get("dependencies")
                if dependencies is not None and not isinstance(dependencies, list):
                    errors.append(f"Step {idx} dependencies must be a list.")

        return errors

    @staticmethod
    def _format_schema_error(error: Any) -> str:
        path = ".".join(str(item) for item in error.path)
        location = f"{path}: " if path else ""
        return f"{location}{error.message}"
