import os
import pickle
import time
import numpy as np
import tiktoken
import re

from DocumentManager.Constants import DUMMY_EMBED_SHAPE, EMBED_TOKENS, NET_ACCESS, TIMESLEEP
from DocumentManager.ProgressManager import ProgressManager

MAX_TOKENS = 8192  # محدودیت مدل text-embedding-ada-002


class Document:
    CHACHE_EXTENTION = ".pkl"

    def __init__(self, path):
        self.path = path
        self.chunks = []
        self.metadata = []
        self.embeddings = None
        self.hash = None
        self.progress_manager = ProgressManager()
    
    @staticmethod
    def path2cached_path(path:str):
        return path + Document.CHACHE_EXTENTION
    
    @staticmethod
    def cached_path2path(cached_path: str):
        return cached_path[: -len(Document.CHACHE_EXTENTION)]
    

    def load_or_process(self, embedding_func, chunk_size=500, chunk_overlap_sentence=2, batch_size=32):
        cache_path = Document.path2cached_path(self.path)
        current_hash = self.compute_hash()
        
        # if catch file exist
        if os.path.exists(cache_path):
            with open(cache_path, "rb") as f:
                data = pickle.load(f)

            # if file didn't modify
            if data.get("hash") == current_hash:
                self.chunks = data["chunks"]
                self.metadata = data["metadata"]
                self.embeddings = np.array(data["embeddings"], dtype="float32")
                return 
        
        
        text, images = self.extract_content()
        # text = self.fix_farsi(text)
        self.chunks, self.metadata = self.chunk_text(text, chunk_size, chunk_overlap_sentence)
        self.embed_chunks(embedding_func)
        
        with open(cache_path, "wb") as f:
            pickle.dump({
                "hash": current_hash,
                "chunks": self.chunks,
                "metadata": self.metadata,
                "embeddings": self.embeddings
            }, f)

    def compute_hash(self):
        import hashlib
        with open(self.path, "rb") as f:
            self.hash = hashlib.md5(f.read()).hexdigest()
        return self.hash

    def extract_content(self):
        """Abstract Method"""
        raise NotImplementedError

    def clean_text(self, text):
        import re
        #! ORDERS ARE IMPOERTENT

        # convert signle newlines to single space
        text = re.sub(r'(?<!\n)\n(?!\n)', ' ', text)

        # حذف فاصله‌های اضافی بین کلمات
        text = re.sub(r'[ \t]{2,}', ' ', text)

        # بیشتر از دو newline → فقط دو تا
        text = re.sub(r'\n{3,}', '\n\n', text)

        return text.strip()
    
    def remove_repeated_lines(self, text):
        #remove repeated lines such as footer and header
        lines = text.split('\n')
        from collections import Counter

        counts = Counter(lines)

        #remove lines that are repeted more than 3. 
        cleaned = [line for line in lines if counts[line] < 3]

        return '\n'.join(cleaned)

    def fix_text(self, text):
        import arabic_reshaper
        from bidi.algorithm import get_display
        reshaped = arabic_reshaper.reshape(text)
        return get_display(reshaped)

    def chunk_text(self, text, chunk_size_words=100, overlap_sentence=2, min_chunk_words=10):
        # Split text into sentences
        sentences = re.split(r'(?<=[.!؟?])\s+', text)

        chunks = []
        metadata = []

        current_chunk = []

        for sentence in sentences:
            # Count words in current chunk + sentence
            current_word_count = sum(len(s.split()) for s in current_chunk)
            sentence_word_count = len(sentence.split())

            if current_word_count + sentence_word_count <= chunk_size_words:
                current_chunk.append(sentence)
            else:
                # Only append if chunk has enough words
                if current_word_count >= min_chunk_words:
                    chunks.append(" ".join(current_chunk))
                    metadata.append({"file": self.path})

                # Create new chunk with overlap sentences
                overlap_sentences_text = current_chunk[-overlap_sentence:] if overlap_sentence <= len(current_chunk) else current_chunk
                current_chunk = overlap_sentences_text + [sentence]

        # Append the last chunk
        if current_chunk:
            total_words = sum(len(s.split()) for s in current_chunk)
            if total_words >= min_chunk_words:
                chunks.append(" ".join(current_chunk))
                metadata.append({"file": self.path})

        return chunks, metadata
        
    
    
    

    def embed_chunks(self, embedding_func):
        """
        Embed chunks dynamically based on their token count.
        Ensures no batch exceeds the model's max token limit.
        """
        self.embeddings = []

        if NET_ACCESS:
            enc = tiktoken.encoding_for_model("text-embedding-ada-002")
        
        batch = []
        batch_tokens = 0
        total_chunk = len(self.chunks)
        completed_chunk = 0
        for i, chunk in enumerate(self.chunks):
            if NET_ACCESS:
                tokens = len(enc.encode(chunk))
            else:
                tokens = EMBED_TOKENS 
            # if chunk is bigger than MAX_TOKENS we should split chunnk
            if tokens > MAX_TOKENS * 0.9:
                # words = chunk.split()
                # approx_tokens_per_word = max(1, tokens // len(words))
                # max_words_per_subchunk = MAX_TOKENS // approx_tokens_per_word

                # for i in range(0, len(words), max_words_per_subchunk):
                #     subchunk = " ".join(words[i:i+max_words_per_subchunk])
                #     self.chunks.append(subchunk)
                continue  # ignore this chunk beacuse

            #if current batch cross MAX_TOKENS, send this batch to openai
            if batch_tokens + tokens > (MAX_TOKENS * 0.9):
                if NET_ACCESS:
                    res = embedding_func(batch, batch=True)
                else:
                    time.sleep(TIMESLEEP)
                    res = np.random.rand(*((len(batch),) + DUMMY_EMBED_SHAPE))
                if isinstance(res, np.ndarray):
                    self.embeddings.extend(res)
                else:
                    self.embeddings.extend(res if isinstance(res, list) else [res])
                
                completed_chunk += len(batch)
                batch = []
                batch_tokens = 0
                info = {
                    "title": _("مطالعه متن:"),
                    "id": "EMBEDCHUNKS",
                    "completed": completed_chunk,
                    "total": total_chunk,
                    "details":{
                        
                    }
                }
                self.progress_manager.update(info)

            # append chunk to batch
            batch.append(chunk)
            batch_tokens += tokens

        # send last batch if not empty
        if batch:
            if NET_ACCESS:
                res = embedding_func(batch, batch=True)
            else:
                time.sleep(TIMESLEEP)
                res = np.random.rand(*((len(batch),) + DUMMY_EMBED_SHAPE))
            if isinstance(res, np.ndarray):
                self.embeddings.extend(res)
            else:
                self.embeddings.extend(res if isinstance(res, list) else [res])

        self.embeddings = np.array(self.embeddings, dtype="float32")

    

    