import os
import sys
import re
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
# ================= 🚑 H20 显卡防崩配置 =================
os.environ["OPENBLAS_CORETYPE"] = "Haswell"
if torch.cuda.is_available():
try:
# 即使是单条,也建议关掉 Flash Attention 以防万一
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
except:
pass
# ======================================================
# ================= 配置区域 =================
INPUT_FILE = './data/文本.txt'
OUTPUT_FILE = './data/emb_st.npy'
LOCAL_MODEL_DIR = './models/paraphrase-multilingual-MiniLM-L12-v2'
# ⚠️ 核心修改:改为 1,彻底避开 Padding Bug
BATCH_SIZE = 1
DEVICE = 'cuda'
# ===========================================
def is_valid_text(text):
# 只保留含汉字的行
if re.search(r'[\u4e00-\u9fa5]', text):
return True
return False
def main():
print(f"🚀 运行设备: {torch.cuda.get_device_name(0)}")
print("ℹ️ 已启用【单条处理模式】(Batch Size = 1),确保 100% 稳定。")
# 1. 读取数据
if not os.path.exists(INPUT_FILE):
print(f"❌ 找不到文件: {INPUT_FILE}")
return
raw_docs = []
with open(INPUT_FILE, 'r', encoding='utf-8', errors='ignore') as f:
raw_docs = [line.strip() for line in f]
docs = [line for line in raw_docs if is_valid_text(line)]
print(f"✅ 有效数据: {len(docs)} 条")
# 2. 加载模型
try:
model = SentenceTransformer(LOCAL_MODEL_DIR, device=DEVICE)
model.max_seq_length = 256
except Exception as e:
print(f"❌ 模型加载失败: {e}")
return
# 3. 生成向量
print("🔄 开始生成向量...")
try:
embeddings = model.encode(
docs,
batch_size=BATCH_SIZE, # 这里是 1
show_progress_bar=True,
normalize_embeddings=False
)
except Exception as e:
print(f"\n❌ 居然单条也崩?报错信息: {e}")
return
# 4. 保存
print(f"✅ 生成完毕。形状: {embeddings.shape}")
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
np.save(OUTPUT_FILE, embeddings)
print(f"💾 结果已保存至: {OUTPUT_FILE}")
if __name__ == "__main__":
main()