import os
import sys
import torch
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
# ================= 🚑 H20 显卡保命补丁 (必须放在最前面) =================
# 1. 解决 CPU 数学库冲突
os.environ["OPENBLAS_CORETYPE"] = "Haswell"
# 2. 解决 H20 GPU 在 BERT 上的底层崩溃/卡死问题
if torch.cuda.is_available():
try:
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
print("🛡️ 已禁用 Flash Attention,启用稳定数学模式 (Math SDP)")
except Exception as e:
print(f"⚠️ 无法设置加速后端: {e}")
# ====================================================================
# ================= 配置区域 =================
INPUT_FILE = './data/文本.txt' # 确保这里是你清洗过的那份文件
OUTPUT_FILE = './data/emb.npy' # 输出 npy 路径
LOCAL_MODEL_DIR = './models/bert-base-chinese' # 你的模型路径
BATCH_SIZE = 32
MAX_LENGTH = 512 # BERT 最大长度限制
# ===========================================
def main():
# --- 1. 检查环境 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 运行设备: {device}")
if torch.cuda.is_available():
print(f" 显卡型号: {torch.cuda.get_device_name(0)}")
# --- 2. 读取并清洗数据 ---
if not os.path.exists(INPUT_FILE):
print(f"❌ 错误: 找不到输入文件 {INPUT_FILE}")
return
print("📖 正在读取文本数据...")
valid_sentences = []
with open(INPUT_FILE, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
text = line.strip()
# 过滤空行,防止报错
if len(text) > 0:
valid_sentences.append(text)
print(f"📊 有效数据量: {len(valid_sentences)} 条")
# --- 3. 加载模型与Tokenizer ---
print(f"Pm 正在加载模型 (路径: {LOCAL_MODEL_DIR})...")
try:
tokenizer = BertTokenizer.from_pretrained(LOCAL_MODEL_DIR)
model = BertModel.from_pretrained(LOCAL_MODEL_DIR)
except Exception as e:
print(f"❌ 模型加载失败: {e}")
return
model.to(device)
model.eval()
# --- 4. 批处理生成向量 ---
data_loader = DataLoader(valid_sentences, batch_size=BATCH_SIZE, shuffle=False)
cls_embeddings = []
print("🔄 开始推理...")
for batch_sentences in tqdm(data_loader, desc="Embedding"):
# Tokenize
inputs = tokenizer(
batch_sentences,
padding=True,
truncation=True, # ⚠️ 关键:强制截断,防止超长文本卡死
max_length=MAX_LENGTH, # ⚠️ 限制最大长度 512
return_tensors="pt"
)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(**inputs)
# 提取 [CLS] 向量
batch_emb = outputs.last_hidden_state[:, 0, :].cpu().numpy()
cls_embeddings.append(batch_emb)
# --- 5. 合并与保存 ---
if cls_embeddings:
final_embeddings = np.vstack(cls_embeddings)
print(f"✅ 生成完毕。最终矩阵形状: {final_embeddings.shape}")
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
np.save(OUTPUT_FILE, final_embeddings)
print(f"💾 结果已保存至: {OUTPUT_FILE}")
else:
print("⚠️ 未生成任何数据。")
if __name__ == "__main__":
main()