首页
⬇️ 下载
# plot.py 
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

CUSTOM_JS_URL = 'https://mirrors.sustech.edu.cn/cdnjs/ajax/libs/plotly.js/2.27.1/plotly.min.js' 

RESULT_DIR = './result'
PIC_DIR = os.path.join(RESULT_DIR, 'pic')
# =============================================

def init_dir():
    if not os.path.exists(PIC_DIR): os.makedirs(PIC_DIR)

def parse_words_simple(words_str):
    """
    解析空格分隔的字符串 "word1 word2 word3"
    """
    if not isinstance(words_str, str):
        return [], []
    
    words = words_str.split()[:10] # 只取前10个
    scores = list(range(len(words), 0, -1)) # 虚拟权重
    
    return words, scores

def enrich_topics_with_counts(df_topics, topics_file_path):
    """
    核心修复函数:如果 topics 表里没有 Count 列,
    就去读取对应的 docs 表,现场统计每个主题有多少文章。
    """
    # 如果已经有 Count 列,直接返回
    if 'Count' in df_topics.columns:
        return df_topics

    print(f"   Note: 主题表中缺少 'Count' 列,正在从文档表中计算...")
    
    # 推断 _docs.csv 的路径
    docs_file_path = topics_file_path.replace('_topics.csv', '_docs.csv')
    
    if not os.path.exists(docs_file_path):
        print(f"   ⚠️ 警告: 找不到 {docs_file_path},无法计算主题热度。将默认设为 1。")
        df_topics['Count'] = 1
        return df_topics
    
    try:
        # 读取文档表
        df_docs = pd.read_csv(docs_file_path)
        
        # 统计每个 Topic 出现的次数
        # value_counts 返回 Series,index是Topic,value是次数
        topic_counts = df_docs['Topic'].value_counts().reset_index()
        topic_counts.columns = ['Topic', 'Count'] # 重命名方便合并
        
        # 合并到主题表
        # how='left' 保证即使某个主题没有文章(count=0)也能保留下来
        df_merged = pd.merge(df_topics, topic_counts, on='Topic', how='left')
        
        # 填充 NaN 为 0
        df_merged['Count'] = df_merged['Count'].fillna(0).astype(int)
        
        return df_merged
        
    except Exception as e:
        print(f"   ⚠️ 计算 Count 失败: {e}")
        df_topics['Count'] = 1
        return df_topics

def plot_model_results(file_path):
    filename = os.path.basename(file_path)
    model_name = filename.replace('_topics.csv', '')
    
    print(f"🎨 正在绘制: {model_name} ...")
    
    try:
        df = pd.read_csv(file_path)
    except Exception as e:
        print(f"   ⚠️ 读取失败: {e}")
        return

    # 预处理:填充空值
    df['Words'] = df['Words'].fillna('')
    
    # 自动补全 Count 列
    df = enrich_topics_with_counts(df, file_path)

    # 过滤掉 Count 为 0 的行
    if 'Count' in df.columns:
        df = df[df['Count'] > 0]

    # ================= 1. 绘制分布饼图 =================
    df_pie = df.copy()
    
    # 限制饼图显示的切片数量 (取前 15 个大主题 + "其他")
    if len(df_pie) > 15:
        top_15 = df_pie.nlargest(15, 'Count')
        other_count = df_pie['Count'].sum() - top_15['Count'].sum()
        other_row = pd.DataFrame([{'Topic': 'Others', 'Count': other_count, 'Words': '...'}])
        df_pie = pd.concat([top_15, other_row], ignore_index=True)

    try:
        fig_pie = px.pie(
            df_pie, 
            values='Count', 
            names='Topic', 
            title=f'<b>{model_name}</b> - 主题文档数量分布',
            hover_data=['Words'],
            hole=0.3
        )
        fig_pie.update_traces(textposition='inside', textinfo='percent+label')
        fig_pie.write_html(f"{PIC_DIR}/{model_name}_分布饼图.html", include_plotlyjs=CUSTOM_JS_URL)
    except Exception as e:
        print(f"   ⚠️ 饼图绘制失败: {e}")
    
    # ================= 2. 绘制核心词汇条形图 =================
    valid_df = df[~df['Topic'].isin([-1, 'Others'])].copy()
    valid_df['Count'] = pd.to_numeric(valid_df['Count'], errors='coerce').fillna(0)
    
    top_topics = valid_df.sort_values('Count', ascending=False).head(6)
    
    if top_topics.empty:
        print("   ⚠️ 无有效主题,跳过条形图")
        return

    n_plots = len(top_topics)
    rows = (n_plots + 2) // 3
    cols = min(n_plots, 3)
    
    subplot_titles = [f"Topic {row['Topic']} (n={row['Count']})" for _, row in top_topics.iterrows()]
    
    fig_bar = make_subplots(
        rows=rows, cols=cols, 
        subplot_titles=subplot_titles,
        horizontal_spacing=0.1,
        vertical_spacing=0.15
    )
    
    for i, (idx, row) in enumerate(top_topics.iterrows()):
        words, scores = parse_words_simple(row['Words'])
        if not words: continue
        
        r, c = (i // 3) + 1, (i % 3) + 1
        
        fig_bar.add_trace(
            go.Bar(
                x=scores[::-1], 
                y=words[::-1], 
                orientation='h',
                marker=dict(color=scores[::-1], colorscale='Viridis'),
                name=f"Topic {row['Topic']}"
            ), 
            row=r, col=c
        )

    fig_bar.update_layout(
        height=400 * max(rows, 1), # 动态调整高度
        width=1200, 
        title_text=f"<b>{model_name}</b> - 核心关键词 (Top {n_plots} Topics)", 
        showlegend=False,
        template="plotly_white"
    )
    
    fig_bar.write_html(f"{PIC_DIR}/{model_name}_核心词图.html", include_plotlyjs=CUSTOM_JS_URL)
    print(f"   ✅ 图表已保存至 {PIC_DIR}")

def main():
    init_dir()
    csv_files = [f for f in os.listdir(RESULT_DIR) if f.endswith('_topics.csv')]
    
    if not csv_files:
        print(f"❌ 在 {RESULT_DIR} 中未找到 *_topics.csv 文件。请先运行 main.py。")
        return

    print(f"📂 找到 {len(csv_files)} 个模型结果文件,开始绘图...")
    for file in csv_files:
        try:
            plot_model_results(os.path.join(RESULT_DIR, file))
        except Exception as e:
            print(f"⚠️ {file} 绘制异常: {e}")

    print(f"\n✨ 全部绘图完成!请查看文件夹: {PIC_DIR}")

if __name__ == "__main__":
    main()
直链已复制!
可使用 wget 直接下载