Source code for neo4j_graphrag.types

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from __future__ import annotations

from enum import Enum
from typing import Any, Callable, Literal, Optional

import neo4j
from pydantic import (
    BaseModel,
    ConfigDict,
    PositiveInt,
    field_validator,
    model_validator,
)

from neo4j_graphrag.utils.validation import validate_search_query_input


[docs] class RawSearchResult(BaseModel): """ Represents the raw result returned from the retriever get_search_result method. It needs to be formatted further before being returned as a RetrieverResult. Attributes: records (list[neo4j.Record]): A list of records from neo4j. metadata: Record-related metadata, such as score. """ records: list[neo4j.Record] metadata: Optional[dict[str, Any]] = None model_config = ConfigDict(arbitrary_types_allowed=True) @field_validator("records") def check_records(cls, value: neo4j.Record) -> neo4j.Record: for v in value: if not isinstance(v, neo4j.Record): raise TypeError( "Type mismatch, expected neo4j.Record, got {}".format(type(v)) ) return value
[docs] class RetrieverResultItem(BaseModel): """ A single record returned from a retriever. Attributes: content (str): The context as will be provided to the LLM metadata (Optional[dict]): Any metadata that can be included together with the text, related to that record (e.g. another node property) """ content: Any metadata: Optional[dict[str, Any]] = None
[docs] class RetrieverResult(BaseModel): """ Represents a result returned from a retriever. Attributes: items (list[RetrieverResultItem]): A list of retrieved items. metadata: Context-related metadata such as generated Cypher query in the Text2CypherRetriever. """ items: list[RetrieverResultItem] metadata: Optional[dict[str, Any]] = None
class Neo4jSchemaModel(BaseModel): neo4j_schema: str class IndexModel(BaseModel): driver: Any @field_validator("driver") def check_driver_is_valid(cls, v: neo4j.Driver) -> neo4j.Driver: if not isinstance(v, neo4j.Driver): raise ValueError("driver must be an instance of neo4j.Driver") return v class VectorIndexModel(IndexModel): name: str label: str embedding_property: str dimensions: PositiveInt similarity_fn: Literal["euclidean", "cosine"] class FulltextIndexModel(IndexModel): name: str label: str node_properties: list[str] @field_validator("node_properties") def check_node_properties_not_empty(cls, v: list[Any]) -> list[Any]: if len(v) == 0: raise ValueError("node_properties cannot be an empty list") return v class VectorSearchModel(BaseModel): query_vector: Optional[list[float]] = None query_text: Optional[str] = None top_k: PositiveInt = 5 filters: Optional[dict[str, Any]] = None @model_validator(mode="before") def check_query(cls, values: dict[str, Any]) -> dict[str, Any]: """ Validates that one of either query_vector or query_text is provided exclusively. """ query_vector, query_text = values.get("query_vector"), values.get("query_text") validate_search_query_input(query_text, query_vector) return values class VectorCypherSearchModel(VectorSearchModel): query_params: Optional[dict[str, Any]] = None class HybridSearchModel(BaseModel): query_text: str query_vector: Optional[list[float]] = None top_k: PositiveInt = 5 class HybridCypherSearchModel(HybridSearchModel): query_params: Optional[dict[str, Any]] = None class Text2CypherSearchModel(BaseModel): query_text: str class SearchType(str, Enum): """Enumerator of the search strategies.""" VECTOR = "vector" HYBRID = "hybrid" class EmbedderModel(BaseModel): embedder: Optional[Any] model_config = ConfigDict(arbitrary_types_allowed=True) @field_validator("embedder") def check_embedder(cls, value: dict[str, Any]) -> dict[str, Any]: if not hasattr(value, "embed_query") or not callable( getattr(value, "embed_query", None) ): raise ValueError( "Provided embedder object must have an 'embed_query' callable method." ) return value class LLMModel(BaseModel): llm: Any model_config = ConfigDict(arbitrary_types_allowed=True) @field_validator("llm") def check_llm(cls, value: object) -> object: if not hasattr(value, "invoke") or not callable(getattr(value, "invoke", None)): raise ValueError( "Provided llm object must have an 'invoke' callable method." ) return value class Neo4jDriverModel(BaseModel): driver: neo4j.Driver model_config = ConfigDict(arbitrary_types_allowed=True) @field_validator("driver") def check_driver(cls, value: neo4j.Driver) -> neo4j.Driver: if not isinstance(value, neo4j.Driver): raise ValueError("Provided driver needs to be of type neo4j.Driver") return value class VectorRetrieverModel(BaseModel): driver_model: Neo4jDriverModel index_name: str embedder_model: Optional[EmbedderModel] = None return_properties: Optional[list[str]] = None result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None class VectorCypherRetrieverModel(BaseModel): driver_model: Neo4jDriverModel index_name: str retrieval_query: str embedder_model: Optional[EmbedderModel] = None result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None class HybridRetrieverModel(BaseModel): driver_model: Neo4jDriverModel vector_index_name: str fulltext_index_name: str embedder_model: Optional[EmbedderModel] = None return_properties: Optional[list[str]] = None result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None class HybridCypherRetrieverModel(BaseModel): driver_model: Neo4jDriverModel vector_index_name: str fulltext_index_name: str retrieval_query: str embedder_model: Optional[EmbedderModel] = None result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None class Text2CypherRetrieverModel(BaseModel): driver_model: Neo4jDriverModel llm_model: LLMModel neo4j_schema_model: Optional[Neo4jSchemaModel] = None examples: Optional[list[str]] = None result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None custom_prompt: Optional[str] = None neo4j_database: Optional[str] = None