import os
import pandas as pd
import numpy as np

# ================= 配置区域 =================
RESULT_DIR = './result'   # 结果文件夹路径
TOP_N_WORDS = 10          # 计算每个主题前 10 个词的多样性
# ===========================================

def calculate_td():
    print("📊 正在初始化 Topic Diversity (主题多样性) 计算程序...")
    
    # 1. 检查结果目录
    if not os.path.exists(RESULT_DIR):
        print(f"❌ 错误：找不到 {RESULT_DIR} 文件夹。请先运行 main.py！")
        return

    # 2. 查找所有主题文件
    model_files = [f for f in os.listdir(RESULT_DIR) if f.endswith('_topics.csv')]
    
    if not model_files:
        print("❌ result 文件夹中没有找到 _topics.csv 文件。")
        return

    print(f"\n🚀 开始计算 {len(model_files)} 个模型的多样性得分...\n")
    print(f"{'模型名称':<15} | {'主题数':<6} | {'多样性得分 (越高越好)':<20}")
    print("-" * 55)

    results = []

    for file_name in model_files:
        model_name = file_name.replace('_topics.csv', '')
        file_path = os.path.join(RESULT_DIR, file_name)
        
        try:
            # 读取 CSV
            df = pd.read_csv(file_path)
            
            # 预处理：填充空值，强制转为字符串，并统一转小写（避免 China 和 china 被算作两个词）
            # 注意：这里的 fillna('') 防止读取到 nan 报错
            words_series = df['Words'].fillna('').astype(str).str.lower()
            
            all_top_words = []
            
            # 收集所有主题的前 N 个词
            for words_str in words_series:
                # 按空格切分，取前 N 个
                words = words_str.split()[:TOP_N_WORDS]
                if words:
                    all_top_words.extend(words)
            
            # 计算多样性
            total_words = len(all_top_words)
            unique_words = len(set(all_top_words))
            
            if total_words == 0:
                score = 0.0
                print(f"{model_name:<15} | 0      | 0.0000 (无有效词)")
            else:
                # 核心公式：唯一词数量 / 总词数量
                # 1.0 代表所有主题的词都互不相同（多样性最高）
                # 0.0 代表所有词都一样（多样性最低）
                score = unique_words / total_words
                print(f"{model_name:<15} | {len(words_series):<6} | {score:.4f}")
            
            results.append({'Model': model_name, 'Diversity': score})

        except Exception as e:
            print(f"{model_name:<15} | Error  | 计算失败: {str(e)[:20]}")

    print("-" * 55)
    
    # 打印最佳模型
    if results:
        best_model = max(results, key=lambda x: x['Diversity'])
        print(f"\n🏆 多样性最佳模型: {best_model['Model']} (Score: {best_model['Diversity']:.4f})")
    else:
        print("\n⚠️ 没有成功计算任何模型。")

if __name__ == "__main__":
    calculate_td()