🏠 首页
⬇️ 下载
# main.py
import os
import numpy as np
import pandas as pd
from bertopic import BERTopic
from umap import UMAP
from hdbscan import HDBSCAN
from sklearn.feature_extraction.text import CountVectorizer

def run_pipeline():
    # --- 1. 路径配置 ---
    CUT_DATA_PATH = './data/切词.txt' 
    EMBEDDING_PATH = './data/embedding.npy'   
    RESULT_DIR = './result'
    if not os.path.exists(RESULT_DIR):
        os.makedirs(RESULT_DIR)
    RESULT_CSV_PATH = os.path.join(RESULT_DIR, '聚类结果.csv')

    # --- 2. 加载数据 ---
    print(f"正在读取文件...")
    if not os.path.exists(CUT_DATA_PATH) or not os.path.exists(EMBEDDING_PATH):
        raise FileNotFoundError(f"文件缺失!请确认 ./data 目录下文件是否存在。")

    with open(CUT_DATA_PATH, 'r', encoding='utf-8') as f:
        docs = [line.strip() for line in f.readlines()]
    
    embeddings = np.load(EMBEDDING_PATH)
    
    # 对齐
    if len(docs) != embeddings.shape[0]:
        min_len = min(len(docs), embeddings.shape[0])
        docs = docs[:min_len]
        embeddings = embeddings[:min_len]
        print(f"⚠️ 数据已截断对齐为: {min_len} 条")
    else:
        print(f"✅ 数据加载完毕: {len(docs)} 条")

    # ==========================================
    # 3. 初始化模型
    # ==========================================
    print("初始化模型组件...")
    
    umap_model = UMAP(n_neighbors=5, n_components=5, min_dist=0.0, metric='cosine', random_state=42)
    hdbscan_model = HDBSCAN(min_cluster_size=5, min_samples=2, metric='euclidean')

    # ⚠️ 核心修改:放弃自定义函数,使用纯正则 r"(?u)[^ ]+"
    # 含义:匹配所有“非空格”的连续字符。
    # 优势:纯字符串配置,完美支持多进程/序列化,不受系统 Locale 影响。
    vectorizer_model = CountVectorizer(
        token_pattern=r"(?u)[^ ]+", 
        stop_words=['洛阳', '旅游', '文化', '我们', '一个', '可以', '觉得', '非常', '没有', '什么', '收起', '评价'],
        lowercase=False # 既然是中文,不需要强制转小写,防止 Unicode 编码问题
    )

    # ==========================================
    # 4. 训练 BERTopic
    # ==========================================
    print("开始训练 BERTopic...")
    topic_model = BERTopic(
        embedding_model=None, 
        umap_model=umap_model,
        hdbscan_model=hdbscan_model,
        vectorizer_model=vectorizer_model,
        verbose=True
    )
    
    topic_model.fit_transform(docs, embeddings=embeddings)

    # ==========================================
    # 5. 保存结果
    # ==========================================
    info = topic_model.get_topic_info()
    print(f"\n✅ 训练完成!共发现 {len(info)-1} 个主题。")
    print("前 5 个主题预览:")
    print(info.head()) 
    
    topic_model.get_document_info(docs).to_csv(RESULT_CSV_PATH, index=False, encoding='utf-8-sig')
    print(f"表格已保存至: {RESULT_CSV_PATH}")

    # ==========================================
    # 6. 生成可视化
    # ==========================================
    print("\n=== 生成可视化图表 ===")
    def save_plot(fig, filename):
        path = os.path.join(RESULT_DIR, filename)
        fig.write_html(path)
        print(f"   -> 已保存: {filename}")

    try:
        save_plot(topic_model.visualize_topics(), "topics_distance.html")
    except: pass
    
    try:
        save_plot(topic_model.visualize_barchart(top_n_topics=20), "topics_barchart.html")
    except: pass

    try:
        save_plot(topic_model.visualize_heatmap(), "topics_heatmap.html")
    except: pass

    try:
        save_plot(topic_model.visualize_hierarchy(), "topics_hierarchy.html")
    except: pass

    print(f"\n🎉 结果已生成到 {RESULT_DIR} 文件夹!")

if __name__ == "__main__":
    run_pipeline()
链接已复制!
可在终端使用 wget "链接" 下载