from __future__ import annotations

from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from statistics import mean
from typing import Dict, Deque, List, Optional

import numpy as np

from Copilot.dataAcquisition.Core.CollectedData import CollectedData, ParameterRef
from Copilot.numericLayer.Core.SignalInsight import SignalInsight, SignalStatus
from Copilot.numericLayer.Core.ruleHits import *
from Copilot.numericLayer.Analyzers.mathUtils import *


# =========================
# EMISSION STATE
# =========================

@dataclass
class SignalEmissionState:
    last_status: Optional[SignalStatus] = None
    last_value: Optional[float] = None
    last_emitted_at: Optional[datetime] = None
    last_rule_hits: set = field(default_factory=set)


# =========================
# PARAM STATE
# =========================

@dataclass
class ParameterSignalStats:
    values: Deque[float] = field(default_factory=lambda: deque(maxlen=20))
    timestamps: Deque[datetime] = field(default_factory=lambda: deque(maxlen=20))

    last_value: Optional[float] = None
    smoothed_value: Optional[float] = None

    rolling_mean: Optional[float] = None
    rolling_std: Optional[float] = None

    min_recent: Optional[float] = None
    max_recent: Optional[float] = None

    repeated_count: int = 0
    last_update: Optional[datetime] = None

    emission_state: SignalEmissionState = field(default_factory=SignalEmissionState)


# =========================
# ANALYZER
# =========================

class primeryNumericalAnalyzer:

    def __init__(
        self,
        window_size: int = 30,
        rising_ratio_threshold: float = 0.05,
        sudden_jump_ratio: float = 0.1,
        stuck_count_threshold: int = 5,
    ):
        self.window_size = window_size
        self.rising_ratio_threshold = rising_ratio_threshold
        self.sudden_jump_ratio = sudden_jump_ratio
        self.stuck_count_threshold = stuck_count_threshold

        self.parms_stats: Dict[int, Dict[str, ParameterSignalStats]] = {}

    # =========================
    # MAIN PROCESS
    # =========================

    def process(self, event: CollectedData):

        if event.machine is None or event.parameter is None:
            raise ValueError("CollectedData must have machine and parameter.")

        machine_id = event.machine.id
        p_id = event.parameter.id

        if machine_id not in self.parms_stats:
            self.parms_stats[machine_id] = {}

        if p_id not in self.parms_stats[machine_id]:
            self.parms_stats[machine_id][p_id] = ParameterSignalStats()

        state = self.parms_stats[machine_id][p_id]

        # =========================
        # snapshot update
        # =========================

        prev_value = state.last_value
        prev_std = state.rolling_std
        prev_mean = state.rolling_mean

        state.values.append(event.value)
        state.timestamps.append(event.timestamp)

        values = list(state.values)

        state.rolling_mean = mean(values)
        state.rolling_std = np.std(values)
        state.min_recent = min(values)
        state.max_recent = max(values)
        state.smoothed_value = ema(values)[-1]
        state.last_update = event.timestamp

        if prev_value is not None and nearly_equal(prev_value, event.value):
            state.repeated_count += 1
        else:
            state.repeated_count = 0

        state.last_value = event.value

        # =========================
        # RULE ENGINE
        # =========================

        rule_hits: set = set()
        status = SignalStatus.NORMAL

        min_v = event.parameter.min_value
        max_v = event.parameter.max_value

        if min_v is not None and event.value < min_v:
            rule_hits.add(ThreshRuleHits.BELOW_MIN)

        if max_v is not None and event.value > max_v:
            rule_hits.add(ThreshRuleHits.ABOVE_MAX)

        trend_threshold = self._get_dynamic_trend_threshold(values, event.parameter)
        trend = self._detect_trend(values, trend_threshold)

        if self._detect_sudden_jump(event.value, prev_value, prev_std, prev_mean) and len(values) > self.window_size*0.6:
            pass
        if self._detect_sudden_jump(event.value, prev_value, prev_std, prev_mean) and len(values) > self.window_size*0.6:
            rule_hits.add(OtherRuleHits.SUDDEN_JUMP)

        if trend in (TrendRuleHits.RISING, TrendRuleHits.FALLING):
            rule_hits.add(trend)

        # =========================
        # BUILD SIGNAL
        # =========================

        rule_hits_list = list(rule_hits)

        message = self._build_message(event, status, trend, rule_hits_list)

        signal = SignalInsight(
            parameter=event.parameter,
            machine=event.machine,
            value=event.value,
            timestamp=event.timestamp,
            status=status,
            trend=trend,
            rule_hits=rule_hits_list,
            message=message
        )

        # =========================
        # EMISSION DECISION
        # =========================


        
        em = state.emission_state
        em.last_status = status
        em.last_value = event.value
        em.last_emitted_at = event.timestamp
        em.last_rule_hits = set(rule_hits)

        return signal

    

    # =========================
    # HELPERS (same as before)
    # =========================

    def _detect_sudden_jump(self, new_value, prev_value, rolling_std, baseline, k_thresh=4):
        if prev_value is None or baseline is None or rolling_std is None:
            return False
        return abs(new_value - baseline) / (rolling_std + 1e-3) >= k_thresh

    def _get_dynamic_trend_threshold(
        self,
        window_values: List[float],
        parameter: ParameterRef,
        range_ratio=0.03,
        k=2,
        minimum_threshold=0.0015,
    ):
        param_range = 0
        if parameter.min_value is not None and parameter.max_value is not None:
            param_range = parameter.max_value - parameter.min_value

        std = np.std(window_values) if len(window_values) >= 3 else 0

        return max(param_range * range_ratio, std * k, minimum_threshold * np.mean(window_values ) * len(window_values))

    def _detect_trend(self, values: list[float], threshold: float):

        if len(values) < 5:
            return TrendRuleHits.INSUFFICIENT_DATA

        smooth = ema(values)
        slope, _ = linear_regression(smooth)

        total_delta = slope * (len(values) - 1)

        mean_value = mean(smooth)

        crossings = 0
        for i in range(1, len(smooth)):
            if (smooth[i - 1] - mean_value) * (smooth[i] - mean_value) < 0:
                crossings += 1

        if crossings >= max(4, len(values) * 0.6):
            return TrendRuleHits.OSCILLATING

        if total_delta >= threshold:
            return TrendRuleHits.RISING
        elif total_delta <= -threshold:
            return TrendRuleHits.FALLING

        return TrendRuleHits.STABLE

    def _build_message(self, event, status, trend, rule_hits):
        unit = getattr(event.parameter, "unit", "") or ""
        suffix = f" {unit}" if unit else ""

        if not rule_hits:
            return f"{event.parameter.display_name}={event.value}{suffix}, trend={trend}, status={status}"

        return (
            f"{event.parameter.display_name}={event.value}{suffix}, "
            f"trend={trend}, status={status}, rules={', '.join(rule_hits)}"
        )