Source code for neo4j_graphrag.experimental.components.entity_relation_extractor

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB []
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  See the License for the specific language governing permissions and
#  limitations under the License.
from __future__ import annotations

import abc
import asyncio
import enum
import json
import logging
from typing import Any, List, Optional, Union, cast

import json_repair
from pydantic import ValidationError, validate_call

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder
from neo4j_graphrag.experimental.components.schema import SchemaConfig
from neo4j_graphrag.experimental.components.types import (
from neo4j_graphrag.experimental.pipeline.component import Component
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
from neo4j_graphrag.generation.prompts import ERExtractionTemplate, PromptTemplate
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.utils.logging import prettify

logger = logging.getLogger(__name__)

class OnError(enum.Enum):

    def possible_values(cls) -> List[str]:
        return [e.value for e in cls]

def balance_curly_braces(json_string: str) -> str:
    Balances curly braces `{}` in a JSON string. This function ensures that every opening brace has a corresponding
    closing brace, but only when they are not part of a string value. If there are unbalanced closing braces,
    they are ignored. If there are missing closing braces, they are appended at the end of the string.

        json_string (str): A potentially malformed JSON string with unbalanced curly braces.

        str: A JSON string with balanced curly braces.
    stack = []
    fixed_json = []
    in_string = False
    escape = False

    for char in json_string:
        if char == '"' and not escape:
            in_string = not in_string
        elif char == "\\" and in_string:
            escape = not escape
            escape = False

        if not in_string:
            if char == "{":
            elif char == "}" and stack and stack[-1] == "{":
            elif char == "}" and (not stack or stack[-1] != "{"):

    # If stack is not empty, add missing closing braces
    while stack:

    return "".join(fixed_json)

def fix_invalid_json(raw_json: str) -> str:
    repaired_json = json_repair.repair_json(raw_json)
    repaired_json = cast(str, repaired_json).strip()

    if repaired_json == '""':
        raise InvalidJSONError("JSON repair resulted in an empty or invalid JSON.")
    if not repaired_json:
        raise InvalidJSONError("JSON repair resulted in an empty string.")
    return repaired_json

[docs] class EntityRelationExtractor(Component, abc.ABC): """Abstract class for entity relation extraction components. Args: on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error. create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True. """ def __init__( self, *args: Any, on_error: OnError = OnError.IGNORE, create_lexical_graph: bool = True, **kwargs: Any, ) -> None: self.on_error = on_error self.create_lexical_graph = create_lexical_graph
[docs] @abc.abstractmethod async def run( self, chunks: TextChunks, document_info: Optional[DocumentInfo] = None, lexical_graph_config: Optional[LexicalGraphConfig] = None, **kwargs: Any, ) -> Neo4jGraph: pass
[docs] def update_ids( self, graph: Neo4jGraph, chunk: TextChunk, ) -> Neo4jGraph: """Make node IDs unique across chunks, document and pipeline runs by prefixing them with a unique prefix. """ prefix = f"{chunk.chunk_id}" for node in graph.nodes: = f"{prefix}:{}" if is None: = {}{"chunk_index": chunk.index}) for rel in graph.relationships: rel.start_node_id = f"{prefix}:{rel.start_node_id}" rel.end_node_id = f"{prefix}:{rel.end_node_id}" return graph
[docs] class LLMEntityRelationExtractor(EntityRelationExtractor): """ Extracts a knowledge graph from a series of text chunks using a large language model. Args: llm (LLMInterface): The language model to use for extraction. prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction. create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True. on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error. max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM. Example: .. code-block:: python from neo4j_graphrag.experimental.components.entity_relation_extractor import LLMEntityRelationExtractor from neo4j_graphrag.llm import OpenAILLM from neo4j_graphrag.experimental.pipeline import Pipeline llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0, "response_format": {"type": "object"}}) extractor = LLMEntityRelationExtractor(llm=llm) pipe = Pipeline() pipe.add_component(extractor, "extractor") """ def __init__( self, llm: LLMInterface, prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(), create_lexical_graph: bool = True, on_error: OnError = OnError.RAISE, max_concurrency: int = 5, ) -> None: super().__init__(on_error=on_error, create_lexical_graph=create_lexical_graph) self.llm = llm # with response_format={ "type": "json_object" }, self.max_concurrency = max_concurrency if isinstance(prompt_template, str): template = PromptTemplate(prompt_template, expected_inputs=[]) else: template = prompt_template self.prompt_template = template async def extract_for_chunk( self, schema: SchemaConfig, examples: str, chunk: TextChunk ) -> Neo4jGraph: """Run entity extraction for a given text chunk.""" prompt = self.prompt_template.format( text=chunk.text, schema=schema.model_dump(), examples=examples ) llm_result = await self.llm.ainvoke(prompt) try: llm_generated_json = fix_invalid_json(llm_result.content) result = json.loads(llm_generated_json) except (json.JSONDecodeError, InvalidJSONError) as e: if self.on_error == OnError.RAISE: raise LLMGenerationError("LLM response is not valid JSON") from e else: logger.error( f"LLM response is not valid JSON for chunk_index={chunk.index}" ) logger.debug(f"Invalid JSON: {llm_result.content}") result = {"nodes": [], "relationships": []} try: chunk_graph = Neo4jGraph.model_validate(result) except ValidationError as e: if self.on_error == OnError.RAISE: raise LLMGenerationError("LLM response has improper format") from e else: logger.error( f"LLM response has improper format for chunk_index={chunk.index}" ) logger.debug(f"Invalid JSON format: {result}") chunk_graph = Neo4jGraph() return chunk_graph async def post_process_chunk( self, chunk_graph: Neo4jGraph, chunk: TextChunk, lexical_graph_builder: Optional[LexicalGraphBuilder] = None, ) -> None: """Perform post-processing after entity and relation extraction: - Update node IDs to make them unique across chunks - Build the lexical graph if requested """ self.update_ids(chunk_graph, chunk) if lexical_graph_builder: await lexical_graph_builder.process_chunk_extracted_entities( chunk_graph, chunk, ) def combine_chunk_graphs( self, lexical_graph: Optional[Neo4jGraph], chunk_graphs: List[Neo4jGraph] ) -> Neo4jGraph: """Combine sub-graphs obtained for each chunk into a single Neo4jGraph object""" if lexical_graph: graph = lexical_graph.model_copy(deep=True) else: graph = Neo4jGraph() for chunk_graph in chunk_graphs: graph.nodes.extend(chunk_graph.nodes) graph.relationships.extend(chunk_graph.relationships) return graph async def run_for_chunk( self, sem: asyncio.Semaphore, chunk: TextChunk, schema: SchemaConfig, examples: str, lexical_graph_builder: Optional[LexicalGraphBuilder] = None, ) -> Neo4jGraph: """Run extraction and post processing for a single chunk""" async with sem: chunk_graph = await self.extract_for_chunk(schema, examples, chunk) await self.post_process_chunk( chunk_graph, chunk, lexical_graph_builder, ) return chunk_graph
[docs] @validate_call async def run( self, chunks: TextChunks, document_info: Optional[DocumentInfo] = None, lexical_graph_config: Optional[LexicalGraphConfig] = None, schema: Union[SchemaConfig, None] = None, examples: str = "", **kwargs: Any, ) -> Neo4jGraph: """Perform entity and relation extraction for all chunks in a list. Optionally, creates the "lexical graph" by adding nodes and relationships to represent the document and its chunks in the returned graph (For more details, see the :ref:`Lexical Graph Builder doc <lexical-graph-builder>` and the :ref:`User Guide <lexical-graph-in-er-extraction>`) Args: chunks (TextChunks): List of text chunks to extract entities and relations from. document_info (Optional[DocumentInfo], optional): Document the chunks are coming from. Used in the lexical graph creation step. lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph. schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction. Caution: at the moment, there is no guarantee that the extracted entities and relations will strictly obey the schema. examples (str): Examples for few-shot learning in the prompt. """ lexical_graph_builder = None lexical_graph = None if self.create_lexical_graph: config = lexical_graph_config or LexicalGraphConfig() lexical_graph_builder = LexicalGraphBuilder(config=config) lexical_graph_result = await text_chunks=chunks, document_info=document_info ) lexical_graph = lexical_graph_result.graph elif lexical_graph_config: lexical_graph_builder = LexicalGraphBuilder(config=lexical_graph_config) schema = schema or SchemaConfig(entities={}, relations={}, potential_schema=[]) examples = examples or "" sem = asyncio.Semaphore(self.max_concurrency) tasks = [ self.run_for_chunk( sem, chunk, schema, examples, lexical_graph_builder, ) for chunk in chunks.chunks ] chunk_graphs: list[Neo4jGraph] = list(await asyncio.gather(*tasks)) graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs) logger.debug(f"Extracted graph: {prettify(graph)}") return graph