Source code for neo4j_graphrag.experimental.pipeline.pipeline

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from __future__ import annotations

import asyncio
import datetime
import enum
import logging
import uuid
import warnings
from collections import defaultdict
from timeit import default_timer
from typing import Any, AsyncGenerator, Optional

from neo4j_graphrag.utils.logging import prettify

try:
    import pygraphviz as pgv
except ImportError:
    pgv = None

from pydantic import BaseModel, Field

from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
from neo4j_graphrag.experimental.pipeline.exceptions import (
    PipelineDefinitionError,
    PipelineMissingDependencyError,
    PipelineStatusUpdateError,
)
from neo4j_graphrag.experimental.pipeline.pipeline_graph import (
    PipelineEdge,
    PipelineGraph,
    PipelineNode,
)
from neo4j_graphrag.experimental.pipeline.stores import InMemoryStore, ResultStore
from neo4j_graphrag.experimental.pipeline.types import (
    ComponentDefinition,
    ConnectionDefinition,
    PipelineDefinition,
)

logger = logging.getLogger(__name__)


class RunStatus(enum.Enum):
    UNKNOWN = "UNKNOWN"
    RUNNING = "RUNNING"
    DONE = "DONE"


class RunResult(BaseModel):
    status: RunStatus = RunStatus.DONE
    result: Optional[DataModel] = None
    timestamp: datetime.datetime = Field(
        default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
    )


class TaskPipelineNode(PipelineNode):
    """Runnable node. It must have:
    - a name (unique within the pipeline)
    - a component instance
    """

    def __init__(self, name: str, component: Component):
        """TaskPipelineNode is a graph node with a run method.

        Args:
            name (str): node's name
            component (Component): component instance
        """
        super().__init__(name, {})
        self.component = component

    async def execute(self, **kwargs: Any) -> RunResult | None:
        """Execute the task

        Returns:
            RunResult | None: RunResult with status and result dict
            if the task run successfully, None if the status update
            was unsuccessful.
        """
        component_result = await self.component.run(**kwargs)
        run_result = RunResult(
            result=component_result,
        )
        return run_result

    async def run(self, inputs: dict[str, Any]) -> RunResult | None:
        """Main method to execute the task."""
        logger.debug(f"TASK START {self.name=} input={prettify(inputs)}")
        start_time = default_timer()
        res = await self.execute(**inputs)
        end_time = default_timer()
        logger.debug(
            f"TASK FINISHED {self.name} in {end_time - start_time} res={prettify(res)}"
        )
        return res


class Orchestrator:
    """Orchestrate a pipeline.

    The orchestrator is responsible for:
    - finding the next tasks to execute
    - building the inputs for each task
    - calling the run method on each task

    Once a TaskNode is done, it calls the `on_task_complete` callback
    that will save the results, find the next tasks to be executed
    (checking that all dependencies are met), and run them.
    """

    def __init__(self, pipeline: "Pipeline"):
        self.pipeline = pipeline
        self.run_id = str(uuid.uuid4())

    async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
        """Get inputs and run a specific task. Once the task is done,
        calls the on_task_complete method.

        Args:
            task (TaskPipelineNode): The task to be run
            data (dict[str, Any]): The pipeline input data

        Returns:
            None
        """
        param_mapping = self.get_input_config_for_task(task)
        inputs = await self.get_component_inputs(task.name, param_mapping, data)
        try:
            await self.set_task_status(task.name, RunStatus.RUNNING)
        except PipelineStatusUpdateError:
            logger.debug(
                f"ORCHESTRATOR: TASK ABORTED: {task.name} is already running or done, aborting"
            )
            return None
        res = await task.run(inputs)
        await self.set_task_status(task.name, RunStatus.DONE)
        if res:
            await self.on_task_complete(data=data, task=task, result=res)

    async def set_task_status(self, task_name: str, status: RunStatus) -> None:
        """Set a new status

        Args:
            task_name (str): Name of the component
            status (RunStatus): New status

        Raises:
            PipelineStatusUpdateError if the new status is not
                compatible with the current one.
        """
        # prevent the method from being called by two concurrent async calls
        async with asyncio.Lock():
            current_status = await self.get_status_for_component(task_name)
            if status == current_status:
                raise PipelineStatusUpdateError(f"Status is already '{status}'")
            if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
                raise PipelineStatusUpdateError("Can't go from DONE to RUNNING")
            return await self.pipeline.store.add_status_for_component(
                self.run_id, task_name, status.value
            )

    async def on_task_complete(
        self, data: dict[str, Any], task: TaskPipelineNode, result: RunResult
    ) -> None:
        """When a given task is complete, it will call this method
        to find the next tasks to run.
        """
        # first save this component results
        res_to_save = None
        if result.result:
            res_to_save = result.result.model_dump()
        await self.add_result_for_component(
            task.name, res_to_save, is_final=task.is_leaf()
        )
        # then get the next tasks to be executed
        # and run them in //
        await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)])

    async def check_dependencies_complete(self, task: TaskPipelineNode) -> None:
        """Check that all parent tasks are complete.

        Raises:
            MissingDependencyError if a parent task's status is not DONE.
        """
        dependencies = self.pipeline.previous_edges(task.name)
        for d in dependencies:
            d_status = await self.get_status_for_component(d.start)
            if d_status != RunStatus.DONE:
                logger.debug(
                    f"ORCHESTRATOR {self.run_id}: TASK DELAYED: Missing dependency {d.start} for {task.name} "
                    f"(status: {d_status}). "
                    "Will try again when dependency is complete."
                )
                raise PipelineMissingDependencyError()

    async def next(
        self, task: TaskPipelineNode
    ) -> AsyncGenerator[TaskPipelineNode, None]:
        """Find the next tasks to be executed after `task` is complete.

        1. Find the task children
        2. Check each child's status:
            - if it's already running or done, do not need to run it again
            - otherwise, check that all its dependencies are met, if yes
                add this task to the list of next tasks to be executed
        """
        possible_next = self.pipeline.next_edges(task.name)
        for next_edge in possible_next:
            next_node = self.pipeline.get_node_by_name(next_edge.end)
            # check status
            next_node_status = await self.get_status_for_component(next_node.name)
            if next_node_status in [RunStatus.RUNNING, RunStatus.DONE]:
                # already running
                continue
            # check deps
            try:
                await self.check_dependencies_complete(next_node)
            except PipelineMissingDependencyError:
                continue
            logger.debug(
                f"ORCHESTRATOR {self.run_id}: enqueuing next task: {next_node.name}"
            )
            yield next_node
        return

    def get_input_config_for_task(
        self, task: TaskPipelineNode
    ) -> dict[str, dict[str, str]]:
        """Build input definition for a given task.,
        The method needs to access the input defs defined in the edges
        between this task and its parents.

        Args:
            task (TaskPipelineNode): the task to get the input config for

        Returns:
            dict: a dict of
                {input_parameter: {source_component_name: "", param_name: ""}}
        """
        if not self.pipeline.is_validated:
            raise PipelineDefinitionError(
                "You must validate the pipeline input config first. Call `pipeline.validate_parameter_mapping()`"
            )
        return self.pipeline.param_mapping.get(task.name) or {}

    async def get_component_inputs(
        self,
        component_name: str,
        param_mapping: dict[str, dict[str, str]],
        input_data: dict[str, Any],
    ) -> dict[str, Any]:
        """Find the component inputs from:
        - input_config: the mapping between components results and inputs
            (results are stored in the pipeline result store)
        - input_data: the user input data

        Args:
            component_name (str): the component/task name
            param_mapping (dict[str, dict[str, str]]): the input config
            input_data (dict[str, Any]): the pipeline input data (user input)
        """
        component_inputs: dict[str, Any] = input_data.get(component_name, {})
        if param_mapping:
            for parameter, mapping in param_mapping.items():
                component = mapping["component"]
                output_param = mapping.get("param")
                component_result = await self.get_results_for_component(component)
                if output_param is not None:
                    value = component_result.get(output_param)
                else:
                    value = component_result
                if parameter in component_inputs:
                    m = f"{component}.{parameter}" if parameter else component
                    warnings.warn(
                        f"In component '{component_name}', parameter '{parameter}' from user input will be ignored and replaced by '{m}'"
                    )
                component_inputs[parameter] = value
        return component_inputs

    async def add_result_for_component(
        self, name: str, result: dict[str, Any] | None, is_final: bool = False
    ) -> None:
        """This is where we save the results in the result store and, optionally,
        in the final result store.
        """
        await self.pipeline.store.add_result_for_component(self.run_id, name, result)
        if is_final:
            # The pipeline only returns the results
            # of the leaf nodes
            # TODO: make this configurable in the future.
            existing_results = await self.pipeline.final_results.get(self.run_id) or {}
            existing_results[name] = result
            await self.pipeline.final_results.add(
                self.run_id, existing_results, overwrite=True
            )

    async def get_results_for_component(self, name: str) -> Any:
        return await self.pipeline.store.get_result_for_component(self.run_id, name)

    async def get_status_for_component(self, name: str) -> RunStatus:
        status = await self.pipeline.store.get_status_for_component(self.run_id, name)
        if status is None:
            return RunStatus.UNKNOWN
        return RunStatus(status)

    async def run(self, data: dict[str, Any]) -> None:
        """Run the pipline, starting from the root nodes
        (node without any parent). Then the callback on_task_complete
        will handle the task dependencies.
        """
        tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
        await asyncio.gather(*tasks)


class PipelineResult(BaseModel):
    run_id: str
    result: Any


[docs] class Pipeline(PipelineGraph[TaskPipelineNode, PipelineEdge]): """This is the main pipeline, where components and their execution order are defined""" def __init__(self, store: Optional[ResultStore] = None) -> None: super().__init__() self.store = store or InMemoryStore() self.final_results = InMemoryStore() self.is_validated = False self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict) """ Dict structure: { component_name : { param_name: { component: "", # source component name param_name: "", } } } """ self.missing_inputs: dict[str, list[str]] = defaultdict() @classmethod def from_template( cls, pipeline_template: PipelineDefinition, store: Optional[ResultStore] = None ) -> Pipeline: warnings.warn( "from_template is deprecated, use from_definition instead", DeprecationWarning, stacklevel=2, ) return cls.from_definition(pipeline_template, store) @classmethod def from_definition( cls, pipeline_definition: PipelineDefinition, store: Optional[ResultStore] = None, ) -> Pipeline: """Create a Pipeline from a pydantic model defining the components and their connections Args: pipeline_definition (PipelineDefinition): An object defining components and how they are connected to each other. store (Optional[ResultStore]): Where the results are stored. By default, uses the InMemoryStore. """ pipeline = Pipeline(store=store) for component in pipeline_definition.components: pipeline.add_component( component.component, component.name, ) for edge in pipeline_definition.connections: pipeline_edge = PipelineEdge( edge.start, edge.end, data={"input_config": edge.input_config} ) pipeline.add_edge(pipeline_edge) return pipeline def show_as_dict(self) -> dict[str, Any]: component_config = [] for name, task in self._nodes.items(): component_config.append( ComponentDefinition(name=name, component=task.component) ) connection_config = [] for edge in self._edges: connection_config.append( ConnectionDefinition( start=edge.start, end=edge.end, input_config=edge.data["input_config"] if edge.data else {}, ) ) pipeline_config = PipelineDefinition( components=component_config, connections=connection_config ) return pipeline_config.model_dump()
[docs] def draw( self, path: str, layout: str = "dot", hide_unused_outputs: bool = True ) -> Any: G = self.get_pygraphviz_graph(hide_unused_outputs) G.layout(layout) G.draw(path)
def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> pgv.AGraph: if pgv is None: raise ImportError( "Could not import pygraphviz. " "Follow installation instruction in pygraphviz documentation " "to get it up and running on your system." ) self.validate_parameter_mapping() G = pgv.AGraph(strict=False, directed=True) # create a node for each component for n, node in self._nodes.items(): comp_inputs = ",".join( f"{i}: {d['annotation']}" for i, d in node.component.component_inputs.items() ) G.add_node( n, node_type="component", shape="rectangle", label=f"{node.component.__class__.__name__}: {n}({comp_inputs})", ) # create a node for each output field and connect them it to its component for o in node.component.component_outputs: param_node_name = f"{n}.{o}" G.add_node(param_node_name, label=o, node_type="output") G.add_edge(n, param_node_name) # then we create the edges between a component output # and the component it gets added to for component_name, params in self.param_mapping.items(): for param, mapping in params.items(): source_component = mapping["component"] source_param_name = mapping.get("param") if source_param_name: source_output_node = f"{source_component}.{source_param_name}" else: source_output_node = source_component G.add_edge(source_output_node, component_name, label=param) # remove outputs that are not mapped if hide_unused_outputs: for n in G.nodes(): if n.attr["node_type"] == "output" and G.out_degree(n) == 0: # type: ignore G.remove_node(n) return G
[docs] def add_component(self, component: Component, name: str) -> None: """Add a new component. Components are uniquely identified by their name. If 'name' is already in the pipeline, a ValueError is raised.""" task = TaskPipelineNode(name, component) self.add_node(task) # invalidate the pipeline if it was already validated self.invalidate()
def set_component(self, name: str, component: Component) -> None: """Replace a component with another. If 'name' is not yet in the pipeline, raises ValueError. """ task = TaskPipelineNode(name, component) self.set_node(task) # invalidate the pipeline if it was already validated self.invalidate()
[docs] def connect( self, start_component_name: str, end_component_name: str, input_config: Optional[dict[str, str]] = None, ) -> None: """Connect one component to another. Args: start_component_name (str): name of the component as defined in the add_component method end_component_name (str): name of the component as defined in the add_component method input_config (Optional[dict[str, str]]): end component input configuration: propagate previous components outputs. Raises: PipelineDefinitionError: if the provided component are not in the Pipeline or if the graph that would be created by this connection is cyclic. """ edge = PipelineEdge( start_component_name, end_component_name, data={"input_config": input_config}, ) try: self.add_edge(edge) except KeyError: raise PipelineDefinitionError( f"{start_component_name} or {end_component_name} is not in the Pipeline" ) if self.is_cyclic(): raise PipelineDefinitionError("Cyclic graph are not allowed") # invalidate the pipeline if it was already validated self.invalidate()
def invalidate(self) -> None: self.is_validated = False self.param_mapping = defaultdict(dict) self.missing_inputs = defaultdict() def validate_parameter_mapping(self) -> None: """Go through the graph and make sure parameter mapping is valid (without considering user input yet) """ if self.is_validated: return for task in self._nodes.values(): self.validate_parameter_mapping_for_task(task) self.is_validated = True def validate_input_data(self, data: dict[str, Any]) -> bool: """Performs parameter and data validation before running the pipeline: - Check parameters defined in the connect method - Make sure the missing parameters are present in the input `data` dict. Args: data (dict[str, Any]): input data to use for validation (usually from Pipeline.run) Raises: PipelineDefinitionError if any parameter mapping is invalid or if a parameter is missing. """ if not self.is_validated: self.validate_parameter_mapping() for task in self._nodes.values(): if task.name not in self.param_mapping: self.validate_parameter_mapping_for_task(task) missing_params = self.missing_inputs[task.name] task_data = data.get(task.name) or {} for param in missing_params: if param not in task_data: raise PipelineDefinitionError( f"Parameter '{param}' not provided for component '{task.name}'" ) return True def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool: """Make sure that all the parameter mapping for a given task are valid. Does not consider user input yet. Considering the naming {param => target (component, [output_parameter]) }, the mapping is valid if: - 'param' is a valid input for task - 'param' has not already been mapped - The target component exists in the pipeline and, if specified, the target output parameter is a valid field in the target component's result model. This method builds the param_mapping and missing_inputs instance variables. """ component = task.component expected_mandatory_inputs = [ param_name for param_name, config in component.component_inputs.items() if config["has_default"] is False ] # start building the actual input list, starting # from the inputs provided in the pipeline.run method actual_inputs = [] prev_edges = self.previous_edges(task.name) # then, iterate over all parents to find the parameter propagation for edge in prev_edges: edge_data = edge.data or {} edge_inputs = edge_data.get("input_config") or {} # check that the previous component is actually returning # the mapped parameter for param, path in edge_inputs.items(): if param in self.param_mapping[task.name]: raise PipelineDefinitionError( f"Parameter '{param}' already mapped to {self.param_mapping[task.name][param]}" ) if param not in task.component.component_inputs: raise PipelineDefinitionError( f"Parameter '{param}' is not a valid input for component '{task.name}' of type '{task.component.__class__.__name__}'" ) try: source_component_name, param_name = path.split(".") except ValueError: # no specific output mapped # the full source component result will be # passed to the next component self.param_mapping[task.name][param] = { "component": path, } continue try: source_node = self.get_node_by_name(source_component_name) except KeyError: raise PipelineDefinitionError( f"Component {source_component_name} does not exist in the pipeline," f" can not map {param} to {path} for {task.name}." ) source_component = source_node.component source_component_outputs = source_component.component_outputs if param_name and param_name not in source_component_outputs: raise PipelineDefinitionError( f"Parameter {param_name} is not valid output for " f"{source_component_name} (must be one of " f"{list(source_component_outputs.keys())})" ) self.param_mapping[task.name][param] = { "component": source_component_name, "param": param_name, } actual_inputs.extend(list(edge_inputs.keys())) missing_inputs = list(set(expected_mandatory_inputs) - set(actual_inputs)) self.missing_inputs[task.name] = missing_inputs return True
[docs] async def run(self, data: dict[str, Any]) -> PipelineResult: logger.debug("PIPELINE START") start_time = default_timer() self.invalidate() self.validate_input_data(data) orchestrator = Orchestrator(self) logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") await orchestrator.run(data) end_time = default_timer() logger.debug( f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s" ) return PipelineResult( run_id=orchestrator.run_id, result=await self.final_results.get(orchestrator.run_id), )