import os
import pandas as pd
import pickle
import numpy as np
import warnings
from gensim.corpora import Dictionary
from gensim.models import CoherenceModel

warnings.filterwarnings("ignore")

# ================= 配置 =================
RESULT_DIR = './result'
TOP_N_WORDS = 10 
# =======================================

def calculate_tc():
    print("📊 初始化 Coherence (NPMI) ...")

    # 1. 加载参考语料
    corpus_path = os.path.join(RESULT_DIR, 'aligned_segmented_texts.pkl')
    if not os.path.exists(corpus_path):
        print(f"❌ 找不到 {corpus_path}")
        return

    with open(corpus_path, 'rb') as f:
        texts = pickle.load(f)
    
    # 【关键修改】构建词典时，强制转为小写，确保和模型输出一致
    # 这样 'China' (语料) 和 'china' (模型) 就能匹配上了
    print("   -> 构建词典 (自动转小写)...")
    tokenized_texts = [str(text).lower().split() for text in texts]
    dictionary = Dictionary(tokenized_texts)
    
    # 使用 set 加速查找
    vocab_set = set(dictionary.token2id.keys())
    print(f"   -> 词典大小: {len(vocab_set)}")

    # 2. 计算模型
    model_files = [f for f in os.listdir(RESULT_DIR) if f.endswith('_topics.csv')]
    
    print(f"\n{'模型名称':<15} | {'有效度':<8} | {'NPMI 得分':<15} | {'状态'}")
    print("-" * 65)

    results = []

    for file_name in model_files:
        model_name = file_name.replace('_topics.csv', '')
        file_path = os.path.join(RESULT_DIR, file_name)
        
        try:
            df = pd.read_csv(file_path)
            # 强制转字符串并转小写
            raw_words_series = df['Words'].fillna('').astype(str).str.lower()
            
            topics = []
            
            # 诊断计数器
            total_words_count = 0
            matched_words_count = 0
            
            for words_str in raw_words_series:
                words = words_str.split()[:TOP_N_WORDS]
                if not words: continue
                
                total_words_count += len(words)
                
                # 过滤：只保留词典里有的词
                valid_words = [w for w in words if w in vocab_set]
                matched_words_count += len(valid_words)
                
                # 只有当剩下的有效词 >= 2 个时，才能计算分数
                if len(valid_words) >= 2:
                    topics.append(valid_words)

            # 计算模型整体的词汇匹配度
            match_rate = matched_words_count / total_words_count if total_words_count > 0 else 0
            
            # 如果有效主题太少，直接跳过
            if not topics:
                print(f"{model_name:<15} | {match_rate:.0%}      | nan             | ⚠️ 词汇完全不匹配")
                # 打印诊断：看看它是啥词匹配不上
                sample_words = raw_words_series.iloc[0].split()[:5]
                print(f"   (诊断: 模型输出了 {sample_words}，但词典里没有)")
                continue

            # 计算分数
            cm = CoherenceModel(
                topics=topics, 
                texts=tokenized_texts, 
                dictionary=dictionary, 
                coherence='c_npmi',
                processes=1
            )
            score = cm.get_coherence()
            
            print(f"{model_name:<15} | {match_rate:.0%}      | {score:.4f}          | ✅ 正常")
            results.append({'Model': model_name, 'NPMI': score})
            
        except Exception as e:
            print(f"{model_name:<15} | -        | error           | {str(e)[:20]}")

    print("-" * 65)
    if results:
        best = max(results, key=lambda x: x['NPMI'])
        print(f"\n🏆 最佳模型: {best['Model']} (NPMI: {best['NPMI']:.4f})")

if __name__ == "__main__":
    calculate_tc()