🏠 首页
⬇️ 下载
import os
import sys
import torch
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import BertTokenizer, BertModel

# ================= 🚑 H20 显卡保命补丁 (必须放在最前面) =================
# 1. 解决 CPU 数学库冲突
os.environ["OPENBLAS_CORETYPE"] = "Haswell"

# 2. 解决 H20 GPU 在 BERT 上的底层崩溃/卡死问题
if torch.cuda.is_available():
    try:
        torch.backends.cuda.enable_flash_sdp(False)
        torch.backends.cuda.enable_mem_efficient_sdp(False)
        torch.backends.cuda.enable_math_sdp(True)
        print("🛡️ 已禁用 Flash Attention,启用稳定数学模式 (Math SDP)")
    except Exception as e:
        print(f"⚠️ 无法设置加速后端: {e}")
# ====================================================================

# ================= 配置区域 =================
INPUT_FILE = './data/文本.txt'      # 确保这里是你清洗过的那份文件
OUTPUT_FILE = './data/emb.npy'      # 输出 npy 路径
LOCAL_MODEL_DIR = './models/bert-base-chinese' # 你的模型路径

BATCH_SIZE = 32
MAX_LENGTH = 512                    # BERT 最大长度限制
# ===========================================

def main():
    # --- 1. 检查环境 ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🚀 运行设备: {device}")
    if torch.cuda.is_available():
        print(f"   显卡型号: {torch.cuda.get_device_name(0)}")

    # --- 2. 读取并清洗数据 ---
    if not os.path.exists(INPUT_FILE):
        print(f"❌ 错误: 找不到输入文件 {INPUT_FILE}")
        return
    
    print("📖 正在读取文本数据...")
    valid_sentences = []
    with open(INPUT_FILE, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            text = line.strip()
            # 过滤空行,防止报错
            if len(text) > 0:
                valid_sentences.append(text)

    print(f"📊 有效数据量: {len(valid_sentences)} 条")

    # --- 3. 加载模型与Tokenizer ---
    print(f"Pm 正在加载模型 (路径: {LOCAL_MODEL_DIR})...")
    try:
        tokenizer = BertTokenizer.from_pretrained(LOCAL_MODEL_DIR)
        model = BertModel.from_pretrained(LOCAL_MODEL_DIR)
    except Exception as e:
        print(f"❌ 模型加载失败: {e}")
        return

    model.to(device)
    model.eval()

    # --- 4. 批处理生成向量 ---
    data_loader = DataLoader(valid_sentences, batch_size=BATCH_SIZE, shuffle=False)
    cls_embeddings = []

    print("🔄 开始推理...")
    for batch_sentences in tqdm(data_loader, desc="Embedding"):
        # Tokenize
        inputs = tokenizer(
            batch_sentences,
            padding=True,
            truncation=True,        # ⚠️ 关键:强制截断,防止超长文本卡死
            max_length=MAX_LENGTH,  # ⚠️ 限制最大长度 512
            return_tensors="pt"
        )

        inputs = inputs.to(device)

        with torch.no_grad():
            outputs = model(**inputs)
        
        # 提取 [CLS] 向量
        batch_emb = outputs.last_hidden_state[:, 0, :].cpu().numpy()
        cls_embeddings.append(batch_emb)

    # --- 5. 合并与保存 ---
    if cls_embeddings:
        final_embeddings = np.vstack(cls_embeddings)
        print(f"✅ 生成完毕。最终矩阵形状: {final_embeddings.shape}")
        
        os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
        np.save(OUTPUT_FILE, final_embeddings)
        print(f"💾 结果已保存至: {OUTPUT_FILE}")
    else:
        print("⚠️ 未生成任何数据。")

if __name__ == "__main__":
    main()
链接已复制!
可在终端使用 wget "链接" 下载