# standard libraries
from typing import Any, ClassVar
from contextlib import asynccontextmanager
import asyncio
import logging
import os
import json
from multiprocessing import Process, Pipe
from multiprocessing.connection import Connection

# external libraries
import requests
import uvicorn
from starlette.types import Message
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.websockets import WebSocketState

# shared libraries
from whisper_gstreamer import awh_entrypoint

GET_MODEL_URL = "https://modelapi.klass.dev/v1/models"
# initial number of process to spawn
PROCESS_COUNT = 1
# max number of concurrent websocket connection
MAX_WORKER_COUNT = int(os.getenv("MAX_WORKER_COUNT"))


class ConnectionManager(object):
    available_whisper_worker_count: ClassVar[int] = MAX_WORKER_COUNT
    process_pool: list[tuple[Connection, Connection, Process]] = []

    @classmethod
    def populate_process_pool(cls, process_count=PROCESS_COUNT) -> None:
        for _ in range(process_count):
            pipe = Pipe()
            parent_conn, child_conn = pipe
            gstreamer_process = Process(target=awh_entrypoint, args=(child_conn,))
            gstreamer_process.start()
            cls.process_pool.append((parent_conn, child_conn, gstreamer_process))

    @classmethod
    def init_process_pool(cls) -> None:
        print("Initialising the gstreamer process pool...", end="")
        cls.populate_process_pool()
        print("Done")

    def __init__(self, ws: WebSocket):
        self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
        self.ws: WebSocket = ws
        self.send_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
        _ = self.loop.create_task(self.handle_writing())
        self.closing: bool = False
        self.parent_conn: Connection[Any]
        self.child_conn: Connection[Any]
        self.gstreamer_process: Process | None = None

    @property
    def closed(self) -> bool:
        return self.closing or self.ws.client_state == WebSocketState.DISCONNECTED

    # send to ws
    async def handle_writing(self) -> None:
        while True:
            try:
                event = await asyncio.wait_for(self.send_queue.get(), 5)
                if self.closed:
                    break
                await self.ws.send_json(event)
            except asyncio.TimeoutError:
                continue

    async def on_ws_open(self):
        if ConnectionManager.available_whisper_worker_count > 0:
            ConnectionManager.available_whisper_worker_count -= 1
            # retrieve process from existing pool
            self.parent_conn, self.child_conn, self.gstreamer_process = (
                ConnectionManager.process_pool.pop()
            )
            _ = self.loop.create_task(self.handle_results_from_gstreamer())
            # add new process into pool after taking existing
            # this means that there will always be an additional process on top of all ongoing process
            # NOTE: remove this if theres limit GPU resource
            ConnectionManager.populate_process_pool(1)
            event = {
                "status": 0,
                "result": {"message": "Whisper model setup successfully!"},
            }

            self.send_queue.put_nowait(event)

        else:
            event = {
                "status": "STATUS_NOT_AVAILABLE",
                "message": "No decoder available, try again later",
            }

            await self.ws.send_json(event)
            await self.close()

    async def on_message(self, msg_obj: Message):
        message: str | bytes = msg_obj.get("bytes", msg_obj.get("text"))
        if self.gstreamer_process is not None:
            self.parent_conn.send(message)

    async def close(self):
        if not self.closed:
            self.closing = True
            await self.ws.close()

    def on_close(self):
        logging.info("audio input disconnected")

        if self.gstreamer_process is not None:
            ConnectionManager.available_whisper_worker_count += 1
            # send last message to ensure thread closes properly
            self.child_conn.send(None)

            if self.gstreamer_process.is_alive():
                self.gstreamer_process.kill()

    # gstreamer returning result
    async def handle_results_from_gstreamer(self):
        while not self.closed:

            message = await asyncio.to_thread(self.parent_conn.recv)

            if self.closed or message is None:
                return

            try:
                event = json.loads(message)
                # close socket when gstreamer terminate
                if "message" in event and "terminated" in event["message"]:
                    await self.close()
                else:
                    self.send_queue.put_nowait(event)

            except Exception as ex:
                logging.info(
                    "********  WhisperSocketHandler exception caught {}".format(ex)
                )


@asynccontextmanager
async def entrypoint(app: FastAPI):
    logging.basicConfig(
        level=logging.NOTSET, format="%(levelname)8s %(asctime)s %(message)s "
    )
    logging.debug("Starting up server")
    logging.info(f"Running in HTTP")
    ConnectionManager.init_process_pool()

    yield


app = FastAPI(lifespan=entrypoint)


@app.websocket("/ws/audio_stream")
async def websocket_endpoint(websocket: WebSocket):
    loop = asyncio.get_running_loop()

    await websocket.accept()

    cm = ConnectionManager(websocket)

    _ = loop.create_task(cm.on_ws_open())

    try:
        while websocket.client_state == WebSocketState.CONNECTED:
            data = await websocket.receive()
            await cm.on_message(data)

    except WebSocketDisconnect:
        logging.info("Client disconnected")

    finally:
        cm.on_close()


@app.get("/models")
async def get_models():
    # get list of models
    response = await asyncio.to_thread(requests.get, GET_MODEL_URL, verify=False)
    obj: dict[str, Any] = response.json()
    model_list = [x["id"] for x in obj["data"]]

    return JSONResponse(content={"models": model_list})


def main():
    listen_port = int(os.environ.get("WHISPER_SERVER_PORT", 8300))
    uvicorn.run("whisper_server:app", host="0.0.0.0", port=listen_port, workers=1)


if __name__ == "__main__":
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
    main()
