from collections import defaultdict

from Copilot.numericLayer.Filters.Signal.signalBaseFilter import baseSignalFilter    
from Copilot.numericLayer.Core.SignalInsight import SignalInsight
from Copilot.numericLayer.Core.ruleHits import ThreshRuleHits
from Copilot.numericLayer.Filters.Rule.ruleBaseFilter import BaseRuleFilter

class ConsecutiveRuleFilter(BaseRuleFilter):
    def __init__(self, min_occurrences=3):

        self.min_occurrences = min_occurrences

        self.target_rules = {
            ThreshRuleHits.ABOVE_MAX,
            ThreshRuleHits.BELOW_MIN,
        }

        self.counts = defaultdict(int)

    def apply(self, signal:SignalInsight):

        machine_id = signal.machine.id
        param_id = signal.parameter.id

        filtered_rules = []

        for rule in signal.rule_hits:

            if rule not in self.target_rules:
                filtered_rules.append(rule)
                continue

            key = (machine_id, param_id, rule)

            self.counts[key] += 1

            if self.counts[key] >= self.min_occurrences:
                filtered_rules.append(rule)

        # reset rules that disappeared
        for rule in self.target_rules:

            if rule not in signal.rule_hits:
                key = (machine_id, param_id, rule)
                self.counts[key] = 0

        filtered_rules

        return filtered_rules



    



class CooldownFilter(BaseRuleFilter):

    def __init__(self, cooldown_seconds=5):
        self.cooldown_seconds = cooldown_seconds
        self.last_emit = {}

    def apply(self, signal: SignalInsight):

        filtered_rules = []

        for rule in signal.rule_hits:

            key = (
                signal.machine.id,
                signal.parameter.id,
                rule
            )

            last = self.last_emit.get(key)

            if last is None:
                self.last_emit[key] = signal.timestamp
                filtered_rules.append(rule)
                continue

            delta = (signal.timestamp - last).total_seconds()

            if delta >= self.cooldown_seconds:
                self.last_emit[key] = signal.timestamp
                filtered_rules.append(rule)

        filtered_rules

        return filtered_rules