徐慧志的个人博客

2025-06-15 Weaviate使用(三) 两种导入数据的方法

发布于 2025年06月15日  •  3 分钟  • 476 字
Table of contents

在导入的时候,可以先生成embedding,写入csv,然后在add_batch的时候添加。另外一种方法是在导入的时候直接embedding。

第一种方法

先创建集合:Create collection

import weaviate

import weaviate.classes.config as wc
from weaviate.classes.config import Configure

client = weaviate.connect_to_local()
# client.collections.delete("MovieCustomVector")
client.collections.create(
    name="MovieCustomVector",
    properties=[
        wc.Property(name="title", data_type=wc.DataType.TEXT),
        wc.Property(name="overview", data_type=wc.DataType.TEXT),
        wc.Property(name="vote_average", data_type=wc.DataType.NUMBER),
        wc.Property(name="genre_ids", data_type=wc.DataType.INT_ARRAY),
        wc.Property(name="release_date", data_type=wc.DataType.DATE),
        wc.Property(name="tmdb_id", data_type=wc.DataType.INT),
    ],
    vectorizer_config=Configure.Vectorizer.none(),
    generative_config=Configure.Generative.ollama(
        api_endpoint="http://host.docker.internal:11434", model="llama3.2"
    ),
)

client.close()

对文本向量化和存储:embed_text

这一步组合标题和概述的内容作为要向量化的文本,向量化之后存为csv。

import pandas as pd
import os
from FlagEmbedding import BGEM3FlagModel

def generate_embeddings(df: pd.DataFrame, batch_size: int = 50) -> pd.DataFrame:
    """生成文本的向量表示

    Args:
        df: 包含电影数据的DataFrame
        batch_size: 批处理大小

    Returns:
        DataFrame: 包含所有embeddings的DataFrame
    """
    model = BGEM3FlagModel("BAAI/bge-m3")
    emb_dfs = []
    src_texts = []
    indices = []

    # 批量处理文本
    for i, row in enumerate(df.itertuples(index=False)):
        # 组合标题和概述
        text = f"Title: {row.title};Overview: {row.overview}"
        src_texts.append(text)
        indices.append(i)

        # 达到批处理大小或处理完所有数据时进行向量化
        if len(src_texts) == batch_size or i + 1 == len(df):
            embeddings = model.encode(src_texts)["dense_vecs"]
            emb_df = pd.DataFrame(embeddings, index=indices)
            emb_dfs.append(emb_df)
            src_texts = []
            indices = []

    return pd.concat(emb_dfs).sort_index()

def main():
    """主函数,处理电影数据并生成向量"""
    # 读取数据
    data_path = os.path.join(os.path.dirname(__file__), "scratch/movies_data_1990_2024.json")
    df = pd.read_json(data_path)

    # 生成向量并保存
    embeddings_df = generate_embeddings(df)
    os.makedirs("scratch", exist_ok=True)
    embeddings_df.to_csv("scratch/movies_data_1990_2024_embeddings.csv", index=False)

if __name__ == "__main__":
    main()

导入数据:Import data

在导入数据(batch.add_object)的时候需要把 vector和property分别写入。另外uuid是数据库为每一条数据创建的不重复的id。


import weaviate
import pandas as pd
from datetime import datetime, timezone
import json
from weaviate.util import generate_uuid5
from tqdm import tqdm
import os

"""
数据来源: "https://raw.githubusercontent.com/weaviate-tutorials/edu-datasets/main/movies_data_1990_2024.json"
"""

data_path = os.path.join(
    os.path.dirname(__file__), "scratch", "movies_data_1990_2024.json"
)
df = pd.read_json(data_path)

embs_path = os.path.join(
    os.path.dirname(__file__), "scratch", "movies_data_1990_2024_embeddings.csv"
)
embs_df = pd.read_csv(embs_path)

client = weaviate.connect_to_local()
movies = client.collections.get("MovieCustomVector")

with movies.batch.fixed_size(batch_size=200) as batch:
    for i, movie in tqdm(df.iterrows()):
        # Convert a JSON date to `datetime` and add time zone information
        release_date = datetime.strptime(movie["release_date"], "%Y-%m-%d").replace(
            tzinfo=timezone.utc
        )
        # Convert a JSON array to a list of integers
        genre_ids = json.loads(movie["genre_ids"])

        # Build the object payload
        movie_obj = {
            "title": movie["title"],
            "overview": movie["overview"],
            "vote_average": movie["vote_average"],
            "genre_ids": genre_ids,
            "release_date": release_date,
            "tmdb_id": movie["id"],
        }

        vector = embs_df.iloc[i].tolist()
        batch.add_object(
            properties=movie_obj, uuid=generate_uuid5(movie["id"]), vector=vector
        )

if len(movies.batch.failed_objects) > 0:
    print(f"Failed to import {len(movies.batch.failed_objects)} objects")

client.close()

第二种方法

上面两个步骤embed_text和import_data可以合为一个步骤:在import data的时候对文本做向量化。

import weaviate
import pandas as pd
from datetime import datetime, timezone
import json
from weaviate.util import generate_uuid5
from tqdm import tqdm
import os
from FlagEmbedding import BGEM3FlagModel

"""
数据来源: "https://raw.githubusercontent.com/weaviate-tutorials/edu-datasets/main/movies_data_1990_2024.json"
"""

def generate_embeddings(texts: list, model: BGEM3FlagModel) -> list:
    """生成文本的向量表示

    Args:
        texts: 要向量化的文本列表
        model: BGE-M3模型实例

    Returns:
        list: 向量化后的结果
    """
    return model.encode(texts)["dense_vecs"]

data_path = os.path.join(
    os.path.dirname(__file__), "scratch", "movies_data_1990_2024.json"
)
df = pd.read_json(data_path)

client = weaviate.connect_to_local()
movies = client.collections.get("MovieCustomVector")

# 初始化BGE-M3模型
model = BGEM3FlagModel("BAAI/bge-m3")

# 批量处理大小
batch_size = 50
src_texts = []
batch_indices = []
batch_objects = []

with movies.batch.fixed_size(batch_size=200) as batch:
    for i, movie in tqdm(df.iterrows()):
        # 准备向量化的文本
        text = f"Title: {movie['title']};Overview: {movie['overview']}"
        src_texts.append(text)
        batch_indices.append(i)

        # 准备对象数据
        release_date = datetime.strptime(movie["release_date"], "%Y-%m-%d").replace(
            tzinfo=timezone.utc
        )
        genre_ids = json.loads(movie["genre_ids"])
        movie_obj = {
            "title": movie["title"],
            "overview": movie["overview"],
            "vote_average": movie["vote_average"],
            "genre_ids": genre_ids,
            "release_date": release_date,
            "tmdb_id": movie["id"],
        }
        batch_objects.append((movie_obj, movie["id"]))

        # 达到批处理大小或处理完所有数据时进行向量化和导入
        if len(src_texts) == batch_size or i + 1 == len(df):

            vectors = generate_embeddings(src_texts, model)

            # 批量导入数据
            for j, (obj, movie_id) in enumerate(batch_objects):
                batch.add_object(
                    properties=obj,
                    uuid=generate_uuid5(movie_id),
                    vector=vectors[j]
                )

            # 清空批次数据
            src_texts = []
            batch_indices = []
            batch_objects = []

if len(movies.batch.failed_objects) > 0:
    print(f"Failed to import {len(movies.batch.failed_objects)} objects")

client.close()
Sein heißt werden, leben heißt lernen.

Der einfache Weg is immer verkehrt.