"""Tests for periodic trigger system.

Tests for the PeriodicTrigger and TriggerCondition classes that provide
action execution during navigation operations.
"""

from __future__ import annotations

import asyncio
from unittest.mock import MagicMock, AsyncMock, patch

import pytest

from executor.triggers import (
    PeriodicTrigger,
    TriggerCondition,
)
from executor.action_executor import (
    ActionExecutor,
    NavState,
    NavigationContext,
)


# =============================================================================
# Test TriggerCondition
# =============================================================================


class TestTriggerCondition:
    """Test TriggerCondition class."""

    def test_interval_based_condition(self):
        """Test creating time-based trigger condition."""
        condition = TriggerCondition(interval_sec=2.0)

        assert condition.interval_sec == 2.0
        assert condition.progress_milestones == set()
        assert condition.while_navigating is True

    def test_progress_milestones_condition(self):
        """Test creating progress-based trigger condition."""
        condition = TriggerCondition(
            progress_milestones=[0.25, 0.5, 0.75, 1.0]
        )

        assert condition.interval_sec is None
        assert condition.progress_milestones == {0.25, 0.5, 0.75, 1.0}
        assert condition.while_navigating is True

    def test_combined_condition(self):
        """Test condition with both interval and milestones."""
        condition = TriggerCondition(
            interval_sec=1.0,
            progress_milestones=[0.5, 1.0],
        )

        assert condition.interval_sec == 1.0
        assert condition.progress_milestones == {0.5, 1.0}

    def test_should_fire_time_based(self):
        """Test time-based firing condition."""
        condition = TriggerCondition(interval_sec=2.0)

        # Not enough time elapsed
        assert not condition.should_fire_time(1.0)

        # Exactly at interval
        assert condition.should_fire_time(2.0)

        # Past interval
        assert condition.should_fire_time(3.0)

    def test_should_fire_progress_based(self):
        """Test progress-based firing condition."""
        condition = TriggerCondition(
            progress_milestones=[0.25, 0.5, 0.75]
        )

        # Before first milestone
        assert not condition.should_fire_progress(0.1)

        # At first milestone
        assert condition.should_fire_progress(0.25)

        # After first milestone but before marking as fired
        # Still returns True because we haven't fired for 0.25 yet
        assert condition.should_fire_progress(0.3)

        # After marking milestone as fired
        condition.mark_milestone_fired(0.3)
        # Now between milestones, should not fire
        assert not condition.should_fire_progress(0.4)

        # At second milestone
        assert condition.should_fire_progress(0.5)

    def test_mark_milestone_fired(self):
        """Test marking milestones as fired."""
        condition = TriggerCondition(
            progress_milestones=[0.25, 0.5, 0.75]
        )

        assert condition.should_fire_progress(0.25)
        condition.mark_milestone_fired(0.25)

        # Should not fire again for same milestone
        assert not condition.should_fire_progress(0.25)

        # Next milestone still available
        assert condition.should_fire_progress(0.5)

    def test_reset_clears_fired_milestones(self):
        """Test reset clears fired milestone tracking."""
        condition = TriggerCondition(
            progress_milestones=[0.25, 0.5, 0.75]
        )

        condition.mark_milestone_fired(0.25)
        condition.mark_milestone_fired(0.5)

        assert not condition.should_fire_progress(0.25)
        assert not condition.should_fire_progress(0.5)

        condition.reset()

        # All milestones available again
        assert condition.should_fire_progress(0.25)
        assert condition.should_fire_progress(0.5)

    def test_while_navigating_flag(self):
        """Test while_navigating flag."""
        condition = TriggerCondition(
            interval_sec=1.0,
            while_navigating=False,
        )

        assert condition.while_navigating is False


# =============================================================================
# Test PeriodicTrigger
# =============================================================================


@pytest.mark.asyncio
class TestPeriodicTrigger:
    """Test PeriodicTrigger dataclass."""

    async def test_trigger_creation(self):
        """Test creating a periodic trigger."""
        condition = TriggerCondition(interval_sec=1.0)
        fire_count = []

        def on_fire():
            fire_count.append(1)

        trigger = PeriodicTrigger(
            action_id="SCAN_AREA",
            condition=condition,
            on_fire=on_fire,
            logger=None,
        )

        assert trigger.action_id == "SCAN_AREA"
        assert trigger.condition is condition
        assert trigger.on_fire is on_fire
        assert not trigger._running

    async def test_start_stop_trigger(self):
        """Test starting and stopping a trigger."""
        condition = TriggerCondition(interval_sec=0.2)
        fire_count = []

        def on_fire():
            fire_count.append(1)

        trigger = PeriodicTrigger(
            action_id="TEST_ACTION",
            condition=condition,
            on_fire=on_fire,
            logger=None,
        )

        trigger.start()
        assert trigger._running is not None

        # Wait for trigger to fire
        await asyncio.sleep(0.3)

        trigger.stop()
        assert len(fire_count) >= 1

    async def test_trigger_with_progress_milestones(self):
        """Test trigger with progress milestones."""
        condition = TriggerCondition(
            progress_milestones=[0.25, 0.5, 0.75, 1.0]
        )
        fire_count = []

        def on_fire():
            fire_count.append(1)

        trigger = PeriodicTrigger(
            action_id="PROGRESS_TRIGGER",
            condition=condition,
            on_fire=on_fire,
            logger=None,
        )

        trigger.start()

        # Update progress through milestones
        trigger.update_progress(0.25)
        await asyncio.sleep(0.2)

        trigger.update_progress(0.5)
        await asyncio.sleep(0.2)

        trigger.update_progress(1.0)
        await asyncio.sleep(0.2)

        trigger.stop()

        # Should have fired at least twice
        assert len(fire_count) >= 2

    async def test_on_navigation_complete(self):
        """Test on_navigation_complete stops trigger."""
        condition = TriggerCondition(interval_sec=0.1)
        fire_count = []

        def on_fire():
            fire_count.append(1)

        trigger = PeriodicTrigger(
            action_id="NAV_TRIGGER",
            condition=condition,
            on_fire=on_fire,
            logger=None,
        )

        trigger.start()
        await asyncio.sleep(0.2)

        initial_count = len(fire_count)
        trigger.on_navigation_complete()

        # Wait and verify no more fires
        await asyncio.sleep(0.2)
        assert len(fire_count) == initial_count

    async def test_update_progress(self):
        """Test update_progress method."""
        condition = TriggerCondition(
            progress_milestones=[0.5, 1.0]
        )
        fire_count = []

        def on_fire():
            fire_count.append(1)

        trigger = PeriodicTrigger(
            action_id="PROGRESS_TRIGGER",
            condition=condition,
            on_fire=on_fire,
            logger=None,
        )

        trigger.start()

        trigger.update_progress(0.3)
        await asyncio.sleep(0.15)

        trigger.update_progress(0.6)  # Pass 0.5 milestone
        await asyncio.sleep(0.15)

        trigger.stop()

        # Should have fired at 0.5 milestone
        assert len(fire_count) >= 1

    async def test_trigger_with_async_callback(self):
        """Test trigger with async callback."""
        condition = TriggerCondition(interval_sec=0.1)
        fire_count = []

        async def on_fire():
            fire_count.append(1)
            await asyncio.sleep(0.01)  # Simulate async work

        trigger = PeriodicTrigger(
            action_id="ASYNC_TRIGGER",
            condition=condition,
            on_fire=on_fire,
            logger=None,
        )

        trigger.start()
        await asyncio.sleep(0.3)
        trigger.stop()

        assert len(fire_count) >= 1

    async def test_multiple_triggers_independent(self):
        """Test multiple triggers run independently."""
        condition1 = TriggerCondition(interval_sec=0.1)
        condition2 = TriggerCondition(interval_sec=0.2)

        fire_count1 = []
        fire_count2 = []

        trigger1 = PeriodicTrigger(
            action_id="TRIGGER_1",
            condition=condition1,
            on_fire=lambda: fire_count1.append(1),
            logger=None,
        )

        trigger2 = PeriodicTrigger(
            action_id="TRIGGER_2",
            condition=condition2,
            on_fire=lambda: fire_count2.append(1),
            logger=None,
        )

        trigger1.start()
        trigger2.start()

        await asyncio.sleep(0.35)

        trigger1.stop()
        trigger2.stop()

        # Trigger1 should fire more often (shorter interval)
        assert len(fire_count1) > len(fire_count2)

    async def test_multiple_start_calls_ignored(self):
        """Test multiple start calls don't create duplicate tasks."""
        condition = TriggerCondition(interval_sec=1.0)
        trigger = PeriodicTrigger(
            action_id="TEST",
            condition=condition,
            on_fire=lambda: None,
            logger=None,
        )

        trigger.start()
        first_task = trigger._task
        trigger.start()
        second_task = trigger._task

        assert first_task is second_task
        trigger.stop()


# =============================================================================
# Test ActionExecutor Integration
# =============================================================================


@pytest.mark.asyncio
class TestActionExecutorTriggerIntegration:
    """Test integration of triggers with ActionExecutor."""

    async def test_add_trigger(self):
        """Test adding a trigger to executor."""
        executor = ActionExecutor()

        condition = TriggerCondition(interval_sec=1.0)
        trigger = PeriodicTrigger(
            action_id="TEST",
            condition=condition,
            on_fire=lambda: None,
            logger=None,
        )

        executor.add_trigger(trigger)
        assert len(executor._triggers) == 1

    async def test_clear_triggers(self):
        """Test clearing all triggers."""
        executor = ActionExecutor()

        condition = TriggerCondition(interval_sec=1.0)
        trigger1 = PeriodicTrigger(
            action_id="TEST1",
            condition=condition,
            on_fire=lambda: None,
            logger=None,
        )
        trigger2 = PeriodicTrigger(
            action_id="TEST2",
            condition=condition,
            on_fire=lambda: None,
            logger=None,
        )

        executor.add_trigger(trigger1)
        executor.add_trigger(trigger2)
        assert len(executor._triggers) == 2

        executor.clear_triggers()
        assert len(executor._triggers) == 0

    async def test_triggers_start_stop_during_navigation(self):
        """Test triggers start and stop with navigation."""
        executor = ActionExecutor()

        fire_count = []

        def on_fire():
            fire_count.append(1)

        condition = TriggerCondition(interval_sec=0.1)
        trigger = PeriodicTrigger(
            action_id="NAV_SCAN",
            condition=condition,
            on_fire=on_fire,
            logger=None,
        )

        executor.add_trigger(trigger)

        # Verify trigger was added
        assert len(executor._triggers) == 1

        # Test start/stop methods work
        await executor._start_triggers()
        assert trigger._running

        await asyncio.sleep(0.15)  # Let trigger fire once
        await executor._stop_triggers()
        assert not trigger._running

        # Trigger should have fired at least once
        assert len(fire_count) >= 1

    async def test_update_trigger_progress(self):
        """Test progress updates propagate to triggers."""
        executor = ActionExecutor()

        condition = TriggerCondition(
            progress_milestones=[0.5, 1.0]
        )
        fire_count = []

        trigger = PeriodicTrigger(
            action_id="PROGRESS_TRIGGER",
            condition=condition,
            on_fire=lambda: fire_count.append(1),
            logger=None,
        )

        executor.add_trigger(trigger)
        trigger.start()

        # Update progress via executor
        await executor._update_trigger_progress(0.5)
        await asyncio.sleep(0.15)

        trigger.stop()

        # Should have fired at 0.5 milestone
        assert len(fire_count) >= 1


# =============================================================================
# Test Edge Cases
# =============================================================================


class TestTriggerEdgeCases:
    """Test edge cases and error handling."""

    def test_condition_with_no_criteria(self):
        """Test condition with neither interval nor milestones."""
        condition = TriggerCondition()

        assert condition.interval_sec is None
        assert condition.progress_milestones == set()

        # Should never fire
        assert not condition.should_fire_time(100.0)
        assert not condition.should_fire_progress(1.0)

    def test_progress_clamping(self):
        """Test progress values are clamped to 0.0-1.0."""
        condition = TriggerCondition(interval_sec=1.0)
        trigger = PeriodicTrigger(
            action_id="TEST",
            condition=condition,
            on_fire=lambda: None,
            logger=None,
        )

        # Negative progress
        trigger.update_progress(-0.5)
        assert trigger._current_progress == 0.0

        # Progress > 1.0
        trigger.update_progress(1.5)
        assert trigger._current_progress == 1.0

    def test_stop_without_start_safe(self):
        """Test stopping a trigger that was never started."""
        condition = TriggerCondition(interval_sec=1.0)
        trigger = PeriodicTrigger(
            action_id="TEST",
            condition=condition,
            on_fire=lambda: None,
            logger=None,
        )

        # Should not raise
        trigger.stop()
        trigger.on_navigation_complete()
