# standard libraries
from datetime import datetime
from contextlib import asynccontextmanager
import logging
import sys
import os
import io

# external libraries
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import uvicorn
from TTSRequest import TTSRequest

import torchaudio
from torchaudio.transforms import Resample
sys.path.append('CosyVoice/third_party/Matcha-TTS')
sys.path.append('CosyVoice')
from cosyvoice.cli.cosyvoice import CosyVoice3

log_filepath = f"/app/logs/app_log_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log"
logging.basicConfig(level=logging.INFO, # Sets minimum level to show
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    handlers=[
                        logging.FileHandler(log_filepath, mode="w"), # Logs to a file named 'app.log' in write mode
                        logging.StreamHandler(sys.stdout)          # Logs to the console (stdout)
                    ],
                    force=True
) # format with similar alignment to other logs

# environment variables
https_mode = os.environ["TTS_HTTPS"]
inference_mode = os.environ["TTS_INFERENCE_MODE"]
speaker_id = os.getenv("TTS_SPEAKER_ID")
ref_text = os.getenv("TTS_REF_TEXT")
ref_audio_file = os.getenv("TTS_REF_AUDIO_FILE")


def maybe_resample(waveform, orig_sr, target_sr) -> bytes:
    if orig_sr == target_sr:
        return waveform
    resampler = Resample(orig_freq=orig_sr, new_freq=target_sr)
    return resampler(waveform)


def load_tts_model(app: FastAPI):
    # app.tts = CosyVoice3(
    #     "/root/.cache/huggingface/hub/singlish-tts-server/snapshots/07018dea66cd8e62c48ede5294d1bd8850cb0cad",
    #     load_trt=False,
    #     load_vllm=False,
    #     fp16=False,
    # )
    app.tts = CosyVoice3("/app/model", load_trt=False, load_vllm=False, fp16=False)
    logging.info("Loaded TTS model")


def synthesize_tts_bytes(text: str, sample_rate: int, app: FastAPI) -> bytes:
    logging.info(f"Received input text \"{text}\"")
    if inference_mode.upper() == "SPEAKER_ID":
        logging.info("Generating TTS using speaker ID...")
        iterator = app.tts.inference_sft(f"You are a helpful assistant.<|endofprompt|>{text}", speaker_id, stream=False)
    elif inference_mode.upper() == "ZEROSHOT":
        logging.info("Generating TTS in zero-shot mode...")
        iterator = app.tts.inference_zero_shot(text, f"You are a helpful assistant.<|endofprompt|>{ref_text}", f"/app/ref/{ref_audio_file}", stream=False)
    else:
        logging.error(f"Invalid inference mode \"{inference_mode}\"")
        raise ValueError(f"Invalid inference mode \"{inference_mode}\"")

    waveform = next(iterator)["tts_speech"]
    resampled_waveform = maybe_resample(waveform, app.tts.sample_rate, sample_rate)

    buf = io.BytesIO()
    torchaudio.save(
        buf,
        resampled_waveform,
        sample_rate=sample_rate,
        format="wav"
    )
    logging.info("Generated byte array representing TTS output")

    save_wav = os.getenv("TTS_SAVE_WAV")
    if save_wav and save_wav.lower() in {"t", "true", "yes", "on", "1"}:
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        torchaudio.save(f"/app/outputs/{timestamp}.wav", resampled_waveform, sample_rate)

    return buf.getvalue()


## Manages the lifecycle of the RabbitMQ connection and resources when app starts and shutdown
@asynccontextmanager
async def lifespan(app: FastAPI):
    ## Start up logic
    logging.info("Server starting up...")
    load_tts_model(app)


    # FastAPI application starts
    yield  

    ## Shutdown logic
    logging.info("Server shutting down...")


app = FastAPI(lifespan=lifespan)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://singlish-tts"],  # frontend domain
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.post("/tts")
def tts(req: TTSRequest):
    text = req.text.strip()
    if not text:
        raise HTTPException(status_code=400, detail="text must be non-empty")

    wav_bytes = synthesize_tts_bytes(text, req.sample_rate, app)

    # Stream bytes as an audio file response
    audio_stream = io.BytesIO(wav_bytes)
    headers = {
        "Content-Disposition": 'inline; filename="tts.wav"'
    }
    return StreamingResponse(audio_stream, media_type="audio/wav", headers=headers)


@app.get("/health")
async def health():
    return {"status": "ok"}

if __name__ == "__main__":
    if https_mode == 'true':
        raise NotImplementedError("HTTPS support is not implemented yet!")
        # hostname_address = os.getenv("API_ADDR")
        # ca_cert_name = os.getenv("CA_CERT")
        # cert_path = []

        # for cert_ in ca_cert_name.split(','):
        #     cert_path.append('/opt/app/cert/{}'.format(cert_))

        # append_rootCA_cert(cert_path)

        # logging.info("Running in HTTPS")
        # uvicorn.run(app, host="0.0.0.0", port=8200, ssl_certfile="/opt/app/cert/{}.crt".format(hostname_address), ssl_keyfile="/opt/app/cert/private.key")
    else:
        logging.info("Running in HTTP")
        uvicorn.run(app, host="0.0.0.0", port=8200, log_config=None)
    