首页
⬇️ 下载
import os
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader

# --- 配置 ---
METHOD = 'bert' 
RAW_DATA_PATH = './data/文本.txt'
OUTPUT_PATH = './data/embedding.npy'
BATCH_SIZE = 16

def get_embeddings():
    if not os.path.exists(RAW_DATA_PATH):
        raise FileNotFoundError(f"找不到输入文件: {RAW_DATA_PATH}")

    with open(RAW_DATA_PATH, 'r', encoding='utf-8') as f:
        sentences = [line.strip() for line in f.readlines() if line.strip()]

    if METHOD == 'bert':
        from transformers import BertTokenizer, BertModel
        model_path = './models/bert-base-chinese'
        
        print(f"正在从本地加载 BERT: {model_path}")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        tokenizer = BertTokenizer.from_pretrained(model_path, local_files_only=True)
        model = BertModel.from_pretrained(model_path, local_files_only=True).to(device)
        model.eval()

        # 补全缺失的批量生成逻辑
        all_embeddings = []
        data_loader = DataLoader(sentences, batch_size=BATCH_SIZE)
        
        print("开始生成 BERT 向量...")
        with torch.no_grad():
            for batch in tqdm(data_loader):
                inputs = tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
                outputs = model(**inputs)
                # 取 CLS token 向量并转为 numpy
                cls_emb = outputs.last_hidden_state[:, 0, :].cpu().numpy()
                all_embeddings.append(cls_emb)
        
        # ⚠️ 关键:必须 return 结果
        return np.vstack(all_embeddings)
        
    elif METHOD == 'st':
        from sentence_transformers import SentenceTransformer
        model_path = './models/paraphrase-multilingual-MiniLM-L12-v2'
        print(f"正在从本地加载 SentenceTransformer: {model_path}")
        model = SentenceTransformer(model_path)
        return model.encode(sentences, show_progress_bar=True)

    return None

if __name__ == "__main__":
    embeddings = get_embeddings()
    if embeddings is not None:
        os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
        np.save(OUTPUT_PATH, embeddings)
        print(f"向量已保存,形状: {embeddings.shape}")
    else:
        print("错误:向量生成失败,返回值为 None。")
直链已复制!
可使用 wget 直接下载