Source code for neo4j_graphrag.experimental.components.embedder

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from pydantic import validate_call

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks
from neo4j_graphrag.experimental.pipeline.component import Component


[docs] class TextChunkEmbedder(Component): """Component for creating embeddings from text chunks. Args: embedder (Embedder): The embedder to use to create the embeddings. Example: .. code-block:: python from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings from neo4j_graphrag.experimental.pipeline import Pipeline embedder = OpenAIEmbeddings(model="text-embedding-3-large") chunk_embedder = TextChunkEmbedder(embedder) pipeline = Pipeline() pipeline.add_component(chunk_embedder, "chunk_embedder") """ def __init__(self, embedder: Embedder): self._embedder = embedder def _embed_chunk(self, text_chunk: TextChunk) -> TextChunk: """Embed a single text chunk. Args: text_chunk (TextChunk): The text chunk to embed. Returns: TextChunk: The text chunk with an added "embedding" key in its metadata containing the embeddings of the text chunk's text. """ embedding = self._embedder.embed_query(text_chunk.text) metadata = text_chunk.metadata if text_chunk.metadata else {} metadata["embedding"] = embedding return TextChunk( text=text_chunk.text, index=text_chunk.index, metadata=metadata, uid=text_chunk.uid, )
[docs] @validate_call async def run(self, text_chunks: TextChunks) -> TextChunks: """Embed a list of text chunks. Args: text_chunks (TextChunks): The text chunks to embed. Returns: TextChunks: The input text chunks with each one having an added embedding. """ return TextChunks( chunks=[self._embed_chunk(text_chunk) for text_chunk in text_chunks.chunks] )