from __future__ import annotations

from dataclasses import asdict, dataclass, field, is_dataclass
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional
import urllib.error
import urllib.request

from .dag_parser import load_schema, SCHEMA_PATH


class LLMClientError(RuntimeError):
    pass


@dataclass
class LLMConfig:
    endpoint: str = "https://modelapi.klass.dev/v1/chat/completions"
    model: str = "Qwen3-Next-80B-A3B-FP8"
    api_key: Optional[str] = None
    timeout_s: int = 30
    include_response_format: bool = True
    accept_language: Optional[str] = None
    extra_headers: Dict[str, str] = field(default_factory=dict)
    max_tokens: Optional[int] = None
    top_p: Optional[float] = None
    top_k: Optional[int] = None
    use_openai_sdk: bool = False
    stream: bool = False


class LLMClient:
    def __init__(
        self,
        config: Optional[LLMConfig] = None,
        schema_path: Optional[Path] = None,
    ) -> None:
        self.config = config or LLMConfig()
        if self.config.api_key is None:
            self.config.api_key = self._resolve_api_key(self.config.endpoint)
        if self.config.accept_language is None:
            self.config.accept_language = os.getenv("LLM_ACCEPT_LANGUAGE")
        if not self.config.use_openai_sdk:
            self.config.use_openai_sdk = (
                os.getenv("LLM_USE_OPENAI_SDK", "").strip().lower() in {"1", "true", "yes", "on"}
            )
        if not self.config.stream:
            self.config.stream = (
                os.getenv("LLM_STREAM", "").strip().lower() in {"1", "true", "yes", "on"}
            )
        self.schema = load_schema(schema_path or SCHEMA_PATH)

    def generate(
        self,
        intent: str,
        context: Optional[Any] = None,
        feedback: Optional[str] = None,
        temperature: float = 0.3,
    ) -> str:
        available_actions = self._resolve_actions(context)
        system_prompt = self._build_system_prompt(available_actions)
        user_prompt = self._build_user_prompt(intent, context, feedback)

        payload: Dict[str, Any] = {
            "model": self.config.model,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            "temperature": temperature,
        }
        if self.config.include_response_format:
            payload["response_format"] = {"type": "json_object"}
        if self.config.max_tokens is not None:
            payload["max_tokens"] = self.config.max_tokens
        if self.config.top_p is not None:
            payload["top_p"] = self.config.top_p
        if self.config.top_k is not None:
            payload["top_k"] = self.config.top_k

        response = self._post(payload)
        try:
            content = response["choices"][0]["message"]["content"]
        except (KeyError, IndexError, TypeError) as exc:
            raise LLMClientError(f"Unexpected LLM response format: {exc}") from exc

        if isinstance(content, dict):
            return json.dumps(content)
        if not isinstance(content, str):
            raise LLMClientError("LLM response content is not a string or object.")
        return content

    def _build_system_prompt(self, available_actions: Optional[set]) -> str:
        actions_text = ""
        if available_actions:
            actions_text = f"Allowed actions: {', '.join(sorted(available_actions))}."

        schema_text = json.dumps(self.schema, indent=2)
        return (
            "You are a robotic action planner.\n"
            "Return a JSON object that strictly matches this schema:\n"
            f"{schema_text}\n"
            "Use the exact user intent string for the intent field.\n"
            f"{actions_text}\n"
            "Return only JSON. Do not include markdown."
        )

    def _build_user_prompt(
        self,
        intent: str,
        context: Optional[Any],
        feedback: Optional[str],
    ) -> str:
        context_payload = self._serialize_context(context)
        context_text = json.dumps(context_payload, indent=2, sort_keys=True)

        sections = [f"Intent: {intent}", "Context:", context_text]
        if feedback:
            sections.append("Feedback from previous attempt:")
            sections.append(feedback)
        return "\n".join(sections)

    def _serialize_context(self, context: Optional[Any]) -> Dict[str, Any]:
        if context is None:
            return {}
        if isinstance(context, dict):
            return context
        if is_dataclass(context):
            return asdict(context)
        if hasattr(context, "__dict__"):
            return dict(context.__dict__)
        return {"value": str(context)}

    def _resolve_actions(self, context: Optional[Any]) -> Optional[set]:
        if context is None:
            return None
        if isinstance(context, dict) and context.get("available_actions"):
            return set(context["available_actions"])
        actions = getattr(context, "available_actions", None)
        if actions:
            return set(actions)
        return None

    def _post(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        headers = {"Content-Type": "application/json"}
        if self.config.api_key:
            headers["Authorization"] = f"Bearer {self.config.api_key}"
        if self.config.accept_language:
            headers["Accept-Language"] = self.config.accept_language
        if self.config.extra_headers:
            headers.update(self.config.extra_headers)

        if self.config.use_openai_sdk:
            return self._post_openai_sdk(payload, headers)

        data = json.dumps(payload).encode("utf-8")

        try:
            import requests

            response = requests.post(
                self.config.endpoint,
                json=payload,
                headers=headers,
                timeout=self.config.timeout_s,
            )
            response.raise_for_status()
            return response.json()
        except ImportError:
            pass
        except Exception as exc:
            raise LLMClientError(f"LLM request failed: {exc}") from exc

        request = urllib.request.Request(
            self.config.endpoint,
            data=data,
            headers=headers,
            method="POST",
        )
        try:
            with urllib.request.urlopen(request, timeout=self.config.timeout_s) as handle:
                body = handle.read().decode("utf-8")
                return json.loads(body)
        except urllib.error.HTTPError as exc:
            detail = exc.read().decode("utf-8", errors="replace")
            raise LLMClientError(f"LLM HTTP error {exc.code}: {detail}") from exc
        except Exception as exc:
            raise LLMClientError(f"LLM request failed: {exc}") from exc

    def _post_openai_sdk(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Dict[str, Any]:
        try:
            from openai import OpenAI
        except ImportError as exc:
            raise LLMClientError(
                "openai package not installed; disable SDK mode or install openai."
            ) from exc

        base_url = self._to_base_url(self.config.endpoint)
        client = OpenAI(api_key=self.config.api_key, base_url=base_url, timeout=self.config.timeout_s)

        request_kwargs: Dict[str, Any] = {
            "model": payload["model"],
            "messages": payload["messages"],
            "temperature": payload.get("temperature", 0.3),
            "stream": self.config.stream,
            "extra_headers": {k: v for k, v in headers.items() if k != "Content-Type"},
        }
        if "response_format" in payload:
            request_kwargs["response_format"] = payload["response_format"]
        if "max_tokens" in payload:
            request_kwargs["max_tokens"] = payload["max_tokens"]
        if "top_p" in payload:
            request_kwargs["top_p"] = payload["top_p"]
        if "top_k" in payload:
            request_kwargs["extra_body"] = {"top_k": payload["top_k"]}

        try:
            if self.config.stream:
                stream = client.chat.completions.create(**request_kwargs)
                parts = []
                for chunk in stream:
                    choices = getattr(chunk, "choices", None) or []
                    if not choices:
                        continue
                    delta = getattr(choices[0], "delta", None)
                    text = getattr(delta, "content", None) if delta else None
                    if text:
                        parts.append(text)
                content = "".join(parts)
                return {"choices": [{"message": {"content": content}}]}

            response = client.chat.completions.create(**request_kwargs)
            if hasattr(response, "model_dump"):
                return response.model_dump()
            if isinstance(response, dict):
                return response
            raise LLMClientError("Unexpected SDK response type.")
        except Exception as exc:
            raise LLMClientError(f"LLM request failed: {exc}") from exc

    @staticmethod
    def _resolve_api_key(endpoint: str) -> Optional[str]:
        host = endpoint.lower()
        if "api.z.ai" in host:
            return os.getenv("Z_AI_API_KEY") or os.getenv("OPENAI_API_KEY")
        return (
            os.getenv("MODELAPI_KEY")
            or os.getenv("MODEL_API_KEY")
            or os.getenv("OPENAI_API_KEY")
        )

    @staticmethod
    def _to_base_url(endpoint: str) -> str:
        marker = "/chat/completions"
        if endpoint.endswith(marker):
            return endpoint[: -len(marker)]
        return endpoint
