import os
import sys
import numpy as np
import pandas as pd
import jieba
# BERTopic 核心组件
from bertopic import BERTopic
from umap import UMAP
from hdbscan import HDBSCAN
from sklearn.feature_extraction.text import CountVectorizer
# ================= ⚙️ 配置区域 =================
DATA_DIR = './data'
RESULT_DIR = './result'
# 输入文件 (确保这两个文件都存在)
TEXT_FILE = os.path.join(DATA_DIR, '文本.txt')
EMBEDDING_FILE = os.path.join(DATA_DIR, 'emb_st.npy') # 使用之前脚本生成的向量
# 聚类参数 (针对 2000+ 条数据的标准设置)
UMAP_NEIGHBORS = 15 # 寻找邻居数量 (越大全局观越强,小数据设小,大数据设15)
HDBSCAN_MIN_SIZE = 15 # 最小聚类大小 (一个主题至少要有几条数据)
# ==============================================
def check_files():
if not os.path.exists(TEXT_FILE):
print(f"❌ 错误: 找不到文本文件: {TEXT_FILE}")
sys.exit(1)
if not os.path.exists(EMBEDDING_FILE):
print(f"❌ 错误: 找不到向量文件: {EMBEDDING_FILE}")
print(" 请先运行 _02embedding_sentence_transformer.py 生成向量!")
sys.exit(1)
if not os.path.exists(RESULT_DIR):
os.makedirs(RESULT_DIR)
def load_data():
print("📖 正在读取数据...")
# 读取文本
with open(TEXT_FILE, 'r', encoding='utf-8', errors='ignore') as f:
# 再次做一个简单的 strip 确保对应关系
docs = [line.strip() for line in f if len(line.strip()) > 0]
# 读取向量
embeddings = np.load(EMBEDDING_FILE)
# ⚠️ 安全检查:文本行数必须等于向量行数
if len(docs) != embeddings.shape[0]:
print(f"❌ 数据不匹配警告!")
print(f" 文本行数: {len(docs)}")
print(f" 向量行数: {embeddings.shape[0]}")
print(" 💡 请重新运行 _02 脚本,确保文本和向量是同一批生成的。")
# 如果相差不大,可以尝试切片对齐 (危险操作,建议重跑)
min_len = min(len(docs), embeddings.shape[0])
docs = docs[:min_len]
embeddings = embeddings[:min_len]
print(f" ⚠️ 已自动截断对齐至: {min_len} 条")
print(f"✅ 数据加载完毕: {len(docs)} 条")
return docs, embeddings
def main():
check_files()
docs, embeddings = load_data()
print("🚀 正在初始化 BERTopic 模型...")
# 1. 降维模型 (UMAP)
# n_neighbors=15 是标准值,适合 2000 条数据
umap_model = UMAP(
n_neighbors=UMAP_NEIGHBORS,
n_components=5,
min_dist=0.0,
metric='cosine',
random_state=42 # 固定随机种子,保证每次结果一样
)
# 2. 聚类模型 (HDBSCAN)
# min_cluster_size=15,避免生成太多只有 2-3 句话的细碎主题
hdbscan_model = HDBSCAN(
min_cluster_size=HDBSCAN_MIN_SIZE,
min_samples=5,
metric='euclidean',
prediction_data=True
)
# 3. 分词模型 (CountVectorizer)
# ⚠️ 关键:中文需要 jieba 分词,否则出来的主题全是整句
def tokenizer_zh(text):
return jieba.lcut(text)
vectorizer_model = CountVectorizer(
tokenizer=tokenizer_zh,
stop_words=['的', '了', '在', '是', '我', '也', '和', '去', '都', '我们', '就', '很', '有', '非常'], # 简单去停用词
min_df=3 # 忽略只出现过 1-2 次的生僻词
)
# 4. 组装 BERTopic
topic_model = BERTopic(
embedding_model=None, # 我们手动传入 embeddings,所以这里设为 None
umap_model=umap_model,
hdbscan_model=hdbscan_model,
vectorizer_model=vectorizer_model,
verbose=True
)
# 5. 开始训练
print("🔄 开始聚类 (Fit Transform)...")
topics, probs = topic_model.fit_transform(docs, embeddings=embeddings)
# 6. 打印结果
topic_info = topic_model.get_topic_info()
print("\n✅ 训练完成!Top 10 主题预览:")
print(topic_info.head(10))
# ================= 7. 保存结果 =================
print("\n💾 正在保存详细 CSV...")
# 将原文、主题ID、概率合并保存
df_result = pd.DataFrame({
"Document": docs,
"Topic": topics,
"Probability": probs if probs is not None else 0
})
# 把主题关键词也拼上去
topic_names = {row['Topic']: row['Name'] for _, row in topic_info.iterrows()}
df_result['Topic_Name'] = df_result['Topic'].map(topic_names)
df_result.to_csv(os.path.join(RESULT_DIR, "topic_results.csv"), index=False, encoding='utf-8-sig')
print(" ✅ CSV 已保存: topic_results.csv")
# ================= 8. 可视化 (带保护) =================
print("\n🎨 正在生成可视化网页...")
nr_topics = len(topic_info) - 1 # 减去 -1 噪音类
try:
# 1. 主题距离图 (Intertopic Distance Map)
# 只有主题数 > 4 才能画这个,否则 UMAP 报错 k>=N
if nr_topics > 4:
fig1 = topic_model.visualize_topics()
fig1.write_html(os.path.join(RESULT_DIR, "chart_distance.html"))
print(" ✅ chart_distance.html")
else:
print(f" ⚠️ 主题数量过少({nr_topics}),跳过距离图绘制。")
# 2. 条形图 (Barchart) - 展示每个主题的核心词
fig2 = topic_model.visualize_barchart(top_n_topics=20)
fig2.write_html(os.path.join(RESULT_DIR, "chart_bar.html"))
print(" ✅ chart_bar.html")
# 3. 热力图 (Similarity Heatmap)
if nr_topics > 2:
fig3 = topic_model.visualize_heatmap()
fig3.write_html(os.path.join(RESULT_DIR, "chart_heatmap.html"))
print(" ✅ chart_heatmap.html")
except Exception as e:
print(f"⚠️ 可视化部分报错 (不影响 CSV 结果): {e}")
print("\n✨ 全部流程结束!")
if __name__ == "__main__":
main()