# Standard
import asyncio
from enum import IntEnum
from multiprocessing.connection import Connection
import json
import logging
import argparse
import typing_extensions
import uuid
import traceback
import psutil
import gc
import re
import requests

# external
import numpy as np
from numpy._typing import NDArray
import soundfile as sf

# Project
from whisper_online import add_shared_args, set_logging, asr_factory

# GStreamer
import gi
import os

gi.require_version("Gst", "1.0")
from gi.repository import Gst

logger = logging.getLogger(__name__)
# Create a console handler

console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)

# Define the log format with timestamp
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")

# Attach the formatter to the handler
console_handler.setFormatter(formatter)
# Add the handler to the logger
logger.addHandler(console_handler)
logger.propagate = False

file_handler = logging.FileHandler("/opt/exp/logconsole.txt")  # Logs to "app.log"
# Add the handler to the logger
logger.addHandler(file_handler)

# global config
SAMPLING_RATE = 16000
LLM_URL = "https://modelapi.klass.dev/v1/chat/completions"
ACTIVATION_PHRASE: str = os.getenv("ACTIVATION_PHRASE", "Hey robot").lower()
COMMAND_FILE: str = "/opt/command.json"


filter_audio_list = []

filter_audio_list.append(sf.SoundFile("/opt/exp/wav_file1.wav", "w", 16_000, 1))
filter_audio_list.append(sf.SoundFile("/opt/exp/wav_file2.wav", "w", 16_000, 1))


def get_args():
    parser = argparse.ArgumentParser()
    add_shared_args(parser)
    args = parser.parse_args()
    set_logging(args, logger, other="")

    # declare whisper worker to actively decode the incoming audio stream
    args.lan = "en"
    args.model = "small.en"
    args.model_dir = "/opt/app/faster_whisper/" + args.model
    return args


# additional user defined function to detect hallucination from whisper
def is_ngram_repeated(sentence, ngram_cnt=5):

    for ngram_cnt_ in range(1, ngram_cnt + 1):
        ngrams_list = detect_repeated_ngrams(sentence, ngram_cnt_)

        for ngrams_ in ngrams_list:
            if found_adjacency_repeated(ngrams_, 5 if ngram_cnt_ == 1 else 3):
                return True
    return False


def remove_punctuation(sentence):
    # Use a regular expression to remove punctuation
    return re.sub(r"[^\w\s]", "", sentence)


def detect_repeated_ngrams(sentence, n=2):

    sentence = remove_punctuation(sentence).lower()
    # Split the sentence into words
    words = sentence.split()

    # Generate n-grams
    ngrams_ = []

    for x in range(n):
        ngrams = [" ".join(words[i : i + n]) for i in range(x, len(words) - n + 1, n)]
        ngrams_.append(ngrams)

    return ngrams_


def found_adjacency_repeated(ngram_list, N=3):
    count = 1
    for idx in range(len(ngram_list) - 1):
        if ngram_list[idx] == ngram_list[idx + 1]:
            count += 1

        if idx > 0:
            if ngram_list[idx] != ngram_list[idx - 1]:
                count = 1  # reset
        if count >= N:
            return True
    return False


# given np array, calculate db
def calculate_dB(audio_seg):
    # Convert raw audio to PCM (normalized to [-1, 1])
    pcm_data = np.frombuffer(audio_seg, dtype=np.int16).astype(np.float32) / 32768.0

    # calculate the root mean square (RMS) of the audio signal
    rms = np.sqrt(np.mean(pcm_data**2))

    if rms == 0:
        return -np.inf

    # Convert RMS to decibels (dB)
    decibels = round(20 * np.log10(rms), 2)
    return decibels


def convert_seconds_to_num_bytes(num_of_seconds):
    converted_num_bytes = int(num_of_seconds * 16000 * 2)

    # make sure its even number
    if converted_num_bytes % 2 == 1:
        converted_num_bytes += 1

    return converted_num_bytes


def convert_from_num_bytes_to_seconds(num_of_bytes):

    return round(num_of_bytes / 32000, 2)


class LLMResGenerator(object):

    llm_res_queue = asyncio.Queue()

    @classmethod
    async def generate_llm_res(cls, cmd: list[tuple[str, str]]) -> str:

        prompt: list[dict[str, str]] = [
            {
                "role": "user",
                "content": f"""Please identify instruction that can be given to a robot after the phase "Hey Robot!" in the following paragraph.
                {cmd}
                Only give me the instruction in laymen terms starting with an action word in continuous tenses, such as moving or spinning.
                I would also like numbers to be represented in numeric form.
                /nothink""",
            },
        ]

        payload = {
            "messages": prompt,
            "model": "qwen3-32b",
            "stream": False,
            "temperature": 0.6,
            "top_p": 0.95,
            "min_p": 0,
            "top_k": 20,
            "seed": 888,
        }

        try:
            response = await asyncio.to_thread(
                requests.post, LLM_URL, json=payload, verify=False
            )
            response_data = response.json()
            llm_response_content = response_data["choices"][0]["message"]["content"]

            if "</think>" in llm_response_content:
                llm_response_content = llm_response_content.split("</think>")[1].strip()

                cls.llm_res_queue.put_nowait(
                    json.dumps({"chat_text": llm_response_content, "role": "llm"})
                )

                # write command to file for robot to move
                with open(COMMAND_FILE, "w") as fp:
                    json.dump({"command": llm_response_content}, fp)

        except Exception as ex:
            print(ex)
            cls.llm_res_queue.put_nowait(
                json.dumps(
                    {
                        "chat_text": "Sorry, I do not understand your command.",
                        "role": "llm",
                    }
                )
            )


class MultiStreamHandler:
    # global audio buffer
    AUDIO_BUFFERS: list[bytearray] = []
    CORRECTION_QUEUE = asyncio.Queue()
    THRESHOLD = -40
    UNLOADED_BUFFER = 0

    def __init__(self):
        self.args = get_args()
        self.start_processing_asr = True
        self.unloaded = False
        self.total_unloaded = 0
        self.loop = asyncio.get_running_loop()
        # clear buffer every 1 min
        self.loop.call_later(60, self.unload_buffer)

    @staticmethod
    def process_correction_chunk(
        proc,
        pcm_data: NDArray,
        intended_start_time_in_secs: float = 0,
        initial_prompt="",
    ):
        return proc.transcribe_for_correction(
            pcm_data,
            beam_size=2,
            intended_start_time_in_secs=intended_start_time_in_secs,
            initial_prompt=initial_prompt,
        )

    async def correction_results_task(self):
        # initialise whisper model for correction stream
        args = get_args()
        # use this for slower transcription speed but better accuracy for correction
        # args.backend = "whisper_timestamped"
        args.model = "medium.en"
        args.model_dir = "/opt/app/faster_whisper/" + args.model
        online = asr_factory(args)
        self.asr_processor_correction = online

        while True:  # TODO: needs a proper exit condition
            asr_result, sock_num, seg_id = (
                await MultiStreamHandler.CORRECTION_QUEUE.get()
            )
            transcript_with_timestamps, final_seg, num_tokens, full_transcript = (
                asr_result
            )

            if not transcript_with_timestamps or len(transcript_with_timestamps) == 0:
                continue

            start_time = round(transcript_with_timestamps[0]["start"] - 0.2, 2)
            end_time = round(transcript_with_timestamps[-1]["end"] + 0.2, 2)

            start_time_offset_in_bytearray = (
                convert_seconds_to_num_bytes(start_time) - self.total_unloaded
            )
            end_time_offset_in_bytearray = (
                convert_seconds_to_num_bytes(end_time) - self.total_unloaded
            )

            audio_buffer_to_correct = MultiStreamHandler.AUDIO_BUFFERS[sock_num]

            # if the end time offset exceed the limit, takes the max
            end_time_offset_in_bytearray = min(
                end_time_offset_in_bytearray, len(audio_buffer_to_correct)
            )

            audio_to_correct = audio_buffer_to_correct[
                start_time_offset_in_bytearray:end_time_offset_in_bytearray
            ]

            # correction with improvement logic by adding more context
            # intended_start_time_in_secs = convert_from_num_bytes_to_seconds(start_time_offset_in_bytearray)
            # # retrieve whatever left over buffer (after unloaded), which we should have 30 secs
            # audio_to_correct = audio_buffer_to_correct[
            #     :end_time_offset_in_bytearray
            # ]

            # print(' len of audio_to_correct: ', len(audio_to_correct))
            # audio_to_correct.extend([0]*16000)
            # print(' len of audio_to_correct: ', len(audio_to_correct))
            pcm_data = (
                np.frombuffer(audio_to_correct[:], dtype=np.int16).astype(np.float32)
                / 32768.0
            )

            try:
                asr_result = await asyncio.to_thread(
                    MultiStreamHandler.process_correction_chunk,
                    self.asr_processor_correction,
                    pcm_data,
                )

                txnscript = ""

                if asr_result:
                    txnscript = " ".join(w["word"] for w in asr_result)

                    if is_ngram_repeated(txnscript):
                        # print("hhhhhhallucination found:::: ", txnscript)
                        # print(
                        #     "try to use the original transcript from last active stream"
                        # )
                        txnscript = full_transcript

                # wav_path = f"/opt/exp/debug_{sock_num}_{seg_id}.wav"
                # sf.write(wav_path, pcm_data, 16000)

                # print(
                #     "\n{}. CORRECT-STREAMM :::: {} -> {} || {} ==> {} ".format(
                #         seg_id,
                #         start_time,
                #         end_time,
                #         start_time_offset_in_bytearray,
                #         end_time_offset_in_bytearray,
                #     ),
                #     txnscript,
                # )

                # keep the rest of info same, except the transcript is corrected
                # print("\nbefore put: ", results_correction_queue.qsize())
                AudioWebMHandler.audio_streams[sock_num].handle_whisper_result(
                    (asr_result, final_seg, num_tokens, txnscript),
                    seg_id,
                    True,
                )

            except Exception as ex:
                print(ex)

    async def compare_amplitude(self):
        chunk_size = int(SAMPLING_RATE * 0.1)  # Minimum chunk size in samples

        NUM_OF_STRIDE_IN_THE_PAST = 20
        WINDOW_TO_COMPARE = NUM_OF_STRIDE_IN_THE_PAST * chunk_size

        # start to process at the 0 index
        offset_processing = 0
        offset_end_processing = 0

        while self.start_processing_asr:
            if self.unloaded:
                offset_processing -= MultiStreamHandler.UNLOADED_BUFFER
                offset_end_processing -= MultiStreamHandler.UNLOADED_BUFFER
                self.unloaded = False
            # Wait until enough audio data is accumulated
            while any(
                len(audio[offset_processing:]) < chunk_size
                for audio in MultiStreamHandler.AUDIO_BUFFERS
            ):
                await asyncio.sleep(0.1)

                if self.unloaded:
                    offset_processing -= MultiStreamHandler.UNLOADED_BUFFER
                    offset_end_processing -= MultiStreamHandler.UNLOADED_BUFFER
                    self.unloaded = False

                if not self.start_processing_asr:
                    break

            # read until chunk size
            offset_end_processing += chunk_size

            # add on logic:
            # instead of compare chunk_size x1, try to compare chunk_size x N in the past, minimum is 0 to avoid negative
            modified_offset_processing = max(offset_processing - WINDOW_TO_COMPARE, 0)

            db_list = [
                calculate_dB(
                    audio_data[modified_offset_processing:offset_end_processing]
                )
                for audio_data in MultiStreamHandler.AUDIO_BUFFERS
            ]

            # find louder audio
            louder = db_list.index(max(db_list))

            for idx, stream in enumerate(AudioWebMHandler.audio_streams):
                audio_data = MultiStreamHandler.AUDIO_BUFFERS[idx][
                    offset_processing:offset_end_processing
                ]

                # make sure loudness is above threshold to eliminate background noise
                if idx == louder and db_list[louder] >= MultiStreamHandler.THRESHOLD:
                    stream.global_audio_buffer.extend(audio_data)
                else:
                    # replace softer audio with silence
                    stream.global_audio_buffer.extend(
                        [0 for _ in range(len(audio_data))]
                    )

            # after done comparing, proceed to next chunk
            offset_processing = offset_end_processing

    def unload_buffer(self):
        # keep last 30s
        audio_chunk_keep = int(30 * 16000 * 2)

        for idx, stream in enumerate(AudioWebMHandler.audio_streams):
            if len(stream.global_audio_buffer) > audio_chunk_keep:

                # get chunk to write
                audio_data = stream.global_audio_buffer[:-audio_chunk_keep]

                pcm_data = np.frombuffer(audio_data, dtype=np.int16)
                # write to disk
                filter_audio_list[idx].write(pcm_data)
                filter_audio_list[idx].flush()

                # remove chunk from filtered and original buffer
                stream.global_audio_buffer = stream.global_audio_buffer[
                    -audio_chunk_keep:
                ]
                MultiStreamHandler.AUDIO_BUFFERS[idx] = (
                    MultiStreamHandler.AUDIO_BUFFERS[idx][-audio_chunk_keep:]
                )

                stream.buffer_unloaded = len(audio_data)
                stream.unloaded = True
                MultiStreamHandler.UNLOADED_BUFFER = len(audio_data)
                if idx == 0:
                    self.total_unloaded += len(audio_data)
                self.unloaded = True

                print("buffer cleared")
                gc.collect()

        self.loop.call_later(60, self.unload_buffer)


class AudioWebMHandler:
    audio_streams: list[typing_extensions.Self] = []

    def __init__(self, min_chunk: float, speaker_name: str):
        self.loop = asyncio.get_running_loop()

        self.min_chunk_size = min_chunk  # Minimum chunk size in seconds
        self.pipeline = None
        self.num_segments = 0
        self.start_processing_asr = False
        self.request_id = str(uuid.uuid4())
        self.eos = False
        self.speaker_name = speaker_name

        # storing all audio data
        self.global_audio_buffer = bytearray()
        MultiStreamHandler.AUDIO_BUFFERS.append(bytearray())

        # amount of unloaded buffer
        self.unloaded = False
        self.buffer_unloaded = 0

        # Initialize GStreamer
        Gst.init(None)
        self.create_pipeline()

        # flag to indicate to start processing audio chunks for ASR
        self.start_processing_asr = True

        # queue to store the results to be sent to client
        self.ws_send_queue = asyncio.Queue()
        AudioWebMHandler.audio_streams.append(self)
        self.wait_closed: asyncio.Future[bool] = self.loop.create_future()

        self.transcript_history: str = ""
        self.command_in_progress: bool = False
        self.sentence_count_after_activation: int = 0
        self.activation_timeout_cb: asyncio.TimerHandle | None = None

        self.build_asr_factory()

    def build_asr_factory(self) -> None:
        # Initialize ASR object
        args = get_args()
        args.sock_num = AudioWebMHandler.audio_streams.index(self)
        online = asr_factory(args)
        self.asr_processor = online
        # handle incoming audio
        self.loop.create_task(self.process_audio_task(online))

    def release_gstreamer_worker(self):
        try:
            self.eos = True
            self.num_segments = 0
            self.reset_pipeline()
            # force to stop the thread
            # print("setting start_processing_asr to False")
            self.start_processing_asr = False

        except:
            traceback.print_exc()
            traceback.print_exc(logger.info)
        finally:
            if all(stream.eos for stream in AudioWebMHandler.audio_streams):
                msg = json.dumps({"message": "process terminated"})
                self.ws_send_queue.put_nowait(msg)

    async def handle_browser_client_message(self, message: bytes | str):
        if isinstance(message, bytes):
            try:
                # logger.info("incoming message received ....")
                # Push incoming bytes to appsrc
                buf = Gst.Buffer.new_allocate(None, len(message), None)
                buf.fill(0, message)
                self.appsrc.emit("push-buffer", buf)
            except Exception as e:
                logger.error(f"Error pushing buffer to appsrc: {e}")
        elif message == "EOS":
            logger.info("Received EOS signal. >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
            print("EOS received, lets stop awhile before process the entire audio")

            # add 3s of empty btyes to audio to force whisper to finish processing last sentence
            self.global_audio_buffer.extend([0 for _ in range(3 * 16000 * 2)])
            await asyncio.sleep(5)

            try:
                logger.info("Processing the entire audio...")
                await self.process_entire_audio()
                self.release_gstreamer_worker()

            except Exception as ex:
                logger.error("error on closing EOS ...... {}".format(ex))

            self.appsrc.emit("end-of-stream")
        elif message == "load whisper":
            msg = json.dumps({"message": "whisper loaded"})
            self.ws_send_queue.put_nowait(msg)
            logger.info("Whisper model loaded and ready for decoding.")
        elif message is None:
            self.release_gstreamer_worker()
        elif message.startswith("speaker_name:"):
            name = message.split(":", 1)[1]
            self.speaker_name = name

            logger.info(f"Set speaker name to {self.speaker_name}")

    def create_pipeline(self):
        """
        Create the GStreamer pipeline for decoding, converting, and resampling WebM audio.
        This pipeline ensures audio is resampled to 16 kHz mono PCM.
        """
        try:
            self.pipeline = Gst.Pipeline()

            # Create GStreamer elements
            self.appsrc = Gst.ElementFactory.make("appsrc", "audio_input")
            self.decodebin = Gst.ElementFactory.make("decodebin", "decoder")
            self.audioconvert = Gst.ElementFactory.make("audioconvert", "converter")
            self.audioresample = Gst.ElementFactory.make("audioresample", "resampler")
            self.appsink = Gst.ElementFactory.make("appsink", "audio_output")

            if not all(
                [
                    self.appsrc,
                    self.decodebin,
                    self.audioconvert,
                    self.audioresample,
                    self.appsink,
                ]
            ):
                raise RuntimeError("Failed to create GStreamer elements.")

            # Configure appsrc
            self.appsrc.set_property("is-live", True)
            self.appsrc.set_property("format", Gst.Format.TIME)
            caps = Gst.Caps.from_string("audio/webm, rate=48000, channels=2")
            self.appsrc.set_property("caps", caps)

            # Configure appsink
            resample_caps = Gst.Caps.from_string("audio/x-raw, rate=16000, channels=1")
            self.appsink.set_property("emit-signals", True)
            self.appsink.set_property("sync", False)
            self.appsink.set_property("caps", resample_caps)
            self.appsink.connect("new-sample", self.on_new_sample)

            # Add elements to the pipeline
            self.pipeline.add(self.appsrc)
            self.pipeline.add(self.decodebin)
            self.pipeline.add(self.audioconvert)
            self.pipeline.add(self.audioresample)
            self.pipeline.add(self.appsink)

            # Link elements
            self.appsrc.link(self.decodebin)
            self.decodebin.connect("pad-added", self._on_pad_added)
            self.audioconvert.link(self.audioresample)
            self.audioresample.link(self.appsink)

            self.pipeline.set_state(Gst.State.PLAYING)
            logger.info("GStreamer pipeline created and initialized.")
        except Exception as e:
            logger.error(f"Failed to create GStreamer pipeline: {e}")
            raise

    def _on_pad_added(self, element, pad):
        """
        Handles dynamic pads from decodebin and links to audioconvert.
        """
        logger.info("Dynamic pad added from decodebin.")
        sink_pad = self.audioconvert.get_static_pad("sink")
        if not sink_pad.is_linked():
            pad.link(sink_pad)

    def reset_pipeline(self):
        """
        Resets the GStreamer pipeline to handle EOS or session termination.
        """
        logger.info("Resetting GStreamer pipeline.")
        self.pipeline.set_state(Gst.State.NULL)

        self.global_audio_buffer.clear()

        self.asr_processor.reset()
        self.asr_processor.init()

    async def process_entire_audio(self):
        audio_data = self.global_audio_buffer[:]

        try:
            # Convert buffer to NumPy array (from int16 format)
            leftover_pcm_data = np.frombuffer(audio_data, dtype=np.int16)
            stream_index = AudioWebMHandler.audio_streams.index(self)

            sfp = filter_audio_list[stream_index]

            sfp.write(leftover_pcm_data)

            sfp.flush()
            sfp.close()

            print("done saving using soundfile with 16kHz.")

            # send EOS to client
            result_json = {"status": 0, "id": self.request_id, "result": {"eos": True}}

            logger.info(f"Sending EOS to client: {result_json}")
            self.ws_send_queue.put_nowait(json.dumps(result_json))

        except:
            traceback.print_exc()

    def on_new_sample(self, sink):
        """
        Callback for new samples from the appsink.
        Processes decoded and resampled PCM audio.
        """
        sample = sink.emit("pull-sample")
        if sample:
            buf = sample.get_buffer()
            result, mapinfo = buf.map(Gst.MapFlags.READ)
            if result:
                try:
                    # Convert buffer to NumPy array (from int16 format)
                    pcm_data = np.frombuffer(mapinfo.data, dtype=np.int16)
                    if pcm_data.ndim == 2:
                        # print("split into mono")
                        pcm_data = np.mean(pcm_data, axis=1)

                    # Add processed PCM data to the buffer
                    MultiStreamHandler.AUDIO_BUFFERS[
                        AudioWebMHandler.audio_streams.index(self)
                    ].extend(pcm_data.tobytes())

                except Exception as e:
                    logger.error(f"Error processing audio sample: {e}")
                finally:
                    buf.unmap(mapinfo)

        return Gst.FlowReturn.OK

    def handle_whisper_result(self, result, seg_id=0, correction=False):
        """
        Sends ASR result to the client.
        """
        if result and len(result) >= 2:  # at least 4-element tuple expected

            # for active stream expect 4 elemnts
            transcript_with_timestamps, final_seg, num_tokens, full_transcript = result

            if len(transcript_with_timestamps) <= 0:
                return

            # the first word start time is the start of the utterance
            # the last word end time is the end of the utterance
            duration = (
                transcript_with_timestamps[-1]["end"]
                - transcript_with_timestamps[0]["start"]
            )

            segment_start = round(transcript_with_timestamps[0]["start"], 2)

            response = {
                "status": 0,
                "result": {
                    "final": final_seg,
                    "correction": correction,
                    "hypotheses": [
                        {
                            "transcript": full_transcript,
                            "transcript_with_timestamps": transcript_with_timestamps,
                        }
                    ],
                    "eos": False,
                },
                "num_tokens": num_tokens,
                "segment_start": segment_start,  # try to use the first word detected 's start time
                "segment_length": round(duration, 2),
                "segment": seg_id if correction else self.num_segments,
                "id": self.request_id,
                "speaker_name": self.speaker_name,
            }

            if final_seg and not correction:
                self.correction_segments = self.num_segments
                self.num_segments += 1
                MultiStreamHandler.CORRECTION_QUEUE.put_nowait(
                    (
                        result,
                        AudioWebMHandler.audio_streams.index(self),
                        self.correction_segments,
                    )
                )
            # logger.info(f"Sending ASR result: {response}")
            if correction:
                self.check_activation_keyword(
                    transcript_with_timestamps, self.command_in_progress
                )

            self.ws_send_queue.put_nowait(json.dumps(response))

    @staticmethod
    def process_audio_chunk(proc, pcm_data: NDArray):
        # Pass PCM data to ASR processor
        proc.insert_audio_chunk(pcm_data)

        # Process the chunk with the ASR model
        return proc.process_iter()

    async def process_audio_task(self, proc):
        """
        Processes audio data: receives PCM audio from the buffer, checks for silence,
        and sends it to the ASR processor.
        """

        # start to process at the 0 index
        offset_processing = 0
        offset_end_processing = -1

        while self.start_processing_asr:
            if self.unloaded:
                offset_processing -= self.buffer_unloaded
                self.unloaded = False

            # read until the end of the buffer, it doesnt matter if there are more than chunk_size bytes to read
            offset_end_processing = len(self.global_audio_buffer)
            # logger.info("processing the audio chunk from {} to {} >>> ".format(offset_processing,offset_end_processing))

            # Extract a chunk of the required size from the buffer
            audio_data = self.global_audio_buffer[
                offset_processing:offset_end_processing
            ]

            if not audio_data:
                await asyncio.sleep(0.1)
                continue

            # Convert raw audio to PCM (normalized to [-1, 1])
            pcm_data = (
                np.frombuffer(audio_data[:], dtype=np.int16).astype(np.float32)
                / 32768.0
            )

            # Log stats about the processed chunk
            # logger.info(f"Processing audio chunk: length={len(pcm_data)}, max_amplitude={np.max(np.abs(pcm_data))}")

            try:
                asr_result = await asyncio.to_thread(
                    AudioWebMHandler.process_audio_chunk, proc, pcm_data
                )
                self.handle_whisper_result(asr_result)

            except Exception as ex:
                traceback.print_exc()
                # ignore the error if not enough transcript to process
                pass
            # after done inference, proceed to next
            offset_processing = offset_end_processing

            # print(
            #     "checking on fglag start_processing_asr : ", self.start_processing_asr
            # )

    def check_activation_keyword(
        self,
        transcript_with_timestamps: list[dict[str, str | int]],
        is_second_trigger: bool,
    ) -> None:

        sentences: dict[int, str] = {}

        for segment in transcript_with_timestamps:
            word: str = segment["word"]
            sentence_id: int = segment["segment_id"]

            # remove commas but not whitespace
            word = word.replace(",", "")

            if sentence_id not in sentences:
                sentences[sentence_id] = word
            else:
                sentences[sentence_id] += word

        sentences_after_activation = ""

        for sentence_id, sentence in sentences.items():

            if not self.command_in_progress:
                sentence_lower = sentence.lower()

                # check if activation key in sentence
                idx = sentence_lower.find(ACTIVATION_PHRASE)
                self.command_in_progress = idx != -1

                if not self.command_in_progress:
                    continue

                print(f"Found activation phrase in sentence #{sentence_id}")

                sentences_after_activation += sentence[idx:]
                self.sentence_count_after_activation += 1
            else:
                sentences_after_activation += sentence
                self.sentence_count_after_activation += 1

                # allow max 3 segment as command
                if self.sentence_count_after_activation >= 3:
                    break

        self.transcript_history += sentences_after_activation

        full_command = self.transcript_history[:]

        if not is_second_trigger and self.command_in_progress:  # first run + activated
            self.activation_timeout_cb = self.loop.call_later(
                2.0, self.queue_command, full_command
            )

        if (
            self.sentence_count_after_activation >= 3 or is_second_trigger
        ):  # this is the second callback since activation
            if self.activation_timeout_cb is not None:
                self.activation_timeout_cb.cancel()

            self.queue_command(full_command)

    def queue_command(self, full_command: str):
        self.sentence_count_after_activation = 0
        self.command_in_progress = False
        LLMResGenerator.llm_res_queue.put_nowait(
            json.dumps({"chat_text": full_command, "role": "user"})
        )
        self.transcript_history = ""

        _ = self.loop.create_task(LLMResGenerator.generate_llm_res(full_command))


def print_my_memory_usage():
    loop = asyncio.get_running_loop()
    process = psutil.Process()
    memory_usage_bytes = process.memory_info().rss
    memory_usage_mb = memory_usage_bytes / (1024 * 1024)  # Convert bytes to megabytes

    print(f"Memory usage: {memory_usage_mb:.2f} MB")

    loop.call_later(10, print_my_memory_usage)


async def handle_request(child_conn: Connection, awh: AudioWebMHandler):
    while True:
        msg = await asyncio.to_thread(child_conn.recv)
        await awh.handle_browser_client_message(msg)


async def handle_response(child_conn: Connection, awh: AudioWebMHandler):
    while True:
        msg = await awh.ws_send_queue.get()
        child_conn.send(msg)


async def handle_llm_res(child_conn: Connection):
    while True:
        qns = await LLMResGenerator.llm_res_queue.get()
        child_conn.send(qns)


async def init_event_loop(child_conn: Connection):
    loop = asyncio.get_running_loop()
    awh = AudioWebMHandler(1.0, "Interviewer")
    _ = loop.create_task(handle_request(child_conn, awh))
    _ = loop.create_task(handle_response(child_conn, awh))
    _ = loop.create_task(handle_llm_res(child_conn))
    # start correction stream and amplitude comparison
    msh = MultiStreamHandler()
    _ = loop.create_task(msh.compare_amplitude())
    _ = loop.create_task(msh.correction_results_task())
    logging.info("correction stream started")

    await awh.wait_closed


def awh_entrypoint(child_conn: Connection):
    asyncio.run(init_event_loop(child_conn))
