from __future__ import annotations

from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional
import threading
from collections import deque


from Copilot.numericLayer.Core.SignalInsight import SignalInsight, SignalStatus
from Copilot.numericLayer.Core.ruleHits import *
from Copilot.SituationLayer.Core.Incident import IncidentRef, IncidentSignalGroup, IncidentSeverity
from Copilot.SituationLayer.Core.IncidentEvents import IncidentEvents
from Copilot.SituationLayer.Core.MachineSituationState import MachineSituationState
from Copilot.SituationLayer.Core.machineParameterSnapshot import MachineParameterSnapshot
from Copilot.SituationLayer.Core.parmsHistory import ParameterHistory, ParameterHistoryPoint



# =========================
# STATE
# =========================



# =========================
# MANAGER
# =========================

class SituationManager:

    def __init__(
        self,
        incident_cooldown_seconds: int = 10,
        grouping_window_seconds: int = 10,
        max_group_lifetime: int = 30,
        min_warning_count: int = 2,
        max_pending_groups: int = 200,
        history_retention_seconds: int = 30,
    ):
        self.incident_cooldown = timedelta(seconds=incident_cooldown_seconds)
        self.grouping_window = timedelta(seconds=grouping_window_seconds)
        self.max_group_lifetime = timedelta(seconds=max_group_lifetime)

        self.min_warning_count = min_warning_count
        self.max_pending_groups = max_pending_groups
        self.history_retention_seconds = history_retention_seconds
        self.history_retention = timedelta( seconds=history_retention_seconds)

        self.machine_states: Dict[int, MachineSituationState] = {}
        self._lock = threading.Lock()

    # =========================
    # ENTRY
    # =========================

    def add_to_history(self, signal: SignalInsight) -> None:
        with self._lock:

            state = self._get_or_create_state(signal.machine.id)
            param_id = signal.parameter.id

            # update latest snapshot
            state.latest_signals_insight[param_id] = signal

            # ensure history container exists
            if param_id not in state.history_by_parameter:
                state.history_by_parameter[param_id] = ParameterHistory(
                    parameter=signal.parameter,
                    samples=[]
                )

            history = state.history_by_parameter[param_id]

            history.samples.append(
                ParameterHistoryPoint(
                    timestamp=signal.timestamp,
                    value=signal.value,
                    status=signal.status,
                    trend=signal.trend,
                )
            )

            cutoff = signal.timestamp - self.history_retention

            # prune old samples
            history.samples = [
                s for s in history.samples
                if s.timestamp >= cutoff
            ]

    def process(self, signal: SignalInsight) -> None:
        with self._lock:
            state = self._get_or_create_state(signal.machine.id)
            param_id = signal.parameter.id

            # always update snapshot

            # active alerts only anomaly
            if signal.status in (SignalStatus.WARNING, SignalStatus.CRITICAL):
                state.active_alerts[param_id] = signal
            else:
                state.active_alerts.pop(param_id, None)

            # NORMAL should NOT affect grouping
            if signal.status == SignalStatus.NORMAL:
                return

            self._handle_grouping(state, signal)

    # =========================
    # GROUPING
    # =========================

    def _handle_grouping(self, state: MachineSituationState, signal: SignalInsight):

        now = signal.timestamp

        if state.open_group is None:
            state.open_group = IncidentSignalGroup(
                started_at=now,
                last_signal_at=now,
            )
            state.open_group.add(signal)
            return

        # update lifetime safety
        if now - state.open_group.started_at > self.max_group_lifetime:
            self._close_group(state)
            state.open_group = IncidentSignalGroup(
                started_at=now,
                last_signal_at=now,
            )
            state.open_group.add(signal)
            return

        # normal grouping window
        if now - state.open_group.last_signal_at <= self.grouping_window:
            state.open_group.add(signal)
            return

        # gap → close & reopen
        self._close_group(state)
        state.open_group = IncidentSignalGroup(
            started_at=now,
            last_signal_at=now,
        )
        state.open_group.add(signal)

    def _close_group(self, state: MachineSituationState):
        if state.open_group is None:
            return

        if len(state.pending_groups) >= self.max_pending_groups:
            state.pending_groups.pop(0)

        state.pending_groups.append(state.open_group)
        state.open_group = None

    # =========================
    # FLUSH
    # =========================

    def flush_stale_groups(self, now: datetime) -> List[IncidentRef]:
        with self._lock:
            incidents: List[IncidentRef] = []

            for state in self.machine_states.values():

                # IMPORTANT: force close open_group if expired
                if state.open_group:
                    if self.should_close_open_group(state, now):
                        self._close_group(state)

                remaining: List[IncidentSignalGroup] = []

                for group in state.pending_groups:

                    if not self._group_qualifies(group):
                        continue

                    incident_type = self._derive_group_incident_type(group.signals)
                    end_time = group.last_signal_at

                    if not self._cooldown_ok(state, incident_type, end_time):
                        remaining.append(group)
                        continue

                    incident = self._build_incident_from_group(state, group)
                    incidents.append(incident)

                    state.last_emitted_incident_by_type[incident_type] = end_time

                state.pending_groups = remaining

            return incidents

    # =========================
    # FILTER
    # =========================

    def _group_qualifies(self, group: IncidentSignalGroup) -> bool:
        critical = sum(1 for s in group.signals if s.status == SignalStatus.CRITICAL)
        warning = sum(1 for s in group.signals if s.status == SignalStatus.WARNING)

        return critical >= 1 or warning >= self.min_warning_count

    # =========================
    # COOLDOWN
    # =========================

    def _cooldown_ok(self, state, incident_type: str, timestamp: datetime) -> bool:
        last = state.last_emitted_incident_by_type.get(incident_type)
        if last is None:
            return True
        return (timestamp - last) >= self.incident_cooldown

    # =========================
    # INCIDENT BUILD
    # =========================

    def pick_trigger(self, signals):
        for s in signals:
            if s.status == SignalStatus.CRITICAL:
                return s
        for s in signals:
            if s.status == SignalStatus.WARNING:
                return s
        return signals[-1]
    
    def _build_parameter_histories(
            self,
            state:MachineSituationState,
            active_signals: List[SignalInsight],
            incident_time: datetime,
        ) -> List[ParameterHistory]:

        histories: List[ParameterHistory] = []

        cutoff = incident_time - self.history_retention

        for signal in active_signals:
            param_id = signal.parameter.id

            history = state.history_by_parameter.get(param_id)

            if not history or not history.samples:
                continue

            samples = [
                s for s in history.samples
                if s.timestamp >= cutoff
            ]

            if not samples:
                continue

            histories.append(
                ParameterHistory(
                    parameter=history.parameter,
                    samples=samples
                )
            )

        return histories

    def _build_incident_from_group(self, state, group: IncidentSignalGroup) -> IncidentRef:

        signals = sorted(group.signals, key=lambda s: s.timestamp)
        trigger = self.pick_trigger(signals)

        incident_type = self._derive_group_incident_type(signals)
        severity = self._calculate_group_severity(signals)

        snapshot = self._build_machine_snapshot(state)
        histories = self._build_parameter_histories( state, signals, group.last_signal_at)

        return IncidentRef(
            incident_type=incident_type,
            severity=severity,
            machine=trigger.machine,
            trigger_parameter=trigger.parameter,
            trigger_value=trigger.value,
            trigger_timestamp=trigger.timestamp,
            trigger_rules=trigger.rule_hits,
            summary=f"{incident_type} detected on {trigger.machine.display_name} ({len(signals)} signals)",
            active_signals=signals,
            machine_snapshot=snapshot,
            parameter_histories=histories,
            metadata={
                "start": group.started_at.isoformat(),
                "end": group.last_signal_at.isoformat(),
                "count": len(signals),
                "history_window_seconds": self.history_retention_seconds,
            },
        )

    # =========================
    # SNAPSHOT
    # =========================

    def _build_machine_snapshot(self, state):

        return sorted(
            [
                MachineParameterSnapshot(
                    parameter=s.parameter,
                    value=s.value,
                    timestamp=s.timestamp,
                    trend=s.trend,
                    status=s.status,
                )
                for s in state.latest_signals_insight.values()
            ],
            key=lambda x: x.parameter.id
        )

    # =========================
    # INCIDENT TYPE
    # =========================

    def _derive_group_incident_type(self, signals):

        rules = set()
        trends = []

        for s in signals:
            rules.update(s.rule_hits)
            trends.append(s.trend)

        if ThreshRuleHits.ABOVE_MAX in rules or ThreshRuleHits.BELOW_MIN in rules:
            return IncidentEvents.THRESHOLD_BREACH

        if OtherRuleHits.SUDDEN_JUMP in rules:
            return IncidentEvents.INSTABILITY

        if TrendRuleHits.OSCILLATING in trends:
            return IncidentEvents.OSCILLATION

        return IncidentEvents.MACHINE_ANOMALY

    # =========================
    # SEVERITY
    # =========================

    def _calculate_group_severity(self, signals):

        critical = sum(1 for s in signals if s.status == SignalStatus.CRITICAL)
        warning = sum(1 for s in signals if s.status == SignalStatus.WARNING)

        if critical >= 2:
            return IncidentSeverity.HIGH

        if critical >= 1:
            return IncidentSeverity.MEDIUM

        if warning >= 3:
            return IncidentSeverity.MEDIUM

        return IncidentSeverity.LOW

    # =========================
    # STATE
    # =========================

    def _get_or_create_state(self, machine_id: int) -> MachineSituationState:
        if machine_id not in self.machine_states:
            self.machine_states[machine_id] = MachineSituationState(machine_id=machine_id)
        return self.machine_states[machine_id]
    
    def should_close_open_group(self, state: MachineSituationState, now: datetime) -> bool:
        g = state.open_group
        if not g:
            return False

        silent_too_long = (now - g.last_signal_at > self.grouping_window)
        too_old = (now - g.started_at > self.max_group_lifetime)

        return silent_too_long or too_old