🏠 首页
⬇️ 下载
import os
import sys
import numpy as np
import pandas as pd
import jieba

# BERTopic 核心组件
from bertopic import BERTopic
from umap import UMAP
from hdbscan import HDBSCAN
from sklearn.feature_extraction.text import CountVectorizer

# ================= ⚙️ 配置区域 =================
DATA_DIR = './data'
RESULT_DIR = './result'

# 输入文件 (确保这两个文件都存在)
TEXT_FILE = os.path.join(DATA_DIR, '文本.txt')
EMBEDDING_FILE = os.path.join(DATA_DIR, 'emb_st.npy') # 使用之前脚本生成的向量

# 聚类参数 (针对 2000+ 条数据的标准设置)
UMAP_NEIGHBORS = 15       # 寻找邻居数量 (越大全局观越强,小数据设小,大数据设15)
HDBSCAN_MIN_SIZE = 15     # 最小聚类大小 (一个主题至少要有几条数据)
# ==============================================

def check_files():
    if not os.path.exists(TEXT_FILE):
        print(f"❌ 错误: 找不到文本文件: {TEXT_FILE}")
        sys.exit(1)
    if not os.path.exists(EMBEDDING_FILE):
        print(f"❌ 错误: 找不到向量文件: {EMBEDDING_FILE}")
        print("   请先运行 _02embedding_sentence_transformer.py 生成向量!")
        sys.exit(1)
    if not os.path.exists(RESULT_DIR):
        os.makedirs(RESULT_DIR)

def load_data():
    print("📖 正在读取数据...")
    # 读取文本
    with open(TEXT_FILE, 'r', encoding='utf-8', errors='ignore') as f:
        # 再次做一个简单的 strip 确保对应关系
        docs = [line.strip() for line in f if len(line.strip()) > 0]
    
    # 读取向量
    embeddings = np.load(EMBEDDING_FILE)
    
    # ⚠️ 安全检查:文本行数必须等于向量行数
    if len(docs) != embeddings.shape[0]:
        print(f"❌ 数据不匹配警告!")
        print(f"   文本行数: {len(docs)}")
        print(f"   向量行数: {embeddings.shape[0]}")
        print("   💡 请重新运行 _02 脚本,确保文本和向量是同一批生成的。")
        # 如果相差不大,可以尝试切片对齐 (危险操作,建议重跑)
        min_len = min(len(docs), embeddings.shape[0])
        docs = docs[:min_len]
        embeddings = embeddings[:min_len]
        print(f"   ⚠️ 已自动截断对齐至: {min_len} 条")
    
    print(f"✅ 数据加载完毕: {len(docs)} 条")
    return docs, embeddings

def main():
    check_files()
    docs, embeddings = load_data()

    print("🚀 正在初始化 BERTopic 模型...")

    # 1. 降维模型 (UMAP)
    # n_neighbors=15 是标准值,适合 2000 条数据
    umap_model = UMAP(
        n_neighbors=UMAP_NEIGHBORS, 
        n_components=5, 
        min_dist=0.0, 
        metric='cosine',
        random_state=42  # 固定随机种子,保证每次结果一样
    )

    # 2. 聚类模型 (HDBSCAN)
    # min_cluster_size=15,避免生成太多只有 2-3 句话的细碎主题
    hdbscan_model = HDBSCAN(
        min_cluster_size=HDBSCAN_MIN_SIZE, 
        min_samples=5, 
        metric='euclidean', 
        prediction_data=True
    )

    # 3. 分词模型 (CountVectorizer)
    # ⚠️ 关键:中文需要 jieba 分词,否则出来的主题全是整句
    def tokenizer_zh(text):
        return jieba.lcut(text)

    vectorizer_model = CountVectorizer(
        tokenizer=tokenizer_zh, 
        stop_words=['的', '了', '在', '是', '我', '也', '和', '去', '都', '我们', '就', '很', '有', '非常'], # 简单去停用词
        min_df=3  # 忽略只出现过 1-2 次的生僻词
    )

    # 4. 组装 BERTopic
    topic_model = BERTopic(
        embedding_model=None, # 我们手动传入 embeddings,所以这里设为 None
        umap_model=umap_model,
        hdbscan_model=hdbscan_model,
        vectorizer_model=vectorizer_model,
        verbose=True
    )

    # 5. 开始训练
    print("🔄 开始聚类 (Fit Transform)...")
    topics, probs = topic_model.fit_transform(docs, embeddings=embeddings)

    # 6. 打印结果
    topic_info = topic_model.get_topic_info()
    print("\n✅ 训练完成!Top 10 主题预览:")
    print(topic_info.head(10))

    # ================= 7. 保存结果 =================
    print("\n💾 正在保存详细 CSV...")
    # 将原文、主题ID、概率合并保存
    df_result = pd.DataFrame({
        "Document": docs, 
        "Topic": topics, 
        "Probability": probs if probs is not None else 0
    })
    # 把主题关键词也拼上去
    topic_names = {row['Topic']: row['Name'] for _, row in topic_info.iterrows()}
    df_result['Topic_Name'] = df_result['Topic'].map(topic_names)
    
    df_result.to_csv(os.path.join(RESULT_DIR, "topic_results.csv"), index=False, encoding='utf-8-sig')
    print("   ✅ CSV 已保存: topic_results.csv")

    # ================= 8. 可视化 (带保护) =================
    print("\n🎨 正在生成可视化网页...")
    
    nr_topics = len(topic_info) - 1 # 减去 -1 噪音类
    
    try:
        # 1. 主题距离图 (Intertopic Distance Map)
        # 只有主题数 > 4 才能画这个,否则 UMAP 报错 k>=N
        if nr_topics > 4:
            fig1 = topic_model.visualize_topics()
            fig1.write_html(os.path.join(RESULT_DIR, "chart_distance.html"))
            print("   ✅ chart_distance.html")
        else:
            print(f"   ⚠️ 主题数量过少({nr_topics}),跳过距离图绘制。")

        # 2. 条形图 (Barchart) - 展示每个主题的核心词
        fig2 = topic_model.visualize_barchart(top_n_topics=20)
        fig2.write_html(os.path.join(RESULT_DIR, "chart_bar.html"))
        print("   ✅ chart_bar.html")
        
        # 3. 热力图 (Similarity Heatmap)
        if nr_topics > 2:
            fig3 = topic_model.visualize_heatmap()
            fig3.write_html(os.path.join(RESULT_DIR, "chart_heatmap.html"))
            print("   ✅ chart_heatmap.html")

    except Exception as e:
        print(f"⚠️ 可视化部分报错 (不影响 CSV 结果): {e}")

    print("\n✨ 全部流程结束!")

if __name__ == "__main__":
    main()
链接已复制!
可在终端使用 wget "链接" 下载