Source code for neo4j_graphrag.generation.graphrag

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from __future__ import annotations

import logging
import warnings
from typing import Any, Optional

from pydantic import ValidationError

from neo4j_graphrag.exceptions import (
    RagInitializationError,
    SearchValidationError,
)
from neo4j_graphrag.generation.prompts import RagTemplate
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RetrieverResult

logger = logging.getLogger(__name__)


[docs] class GraphRAG: """Performs a GraphRAG search using a specific retriever and LLM. Example: .. code-block:: python import neo4j from neo4j_graphrag.retrievers import VectorRetriever from neo4j_graphrag.llm.openai_llm import OpenAILLM from neo4j_graphrag.generation import GraphRAG driver = neo4j.GraphDatabase.driver(URI, auth=AUTH) retriever = VectorRetriever(driver, "vector-index-name", custom_embedder) llm = OpenAILLM() graph_rag = GraphRAG(retriever, llm) graph_rag.search(query_text="Find me a book about Fremen") Args: retriever (Retriever): The retriever used to find relevant context to pass to the LLM. llm (LLMInterface): The LLM used to generate the answer. prompt_template (RagTemplate): The prompt template that will be formatted with context and user question and passed to the LLM. Raises: RagInitializationError: If validation of the input arguments fail. """ def __init__( self, retriever: Retriever, llm: LLMInterface, prompt_template: RagTemplate = RagTemplate(), ): try: validated_data = RagInitModel( retriever=retriever, llm=llm, prompt_template=prompt_template, ) except ValidationError as e: raise RagInitializationError(e.errors()) self.retriever = validated_data.retriever self.llm = validated_data.llm self.prompt_template = validated_data.prompt_template
[docs] def search( self, query_text: str = "", message_history: Optional[list[LLMMessage]] = None, examples: str = "", retriever_config: Optional[dict[str, Any]] = None, return_context: bool | None = None, ) -> RagResultModel: """ .. warning:: The default value of 'return_context' will change from 'False' to 'True' in a future version. This method performs a full RAG search: 1. Retrieval: context retrieval 2. Augmentation: prompt formatting 3. Generation: answer generation with LLM Args: query_text (str): The user question. message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. examples (str): Examples added to the LLM prompt. retriever_config (Optional[dict]): Parameters passed to the retriever. search method; e.g.: top_k return_context (bool): Whether to append the retriever result to the final result (default: False). Returns: RagResultModel: The LLM-generated answer. """ if return_context is None: warnings.warn( "The default value of 'return_context' will change from 'False' to 'True' in a future version.", DeprecationWarning, ) return_context = False try: validated_data = RagSearchModel( query_text=query_text, examples=examples, retriever_config=retriever_config or {}, return_context=return_context, ) except ValidationError as e: raise SearchValidationError(e.errors()) query = self.build_query(validated_data.query_text, message_history) retriever_result: RetrieverResult = self.retriever.search( query_text=query, **validated_data.retriever_config ) context = "\n".join(item.content for item in retriever_result.items) prompt = self.prompt_template.format( query_text=query_text, context=context, examples=validated_data.examples ) logger.debug(f"RAG: retriever_result={retriever_result}") logger.debug(f"RAG: prompt={prompt}") answer = self.llm.invoke( prompt, message_history, system_instruction=self.prompt_template.system_instructions, ) result: dict[str, Any] = {"answer": answer.content} if return_context: result["retriever_result"] = retriever_result return RagResultModel(**result)
[docs] def build_query( self, query_text: str, message_history: Optional[list[LLMMessage]] = None ) -> str: summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words." if message_history: summarization_prompt = self.chat_summary_prompt( message_history=message_history ) summary = self.llm.invoke( input=summarization_prompt, system_instruction=summary_system_message, ).content return self.conversation_prompt(summary=summary, current_query=query_text) return query_text
[docs] def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str: message_list = [ ": ".join([f"{value}" for _, value in message.items()]) for message in message_history ] history = "\n".join(message_list) return f""" Summarize the message history: {history} """
[docs] def conversation_prompt(self, summary: str, current_query: str) -> str: return f""" Message Summary: {summary} Current Query: {current_query} """