🏠 首页
⬇️ 下载
import os
import sys
import re
import numpy as np
import torch
from sentence_transformers import SentenceTransformer

# ================= 🚑 H20 显卡防崩配置 =================
os.environ["OPENBLAS_CORETYPE"] = "Haswell"
if torch.cuda.is_available():
    try:
        # 即使是单条,也建议关掉 Flash Attention 以防万一
        torch.backends.cuda.enable_flash_sdp(False)
        torch.backends.cuda.enable_mem_efficient_sdp(False)
        torch.backends.cuda.enable_math_sdp(True)
    except:
        pass
# ======================================================

# ================= 配置区域 =================
INPUT_FILE = './data/文本.txt'
OUTPUT_FILE = './data/emb_st.npy'
LOCAL_MODEL_DIR = './models/paraphrase-multilingual-MiniLM-L12-v2'

# ⚠️ 核心修改:改为 1,彻底避开 Padding Bug
BATCH_SIZE = 1  
DEVICE = 'cuda'
# ===========================================

def is_valid_text(text):
    # 只保留含汉字的行
    if re.search(r'[\u4e00-\u9fa5]', text):
        return True
    return False

def main():
    print(f"🚀 运行设备: {torch.cuda.get_device_name(0)}")
    print("ℹ️  已启用【单条处理模式】(Batch Size = 1),确保 100% 稳定。")

    # 1. 读取数据
    if not os.path.exists(INPUT_FILE):
        print(f"❌ 找不到文件: {INPUT_FILE}")
        return

    raw_docs = []
    with open(INPUT_FILE, 'r', encoding='utf-8', errors='ignore') as f:
        raw_docs = [line.strip() for line in f]
    
    docs = [line for line in raw_docs if is_valid_text(line)]
    print(f"✅ 有效数据: {len(docs)} 条")

    # 2. 加载模型
    try:
        model = SentenceTransformer(LOCAL_MODEL_DIR, device=DEVICE)
        model.max_seq_length = 256
    except Exception as e:
        print(f"❌ 模型加载失败: {e}")
        return

    # 3. 生成向量
    print("🔄 开始生成向量...")
    try:
        embeddings = model.encode(
            docs, 
            batch_size=BATCH_SIZE,  # 这里是 1
            show_progress_bar=True,
            normalize_embeddings=False
        )
    except Exception as e:
        print(f"\n❌ 居然单条也崩?报错信息: {e}")
        return

    # 4. 保存
    print(f"✅ 生成完毕。形状: {embeddings.shape}")
    os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
    np.save(OUTPUT_FILE, embeddings)
    print(f"💾 结果已保存至: {OUTPUT_FILE}")

if __name__ == "__main__":
    main()
链接已复制!
可在终端使用 wget "链接" 下载