import time
from collections import defaultdict, deque
from functools import partial

import cv2
import numpy as np


def pad_image(image: np.ndarray, target_width: int, target_height: int) -> np.ndarray:
    """
    Pad an image to match target dimensions while preserving aspect ratio.

    This function takes an input image and resizes it to fit within the specified
    target dimensions while maintaining its original aspect ratio. The resized image
    is then centered on a black canvas of the target dimensions.

    Args:
        image (np.ndarray): Input image as a NumPy array (BGR format)
        target_width (int): Desired width of the output image
        target_height (int): Desired height of the output image

    Returns:
        np.ndarray: Padded image with dimensions target_width x target_height

    Examples:
    >>> import cv2
    >>> from utils import pad_image
    >>> img = cv2.imread('input.jpg')
    >>> padded_img = pad_image(img, 1920, 1080)
    >>> cv2.imwrite('padded_output.jpg', padded_img)
    """
    # Get original dimensions
    h: int = image.shape[0]
    w: int = image.shape[1]

    # Calculate scaling factor to fit within target dimensions while preserving aspect ratio
    scale: float = min(target_width / w, target_height / h)
    new_w: int = int(w * scale)
    new_h: int = int(h * scale)

    # Resize the image
    resized: np.ndarray = cv2.resize(image, (new_w, new_h))

    # Create a black canvas of the target size
    padded: np.ndarray = np.zeros((target_height, target_width, 3), dtype=np.uint8)

    # Calculate position to center the image
    x_offset: int = (target_width - new_w) // 2
    y_offset: int = (target_height - new_h) // 2

    # Place the resized image on the canvas
    padded[y_offset : y_offset + new_h, x_offset : x_offset + new_w] = resized

    return padded


class StageTimer:
    """Keeps track of times for multiple stages in a loop over a sliding window."""

    def __init__(self, window_size=100):
        self.window_size = window_size
        self.stage_times = defaultdict(partial(deque, maxlen=window_size))
        self.loop_times = deque(maxlen=window_size)
        self.first_t = None
        self.last_t = None

    def start(self):
        """Call at the start of each loop"""
        self.first_t = time.time()
        self.last_t = time.time()

    def mark(self, stage_name):
        """Call at the end of a stage"""
        now = time.time()
        elapsed = now - self.last_t
        self.stage_times[stage_name].append(elapsed)
        self.last_t = now

    def end(self):
        """Call at the end of a loop"""
        now = time.time()
        elapsed = now - self.first_t
        self.loop_times.append(elapsed)
        self.last_t = None

    def get(self, key):
        return 1000 * sum(self.stage_times[key]) / len(self.stage_times[key])

    def mean_loop_duration(self):
        return 1000 * sum(self.loop_times) / len(self.loop_times)


timer = StageTimer(window_size=2)
if __name__ == "__main__":
    for i in range(100):
        print(i)
        timer.start()
        time.sleep(0.1)

        timer.mark("first")

        time.sleep(0.01)
        timer.mark("second")

        timer.end()

        if i % 2 == 0:
            print(
                f"Mean times: {timer.get('first'):.1f} ms ~ {timer.mean_loop_duration():.1f} ~  {1000 / timer.mean_loop_duration():.1f} FPS"
            )
