Source code for neo4j_graphrag.experimental.components.neo4j_reader

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from __future__ import annotations

from typing import Optional

import neo4j
from pydantic import validate_call

from neo4j_graphrag.experimental.components.types import (
    LexicalGraphConfig,
    TextChunk,
    TextChunks,
)
from neo4j_graphrag.experimental.pipeline import Component


[docs] class Neo4jChunkReader(Component): """Reads text chunks from a Neo4j database. Args: driver (neo4j.driver): The Neo4j driver to connect to the database. fetch_embeddings (bool): If True, the embedding property is also returned. Default to False. neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_). Example: .. code-block:: python from neo4j import GraphDatabase from neo4j_graphrag.experimental.components.neo4j_reader import Neo4jChunkReader URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") DATABASE = "neo4j" driver = GraphDatabase.driver(URI, auth=AUTH) reader = Neo4jChunkReader(driver=driver, neo4j_database=DATABASE) await reader.run() """ def __init__( self, driver: neo4j.Driver, fetch_embeddings: bool = False, neo4j_database: Optional[str] = None, ): self.driver = driver self.fetch_embeddings = fetch_embeddings self.neo4j_database = neo4j_database def _get_query( self, chunk_label: str, index_property: str, embedding_property: str, ) -> str: return_properties = [".*"] if not self.fetch_embeddings: return_properties.append(f"{embedding_property}: null") query = ( f"MATCH (c:`{chunk_label}`) " f"RETURN c {{ { ', '.join(return_properties) } }} as chunk " ) if index_property: query += f"ORDER BY c.{index_property}" return query
[docs] @validate_call async def run( self, lexical_graph_config: LexicalGraphConfig = LexicalGraphConfig(), ) -> TextChunks: """Reads text chunks from a Neo4j database. Args: lexical_graph_config (LexicalGraphConfig): Node labels and relationship types for the lexical graph. """ query = self._get_query( lexical_graph_config.chunk_node_label, lexical_graph_config.chunk_index_property, lexical_graph_config.chunk_embedding_property, ) result, _, _ = self.driver.execute_query( query, database_=self.neo4j_database, routing_=neo4j.RoutingControl.READ, ) chunks = [] for record in result: chunk = record.get("chunk") input_data = { "text": chunk.pop(lexical_graph_config.chunk_text_property, ""), "index": chunk.pop(lexical_graph_config.chunk_index_property, -1), } if ( uid := chunk.pop(lexical_graph_config.chunk_id_property, None) ) is not None: input_data["uid"] = uid input_data["metadata"] = chunk chunks.append(TextChunk(**input_data)) return TextChunks(chunks=chunks)