# plot.py
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
CUSTOM_JS_URL = 'https://mirrors.sustech.edu.cn/cdnjs/ajax/libs/plotly.js/2.27.1/plotly.min.js'
RESULT_DIR = './result'
PIC_DIR = os.path.join(RESULT_DIR, 'pic')
# =============================================
def init_dir():
if not os.path.exists(PIC_DIR): os.makedirs(PIC_DIR)
def parse_words_simple(words_str):
"""
解析空格分隔的字符串 "word1 word2 word3"
"""
if not isinstance(words_str, str):
return [], []
words = words_str.split()[:10] # 只取前10个
scores = list(range(len(words), 0, -1)) # 虚拟权重
return words, scores
def enrich_topics_with_counts(df_topics, topics_file_path):
"""
核心修复函数:如果 topics 表里没有 Count 列,
就去读取对应的 docs 表,现场统计每个主题有多少文章。
"""
# 如果已经有 Count 列,直接返回
if 'Count' in df_topics.columns:
return df_topics
print(f" Note: 主题表中缺少 'Count' 列,正在从文档表中计算...")
# 推断 _docs.csv 的路径
docs_file_path = topics_file_path.replace('_topics.csv', '_docs.csv')
if not os.path.exists(docs_file_path):
print(f" ⚠️ 警告: 找不到 {docs_file_path},无法计算主题热度。将默认设为 1。")
df_topics['Count'] = 1
return df_topics
try:
# 读取文档表
df_docs = pd.read_csv(docs_file_path)
# 统计每个 Topic 出现的次数
# value_counts 返回 Series,index是Topic,value是次数
topic_counts = df_docs['Topic'].value_counts().reset_index()
topic_counts.columns = ['Topic', 'Count'] # 重命名方便合并
# 合并到主题表
# how='left' 保证即使某个主题没有文章(count=0)也能保留下来
df_merged = pd.merge(df_topics, topic_counts, on='Topic', how='left')
# 填充 NaN 为 0
df_merged['Count'] = df_merged['Count'].fillna(0).astype(int)
return df_merged
except Exception as e:
print(f" ⚠️ 计算 Count 失败: {e}")
df_topics['Count'] = 1
return df_topics
def plot_model_results(file_path):
filename = os.path.basename(file_path)
model_name = filename.replace('_topics.csv', '')
print(f"🎨 正在绘制: {model_name} ...")
try:
df = pd.read_csv(file_path)
except Exception as e:
print(f" ⚠️ 读取失败: {e}")
return
# 预处理:填充空值
df['Words'] = df['Words'].fillna('')
# 自动补全 Count 列
df = enrich_topics_with_counts(df, file_path)
# 过滤掉 Count 为 0 的行
if 'Count' in df.columns:
df = df[df['Count'] > 0]
# ================= 1. 绘制分布饼图 =================
df_pie = df.copy()
# 限制饼图显示的切片数量 (取前 15 个大主题 + "其他")
if len(df_pie) > 15:
top_15 = df_pie.nlargest(15, 'Count')
other_count = df_pie['Count'].sum() - top_15['Count'].sum()
other_row = pd.DataFrame([{'Topic': 'Others', 'Count': other_count, 'Words': '...'}])
df_pie = pd.concat([top_15, other_row], ignore_index=True)
try:
fig_pie = px.pie(
df_pie,
values='Count',
names='Topic',
title=f'<b>{model_name}</b> - 主题文档数量分布',
hover_data=['Words'],
hole=0.3
)
fig_pie.update_traces(textposition='inside', textinfo='percent+label')
fig_pie.write_html(f"{PIC_DIR}/{model_name}_分布饼图.html", include_plotlyjs=CUSTOM_JS_URL)
except Exception as e:
print(f" ⚠️ 饼图绘制失败: {e}")
# ================= 2. 绘制核心词汇条形图 =================
valid_df = df[~df['Topic'].isin([-1, 'Others'])].copy()
valid_df['Count'] = pd.to_numeric(valid_df['Count'], errors='coerce').fillna(0)
top_topics = valid_df.sort_values('Count', ascending=False).head(6)
if top_topics.empty:
print(" ⚠️ 无有效主题,跳过条形图")
return
n_plots = len(top_topics)
rows = (n_plots + 2) // 3
cols = min(n_plots, 3)
subplot_titles = [f"Topic {row['Topic']} (n={row['Count']})" for _, row in top_topics.iterrows()]
fig_bar = make_subplots(
rows=rows, cols=cols,
subplot_titles=subplot_titles,
horizontal_spacing=0.1,
vertical_spacing=0.15
)
for i, (idx, row) in enumerate(top_topics.iterrows()):
words, scores = parse_words_simple(row['Words'])
if not words: continue
r, c = (i // 3) + 1, (i % 3) + 1
fig_bar.add_trace(
go.Bar(
x=scores[::-1],
y=words[::-1],
orientation='h',
marker=dict(color=scores[::-1], colorscale='Viridis'),
name=f"Topic {row['Topic']}"
),
row=r, col=c
)
fig_bar.update_layout(
height=400 * max(rows, 1), # 动态调整高度
width=1200,
title_text=f"<b>{model_name}</b> - 核心关键词 (Top {n_plots} Topics)",
showlegend=False,
template="plotly_white"
)
fig_bar.write_html(f"{PIC_DIR}/{model_name}_核心词图.html", include_plotlyjs=CUSTOM_JS_URL)
print(f" ✅ 图表已保存至 {PIC_DIR}")
def main():
init_dir()
csv_files = [f for f in os.listdir(RESULT_DIR) if f.endswith('_topics.csv')]
if not csv_files:
print(f"❌ 在 {RESULT_DIR} 中未找到 *_topics.csv 文件。请先运行 main.py。")
return
print(f"📂 找到 {len(csv_files)} 个模型结果文件,开始绘图...")
for file in csv_files:
try:
plot_model_results(os.path.join(RESULT_DIR, file))
except Exception as e:
print(f"⚠️ {file} 绘制异常: {e}")
print(f"\n✨ 全部绘图完成!请查看文件夹: {PIC_DIR}")
if __name__ == "__main__":
main()