Graphrag integration (#4612)

* add initial global search draft

* add graphrag dep

* fix local search embedding

* linting

* add from config constructor

* remove draft notebook

* update config factory and add docstrings

* add graphrag sample

* add sample prompts

* update readme

* update deps

* Add API docs

* Update python/samples/agentchat_graphrag/requirements.txt

* Update python/samples/agentchat_graphrag/requirements.txt

* update docstrings with snippet and doc ref

* lint

* improve set up instructions in docstring

* lint

* update lock

* Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_local_search.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* add unit tests

* update lock

* update uv lock

* add docstring newlines

* stubs and typing on graphrag tests

* fix docstrings

* fix mypy error

* + linting and type fixes

* type fix graphrag sample

* Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_local_search.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* Update python/samples/agentchat_graphrag/requirements.txt

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* update overrides

* fix docstring client imports

* additional docstring fix

* add docstring missing import

* use openai and fix db path

* use console for displaying messages

* add model config and gitignore

* update readme

* lint

* Update python/samples/agentchat_graphrag/README.md

* Update python/samples/agentchat_graphrag/README.md

* Comment remaining azure config

---------

Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Leonardo Pinheiro 2025-01-15 21:04:17 +10:00 committed by GitHub
parent 8efe0c45b0
commit 95bd514a9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 2364 additions and 52 deletions

View File

@ -51,6 +51,7 @@ python/autogen_ext.teams.magentic_one
python/autogen_ext.models.openai
python/autogen_ext.models.replay
python/autogen_ext.tools.langchain
python/autogen_ext.tools.graphrag
python/autogen_ext.tools.code_execution
python/autogen_ext.code_executors.local
python/autogen_ext.code_executors.docker

View File

@ -0,0 +1,8 @@
autogen\_ext.tools.graphrag
===========================
.. automodule:: autogen_ext.tools.graphrag
:members:
:undoc-members:
:show-inheritance:

View File

@ -27,6 +27,7 @@ file-surfer = [
"autogen-agentchat==0.4.1",
"markitdown>=0.0.1a2",
]
graphrag = ["graphrag>=1.0.1"]
web-surfer = [
"autogen-agentchat==0.4.1",
"playwright>=1.48.0",
@ -57,6 +58,7 @@ packages = ["src/autogen_ext"]
dev = [
"autogen_test_utils",
"langchain-experimental",
"pandas-stubs>=2.2.3.241126",
]
[tool.ruff]

View File

@ -0,0 +1,25 @@
from ._config import (
GlobalContextConfig,
GlobalDataConfig,
LocalContextConfig,
LocalDataConfig,
MapReduceConfig,
SearchConfig,
)
from ._global_search import GlobalSearchTool, GlobalSearchToolArgs, GlobalSearchToolReturn
from ._local_search import LocalSearchTool, LocalSearchToolArgs, LocalSearchToolReturn
__all__ = [
"GlobalSearchTool",
"LocalSearchTool",
"GlobalDataConfig",
"LocalDataConfig",
"GlobalContextConfig",
"GlobalSearchToolArgs",
"GlobalSearchToolReturn",
"LocalContextConfig",
"LocalSearchToolArgs",
"LocalSearchToolReturn",
"MapReduceConfig",
"SearchConfig",
]

View File

@ -0,0 +1,59 @@
from pydantic import BaseModel
class DataConfig(BaseModel):
input_dir: str
entity_table: str = "create_final_nodes"
entity_embedding_table: str = "create_final_entities"
community_level: int = 2
class GlobalDataConfig(DataConfig):
community_table: str = "create_final_communities"
community_report_table: str = "create_final_community_reports"
class LocalDataConfig(DataConfig):
relationship_table: str = "create_final_relationships"
text_unit_table: str = "create_final_text_units"
class ContextConfig(BaseModel):
max_data_tokens: int = 8000
class GlobalContextConfig(ContextConfig):
use_community_summary: bool = False
shuffle_data: bool = True
include_community_rank: bool = True
min_community_rank: int = 0
community_rank_name: str = "rank"
include_community_weight: bool = True
community_weight_name: str = "occurrence weight"
normalize_community_weight: bool = True
max_data_tokens: int = 12000
class LocalContextConfig(ContextConfig):
text_unit_prop: float = 0.5
community_prop: float = 0.25
include_entity_rank: bool = True
rank_description: str = "number of relationships"
include_relationship_weight: bool = True
relationship_ranking_attribute: str = "rank"
class MapReduceConfig(BaseModel):
map_max_tokens: int = 1000
map_temperature: float = 0.0
reduce_max_tokens: int = 2000
reduce_temperature: float = 0.0
allow_general_knowledge: bool = False
json_mode: bool = False
response_type: str = "multiple paragraphs"
class SearchConfig(BaseModel):
max_tokens: int = 1500
temperature: float = 0.0
response_type: str = "multiple paragraphs"

View File

@ -0,0 +1,214 @@
# mypy: disable-error-code="no-any-unimported,misc"
from pathlib import Path
import pandas as pd
import tiktoken
from autogen_core import CancellationToken
from autogen_core.tools import BaseTool
from graphrag.config.config_file_loader import load_config_from_file
from graphrag.query.indexer_adapters import (
read_indexer_communities,
read_indexer_entities,
read_indexer_reports,
)
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.get_client import get_llm
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
from graphrag.query.structured_search.global_search.search import GlobalSearch
from pydantic import BaseModel, Field
from ._config import GlobalContextConfig as ContextConfig
from ._config import GlobalDataConfig as DataConfig
from ._config import MapReduceConfig
_default_context_config = ContextConfig()
_default_mapreduce_config = MapReduceConfig()
class GlobalSearchToolArgs(BaseModel):
query: str = Field(..., description="The user query to perform global search on.")
class GlobalSearchToolReturn(BaseModel):
answer: str
class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]):
"""Enables running GraphRAG global search queries as an AutoGen tool.
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
The search combines graph-based document relationships with semantic embeddings to find relevant information.
.. note::
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
To install:
.. code-block:: bash
pip install -U "autogen-agentchat" "autogen-ext[graphrag]"
Before using this tool, you must complete the GraphRAG setup and indexing process:
1. Follow the GraphRAG documentation to initialize your project and settings
2. Configure and tune your prompts for the specific use case
3. Run the indexing process to generate the required data files
4. Ensure you have the settings.yaml file from the setup process
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
for detailed instructions on completing these prerequisite steps.
Example usage with AssistantAgent:
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.ui import Console
from autogen_ext.tools.graphrag import GlobalSearchTool
from autogen_agentchat.agents import AssistantAgent
async def main():
# Initialize the OpenAI client
openai_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key="<api-key>",
)
# Set up global search tool
global_tool = GlobalSearchTool.from_settings(settings_path="./settings.yaml")
# Create assistant agent with the global search tool
assistant_agent = AssistantAgent(
name="search_assistant",
tools=[global_tool],
model_client=openai_client,
system_message=(
"You are a tool selector AI assistant using the GraphRAG framework. "
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
"For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function."
),
)
# Run a sample query
query = "What is the overall sentiment of the community reports?"
await Console(assistant_agent.run_stream(task=query))
if __name__ == "__main__":
asyncio.run(main())
"""
def __init__(
self,
token_encoder: tiktoken.Encoding,
llm: BaseLLM,
data_config: DataConfig,
context_config: ContextConfig = _default_context_config,
mapreduce_config: MapReduceConfig = _default_mapreduce_config,
):
super().__init__(
args_type=GlobalSearchToolArgs,
return_type=GlobalSearchToolReturn,
name="global_search_tool",
description="Perform a global search with given parameters using graphrag.",
)
# Use the provided LLM
self._llm = llm
# Load parquet files
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
report_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.community_report_table}.parquet"
)
entity_embedding_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.entity_embedding_table}.parquet"
)
communities = read_indexer_communities(community_df, entity_df, report_df)
reports = read_indexer_reports(report_df, entity_df, data_config.community_level)
entities = read_indexer_entities(entity_df, entity_embedding_df, data_config.community_level)
context_builder = GlobalCommunityContext(
community_reports=reports,
communities=communities,
entities=entities,
token_encoder=token_encoder,
)
context_builder_params = {
"use_community_summary": context_config.use_community_summary,
"shuffle_data": context_config.shuffle_data,
"include_community_rank": context_config.include_community_rank,
"min_community_rank": context_config.min_community_rank,
"community_rank_name": context_config.community_rank_name,
"include_community_weight": context_config.include_community_weight,
"community_weight_name": context_config.community_weight_name,
"normalize_community_weight": context_config.normalize_community_weight,
"max_tokens": context_config.max_data_tokens,
"context_name": "Reports",
}
map_llm_params = {
"max_tokens": mapreduce_config.map_max_tokens,
"temperature": mapreduce_config.map_temperature,
"response_format": {"type": "json_object"},
}
reduce_llm_params = {
"max_tokens": mapreduce_config.reduce_max_tokens,
"temperature": mapreduce_config.reduce_temperature,
}
self._search_engine = GlobalSearch(
llm=self._llm,
context_builder=context_builder,
token_encoder=token_encoder,
max_data_tokens=context_config.max_data_tokens,
map_llm_params=map_llm_params,
reduce_llm_params=reduce_llm_params,
allow_general_knowledge=mapreduce_config.allow_general_knowledge,
json_mode=mapreduce_config.json_mode,
context_builder_params=context_builder_params,
concurrent_coroutines=32,
response_type=mapreduce_config.response_type,
)
async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn:
result = await self._search_engine.asearch(args.query)
assert isinstance(result.response, str), "Expected response to be a string"
return GlobalSearchToolReturn(answer=result.response)
@classmethod
def from_settings(cls, settings_path: str | Path) -> "GlobalSearchTool":
"""Create a GlobalSearchTool instance from GraphRAG settings file.
Args:
settings_path: Path to the GraphRAG settings.yaml file
Returns:
An initialized GlobalSearchTool instance
"""
# Load GraphRAG config
config = load_config_from_file(settings_path)
# Initialize token encoder
token_encoder = tiktoken.get_encoding(config.encoding_model)
# Initialize LLM using graphrag's get_client
llm = get_llm(config)
# Create data config from storage paths
data_config = DataConfig(
input_dir=str(Path(config.storage.base_dir)),
)
return cls(
token_encoder=token_encoder,
llm=llm,
data_config=data_config,
context_config=_default_context_config,
mapreduce_config=_default_mapreduce_config,
)

View File

@ -0,0 +1,227 @@
# mypy: disable-error-code="no-any-unimported,misc"
import os
from pathlib import Path
import pandas as pd
import tiktoken
from autogen_core import CancellationToken
from autogen_core.tools import BaseTool
from graphrag.config.config_file_loader import load_config_from_file
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_text_units,
)
from graphrag.query.llm.base import BaseLLM, BaseTextEmbedding
from graphrag.query.llm.get_client import get_llm, get_text_embedder
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from pydantic import BaseModel, Field
from ._config import LocalContextConfig, SearchConfig
from ._config import LocalDataConfig as DataConfig
_default_context_config = LocalContextConfig()
_default_search_config = SearchConfig()
class LocalSearchToolArgs(BaseModel):
query: str = Field(..., description="The user query to perform local search on.")
class LocalSearchToolReturn(BaseModel):
answer: str = Field(..., description="The answer to the user query.")
class LocalSearchTool(BaseTool[LocalSearchToolArgs, LocalSearchToolReturn]):
"""Enables running GraphRAG local search queries as an AutoGen tool.
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
The search combines local document context with semantic embeddings to find relevant information.
.. note::
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
To install:
.. code-block:: bash
pip install -U "autogen-agentchat" "autogen-ext[graphrag]"
Before using this tool, you must complete the GraphRAG setup and indexing process:
1. Follow the GraphRAG documentation to initialize your project and settings
2. Configure and tune your prompts for the specific use case
3. Run the indexing process to generate the required data files
4. Ensure you have the settings.yaml file from the setup process
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
for detailed instructions on completing these prerequisite steps.
Example usage with AssistantAgent:
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.ui import Console
from autogen_ext.tools.graphrag import LocalSearchTool
from autogen_agentchat.agents import AssistantAgent
async def main():
# Initialize the OpenAI client
openai_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key="<api-key>",
)
# Set up local search tool
local_tool = LocalSearchTool.from_settings(settings_path="./settings.yaml")
# Create assistant agent with the local search tool
assistant_agent = AssistantAgent(
name="search_assistant",
tools=[local_tool],
model_client=openai_client,
system_message=(
"You are a tool selector AI assistant using the GraphRAG framework. "
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
"For specific, detailed information about particular entities or relationships, call the 'local_search' function."
),
)
# Run a sample query
query = "What does the station-master say about Dr. Becher?"
await Console(assistant_agent.run_stream(task=query))
if __name__ == "__main__":
asyncio.run(main())
Args:
token_encoder (tiktoken.Encoding): The tokenizer used for text encoding
llm (BaseLLM): The language model to use for search
embedder (BaseTextEmbedding): The text embedding model to use
data_config (DataConfig): Configuration for data source locations and settings
context_config (LocalContextConfig, optional): Configuration for context building. Defaults to default config.
search_config (SearchConfig, optional): Configuration for search operations. Defaults to default config.
"""
def __init__(
self,
token_encoder: tiktoken.Encoding,
llm: BaseLLM,
embedder: BaseTextEmbedding,
data_config: DataConfig,
context_config: LocalContextConfig = _default_context_config,
search_config: SearchConfig = _default_search_config,
):
super().__init__(
args_type=LocalSearchToolArgs,
return_type=LocalSearchToolReturn,
name="local_search_tool",
description="Perform a local search with given parameters using graphrag.",
)
# Use the adapter
self._llm = llm
self._embedder = embedder
# Load parquet files
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
entity_embedding_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.entity_embedding_table}.parquet"
)
relationship_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.relationship_table}.parquet"
)
text_unit_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.text_unit_table}.parquet") # type: ignore
# Read data using indexer adapters
entities = read_indexer_entities(entity_df, entity_embedding_df, data_config.community_level)
relationships = read_indexer_relationships(relationship_df)
text_units = read_indexer_text_units(text_unit_df)
# Set up vector store for entity embeddings
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=os.path.join(data_config.input_dir, "lancedb"))
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=f"{data_config.input_dir}/lancedb")
# Set up context builder
context_builder = LocalSearchMixedContext(
entities=entities,
entity_text_embeddings=description_embedding_store,
text_embedder=self._embedder,
text_units=text_units,
relationships=relationships,
token_encoder=token_encoder,
)
context_builder_params = {
"text_unit_prop": context_config.text_unit_prop,
"community_prop": context_config.community_prop,
"include_entity_rank": context_config.include_entity_rank,
"rank_description": context_config.rank_description,
"include_relationship_weight": context_config.include_relationship_weight,
"relationship_ranking_attribute": context_config.relationship_ranking_attribute,
"max_tokens": context_config.max_data_tokens,
}
llm_params = {
"max_tokens": search_config.max_tokens,
"temperature": search_config.temperature,
}
self._search_engine = LocalSearch(
llm=self._llm,
context_builder=context_builder,
token_encoder=token_encoder,
llm_params=llm_params,
context_builder_params=context_builder_params,
response_type=search_config.response_type,
)
async def run(self, args: LocalSearchToolArgs, cancellation_token: CancellationToken) -> LocalSearchToolReturn:
result = await self._search_engine.asearch(args.query) # type: ignore
assert isinstance(result.response, str), "Expected response to be a string"
return LocalSearchToolReturn(answer=result.response)
@classmethod
def from_settings(cls, settings_path: str | Path) -> "LocalSearchTool":
"""Create a LocalSearchTool instance from GraphRAG settings file.
Args:
settings_path: Path to the GraphRAG settings.yaml file
Returns:
An initialized LocalSearchTool instance
"""
# Load GraphRAG config
config = load_config_from_file(settings_path)
# Initialize token encoder
token_encoder = tiktoken.get_encoding(config.encoding_model)
# Initialize LLM and embedder using graphrag's get_client functions
llm = get_llm(config)
embedder = get_text_embedder(config)
# Create data config from storage paths
data_config = DataConfig(
input_dir=str(Path(config.storage.base_dir)),
)
return cls(
token_encoder=token_encoder,
llm=llm,
embedder=embedder,
data_config=data_config,
context_config=_default_context_config,
search_config=_default_search_config,
)

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,184 @@
# mypy: disable-error-code="no-any-unimported"
import os
import tempfile
from typing import Any, AsyncGenerator, Generator
import pandas as pd
import pytest
import tiktoken
from autogen_core import CancellationToken
from autogen_ext.tools.graphrag import GlobalSearchTool, GlobalSearchToolReturn, LocalSearchTool, LocalSearchToolReturn
from autogen_ext.tools.graphrag._config import GlobalDataConfig, LocalDataConfig
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
from graphrag.model.types import TextEmbedder
from graphrag.query.llm.base import BaseLLM, BaseTextEmbedding
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult
class MockLLM(BaseLLM): # type: ignore
def generate(
self,
messages: str | list[Any],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
return "Mock response"
def stream_generate(
self, messages: str | list[Any], callbacks: list[BaseLLMCallback] | None = None, **kwargs: Any
) -> Generator[str, None, None]:
yield "Mock response"
async def agenerate(
self,
messages: str | list[Any],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
return "Mock response"
async def astream_generate( # type: ignore
self, messages: str | list[Any], callbacks: list[BaseLLMCallback] | None = None, **kwargs: Any
) -> AsyncGenerator[str, None]:
yield "Mock response"
class MockTextEmbedding(BaseTextEmbedding): # type: ignore
def embed(self, text: str, **kwargs: Any) -> list[float]:
return [0.1] * 10
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
return [0.1] * 10
class MockVectorStore(BaseVectorStore): # type: ignore
def __init__(self, **kwargs: Any) -> None:
super().__init__(collection_name="mock", **kwargs)
self.documents: dict[str | int, VectorStoreDocument] = {}
def connect(self, **kwargs: Any) -> None:
pass
def load_documents(self, documents: list[VectorStoreDocument], overwrite: bool = True) -> None:
if overwrite:
self.documents = {}
for doc in documents:
self.documents[doc.id] = doc
def filter_by_id(self, include_ids: list[str] | list[int]) -> None:
return None
def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
docs = list(self.documents.values())[:k]
return [VectorStoreSearchResult(document=doc, score=0.9) for doc in docs]
def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
return self.similarity_search_by_vector([0.1] * 10, k)
def search_by_id(self, id: str) -> VectorStoreDocument:
return self.documents.get(id, VectorStoreDocument(id=id, text=None, vector=None))
@pytest.mark.asyncio
async def test_global_search_tool(
community_df_fixture: pd.DataFrame,
entity_df_fixture: pd.DataFrame,
report_df_fixture: pd.DataFrame,
entity_embedding_fixture: pd.DataFrame,
) -> None:
# Create a temporary directory to simulate the data config
with tempfile.TemporaryDirectory() as tempdir:
# Save fixtures to parquet files
community_table = os.path.join(tempdir, "create_final_communities.parquet")
entity_table = os.path.join(tempdir, "create_final_nodes.parquet")
community_report_table = os.path.join(tempdir, "create_final_community_reports.parquet")
entity_embedding_table = os.path.join(tempdir, "create_final_entities.parquet")
community_df_fixture.to_parquet(community_table) # type: ignore
entity_df_fixture.to_parquet(entity_table) # type: ignore
report_df_fixture.to_parquet(community_report_table) # type: ignore
entity_embedding_fixture.to_parquet(entity_embedding_table) # type: ignore
# Initialize the data config with the temporary directory
data_config = GlobalDataConfig(
input_dir=tempdir,
community_table="create_final_communities",
entity_table="create_final_nodes",
community_report_table="create_final_community_reports",
entity_embedding_table="create_final_entities",
)
# Initialize the GlobalSearchTool with mock data
token_encoder = tiktoken.encoding_for_model("gpt-4o")
llm = MockLLM()
global_search_tool = GlobalSearchTool(token_encoder=token_encoder, llm=llm, data_config=data_config)
# Example of running the tool and checking the result
query = "What is the overall sentiment of the community reports?"
cancellation_token = CancellationToken()
result = await global_search_tool.run_json(args={"query": query}, cancellation_token=cancellation_token)
assert isinstance(result, GlobalSearchToolReturn)
assert isinstance(result.answer, str)
@pytest.mark.asyncio
async def test_local_search_tool(
entity_df_fixture: pd.DataFrame,
relationship_df_fixture: pd.DataFrame,
text_unit_df_fixture: pd.DataFrame,
entity_embedding_fixture: pd.DataFrame,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Create a temporary directory to simulate the data config
with tempfile.TemporaryDirectory() as tempdir:
# Save fixtures to parquet files
entity_table = os.path.join(tempdir, "create_final_nodes.parquet")
relationship_table = os.path.join(tempdir, "create_final_relationships.parquet")
text_unit_table = os.path.join(tempdir, "create_final_text_units.parquet")
entity_embedding_table = os.path.join(tempdir, "create_final_entities.parquet")
entity_df_fixture.to_parquet(entity_table) # type: ignore
relationship_df_fixture.to_parquet(relationship_table) # type: ignore
text_unit_df_fixture.to_parquet(text_unit_table) # type: ignore
entity_embedding_fixture.to_parquet(entity_embedding_table) # type: ignore
# Initialize the data config with the temporary directory
data_config = LocalDataConfig(
input_dir=tempdir,
entity_table="create_final_nodes",
relationship_table="create_final_relationships",
text_unit_table="create_final_text_units",
entity_embedding_table="create_final_entities",
)
# Initialize the LocalSearchTool with mock data
token_encoder = tiktoken.encoding_for_model("gpt-4o")
llm = MockLLM()
embedder = MockTextEmbedding()
# Mock the vector store
def mock_vector_store_factory(*args: Any, **kwargs: dict[str, Any]) -> MockVectorStore:
store = MockVectorStore()
store.document_collection = store # Make the store act as its own collection
return store
# Patch the LanceDBVectorStore class
monkeypatch.setattr("autogen_ext.tools.graphrag._local_search.LanceDBVectorStore", mock_vector_store_factory) # type: ignore
local_search_tool = LocalSearchTool(
token_encoder=token_encoder, llm=llm, embedder=embedder, data_config=data_config
)
# Example of running the tool and checking the result
query = "What are the relationships between Dr. Becher and the station-master?"
cancellation_token = CancellationToken()
result = await local_search_tool.run_json(args={"query": query}, cancellation_token=cancellation_token)
assert isinstance(result, LocalSearchToolReturn)
assert isinstance(result.answer, str)

View File

@ -22,6 +22,12 @@ dev = [
"chainlit",
]
[tool.uv]
override-dependencies = [
"tenacity==9.0.0",
"aiofiles==24.1.0",
"chainlit==2.0.1",
]
[tool.uv.workspace]
members = ["packages/*"]

View File

@ -0,0 +1,3 @@
model_config.json
data
cache

View File

@ -0,0 +1,57 @@
# Building a Multi-Agent Application with AutoGen and GraphRAG
In this sample, we will build a chat interface that interacts with a `RoundRobinGroupChat` team built using the [AutoGen AgentChat](https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/index.html) API and the GraphRAG framework.
## High-Level Description
The `app.py` script sets up a chat interface that communicates with the AutoGen team. When a chat starts, it:
- Initializes an AgentChat team with both local and global search tools.
- As user query is sent to the team with the agent, which must select the appropriate tool to use, the query is then passed to the appropriate tool to respond.
- As agents respond/act, their responses are streamed back to the chat interface.
## What is GraphRAG?
GraphRAG (Graph-based Retrieval-Augmented Generation) is a framework designed to enhance multi-agent systems by providing robust tools for information retrieval and reasoning. It leverages graph structures to organize and query data efficiently, enabling both global and local search capabilities.
Global Search: Global search involves querying the entire indexed dataset to retrieve relevant information. It is ideal for broad queries where the required information might be scattered across multiple documents or nodes in the graph.
Local Search: Local search focuses on a specific subset of the data, such as a particular node or neighborhood in the graph. This approach is used for queries that are contextually tied to a specific segment of the data.
By combining these search strategies, GraphRAG ensures comprehensive and context-sensitive responses from the multi-agent team.
## Setup
To set up the project, follow these steps:
1. Install the required Python packages by running:
```bash
pip install -r requirements.txt
```
2. Download the plain text version of "The Adventures of Sherlock Holmes" from [Project Gutenberg](https://www.gutenberg.org/ebooks/1661) and save it to `data/input/sherlock_book.txt`.
3. Adjust the `settings.yaml` file with your LLM and embedding configuration. Ensure that the API keys and other necessary details are correctly set.
4. Create a `model_config.json` file with the Assistant model configuration. Use the `model_config_template.json` file as a reference. Make sure to remove the comments in the template file.
5. Run the `graphrag prompt-tune` command to tune the prompts. This step adjusts the prompts to better fit the context of the downloaded text.
6. After tuning, run the `graphrag index` command to index the data. This process will create the necessary data structures for performing searches. The indexing may take some time, at least 10 minutes on most machines, depending on the connection to the model API.
The outputs will be located in the `data/output/` directory.
## Running the Sample
Run the sample by executing the following command:
```bash
python app.py
Agent response: [FunctionCall(id='call_0xAXMOHLl62QFd9cfIb0S3BO', arguments='{"query":"station-master Dr. Becher"}', name='local_search_tool')]
Agent response: [FunctionExecutionResult(content='{"answer": "### Dr. Becher and the Station-Master\\n\\nDr. Becher is an Englishman who owns a house that caught fire, and he has a foreign patient staying with him [Data: Entities (489)]. The fire at Dr. Becher\'s house was a significant event, as it was described as a great widespread whitewashed building spouting fire at every chink and window, with fire-engines striving to control the blaze [Data: Sources (91); Entities (491)]. The station-master provided information about the fire, confirming that it broke out during the night and worsened, leading to the entire place being in a blaze [Data: Sources (91)].\\n\\nThe station-master also clarified a misunderstanding about Dr. Becher\'s nationality, stating that Dr. Becher is an Englishman, contrary to the engineer\'s assumption that he might be a German. The station-master humorously noted that Dr. Becher is well-fed, unlike his foreign patient, who could benefit from some good Berkshire beef [Data: Sources (91)].\\n\\n### The Fire Incident\\n\\nThe fire at Dr. Becher\'s house was linked to a larger criminal investigation involving a gang of coiners. The fire was inadvertently started by an oil-lamp that was crushed in a press, which was part of the machinery used by the gang. This incident was a turning point in the investigation, as it led to the discovery of the gang\'s operations, although the criminals managed to escape [Data: Sources (91)].\\n\\nThe fire-engines present at the scene were unable to prevent the destruction of the house, and the firemen were perturbed by the strange arrangements they found within the building. Despite their efforts, the house was reduced to ruins, with only some twisted cylinders and iron piping remaining [Data: Sources (91); Entities (491)].\\n\\nIn summary, Dr. Becher\'s house fire was a pivotal event in the investigation of a criminal gang, with the station-master providing key information about the incident and Dr. Becher\'s identity. The fire not only highlighted the dangers associated with the gang\'s activities but also underscored the challenges faced by law enforcement in apprehending the criminals."}', call_id='call_0xAXMOHLl62QFd9cfIb0S3BO')]
```

View File

@ -0,0 +1,66 @@
import argparse
import asyncio
import json
import logging
from typing import Any, Dict
from autogen_agentchat.ui import Console
from autogen_ext.tools.graphrag import (
GlobalSearchTool,
LocalSearchTool,
)
from autogen_agentchat.agents import AssistantAgent
from autogen_core.models import ChatCompletionClient
async def main(model_config: Dict[str, Any]) -> None:
# Initialize the model client from config
model_client = ChatCompletionClient.load_component(model_config)
# Set up global search tool
global_tool = GlobalSearchTool.from_settings(
settings_path="./settings.yaml"
)
local_tool = LocalSearchTool.from_settings(
settings_path="./settings.yaml"
)
# Create assistant agent with both search tools
assistant_agent = AssistantAgent(
name="search_assistant",
tools=[global_tool, local_tool],
model_client=model_client,
system_message=(
"You are a tool selector AI assistant using the GraphRAG framework. "
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
"For specific, detailed information about particular entities or relationships, call the 'local_search' function. "
"For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function. "
"Do not attempt to answer the query directly; focus solely on selecting and calling the correct function."
)
)
# Run a sample query
query = "What does the station-master says about Dr. Becher?"
print(f"\nQuery: {query}")
await Console(assistant_agent.run_stream(task=query))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a GraphRAG search with an agent.")
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
parser.add_argument(
"--model-config", type=str, help="Path to the model configuration file.", default="model_config.json"
)
args = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.WARNING)
logging.getLogger("autogen_core").setLevel(logging.DEBUG)
handler = logging.FileHandler("graphrag_search.log")
logging.getLogger("autogen_core").addHandler(handler)
with open(args.model_config, "r") as f:
model_config = json.load(f)
asyncio.run(main(model_config))

View File

@ -0,0 +1,38 @@
// Use Azure OpenAI with AD token provider.
// {
// "provider": "AzureOpenAIChatCompletionClient",
// "config": {
// "model": "gpt-4o-2024-05-13",
// "azure_endpoint": "https://{your-custom-endpoint}.openai.azure.com/",
// "azure_deployment": "{your-azure-deployment}",
// "api_version": "2024-06-01",
// "azure_ad_token_provider": {
// "provider": "autogen_ext.auth.azure.AzureTokenProvider",
// "config": {
// "provider_kind": "DefaultAzureCredential",
// "scopes": [
// "https://cognitiveservices.azure.com/.default"
// ]
// }
// }
// }
// }
// Use Azure Open AI with key
// {
// "provider": "AzureOpenAIChatCompletionClient",
// "config": {
// "model": "gpt-4o-2024-05-13",
// "azure_endpoint": "https://{your-custom-endpoint}.openai.azure.com/",
// "azure_deployment": "{your-azure-deployment}",
// "api_version": "2024-06-01",
// "api_key": "REPLACE_WITH_YOUR_API_KEY"
// }
// }
// Use Open AI with key
{
"provider": "OpenAIChatCompletionClient",
"config": {
"model": "gpt-4o-2024-05-13",
"api_key": "REPLACE_WITH_YOUR_API_KEY"
}
}

View File

@ -0,0 +1,90 @@
You are an expert in literary analysis. You are skilled at dissecting texts to uncover themes, motifs, and character relationships. You are adept at helping people understand the intricate dynamics and structures within literary communities, facilitating deeper insights into how various works influence and reflect societal contexts.
# Goal
Write a comprehensive assessment report of a community taking on the role of a A literary analyst tasked with examining the provided text excerpt from a Sherlock Holmes story, focusing on character dynamics, thematic elements, and narrative structure. The analysis will explore the relationships between characters, the significance of dialogue, and the motifs present in the text. This report will be used to enhance understanding of the literary community surrounding Arthur Conan Doyle's works and their impact on the genre of detective fiction, as well as to inform discussions on character development and thematic depth in literature.. The content of this report includes an overview of the community's key entities and relationships.
# Report Structure
The report should include the following sections:
- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant points associated with its entities.
- REPORT RATING: A float score between 0-10 that represents the relevance of the text to literary analysis, character development, narrative structure, and thematic exploration, with 1 being trivial or irrelevant and 10 being highly significant, profound, and impactful to the understanding of the text and its implications within the literary canon.
- RATING EXPLANATION: Give a single sentence explanation of the rating.
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.
Return output as a well-formed JSON-formatted string with the following format. Don't use any unnecessary escape sequences. The output should be a single JSON object that can be parsed by json.loads.
{
"title": "<report_title>",
"summary": "<executive_summary>",
"rating": <threat_severity_rating>,
"rating_explanation": "<rating_explanation>"
"findings": "[{"summary":"<insight_1_summary>", "explanation": "<insight_1_explanation"}, {"summary":"<insight_2_summary>", "explanation": "<insight_2_explanation"}]"
}
# Grounding Rules
After each paragraph, add data record reference if the content of the paragraph was derived from one or more data records. Reference is in the format of [records: <record_source> (<record_id_list>, ...<record_source> (<record_id_list>)]. If there are more than 10 data records, show the top 10 most relevant records.
Each paragraph should contain multiple sentences of explanation and concrete examples with specific named entities. All paragraphs must have these references at the start and end. Use "NONE" if there are no related roles or records. Everything should be in The primary language of the provided text is "English.".
Example paragraph with references added:
This is a paragraph of the output text [records: Entities (1, 2, 3), Claims (2, 5), Relationships (10, 12)]
# Example Input
-----------
Text:
Entities
id,entity,description
5,ABILA CITY PARK,Abila City Park is the location of the POK rally
Relationships
id,source,target,description
37,ABILA CITY PARK,POK RALLY,Abila City Park is the location of the POK rally
38,ABILA CITY PARK,POK,POK is holding a rally in Abila City Park
39,ABILA CITY PARK,POKRALLY,The POKRally is taking place at Abila City Park
40,ABILA CITY PARK,CENTRAL BULLETIN,Central Bulletin is reporting on the POK rally taking place in Abila City Park
Output:
{
"title": "Abila City Park and POK Rally",
"summary": "The community revolves around the Abila City Park, which is the location of the POK rally. The park has relationships with POK, POKRALLY, and Central Bulletin, all
of which are associated with the rally event.",
"rating": 5.0,
"rating_explanation": "The impact rating is moderate due to the potential for unrest or conflict during the POK rally.",
"findings": [
{
"summary": "Abila City Park as the central location",
"explanation": "Abila City Park is the central entity in this community, serving as the location for the POK rally. This park is the common link between all other
entities, suggesting its significance in the community. The park's association with the rally could potentially lead to issues such as public disorder or conflict, depending on the
nature of the rally and the reactions it provokes. [records: Entities (5), Relationships (37, 38, 39, 40)]"
},
{
"summary": "POK's role in the community",
"explanation": "POK is another key entity in this community, being the organizer of the rally at Abila City Park. The nature of POK and its rally could be a potential
source of threat, depending on their objectives and the reactions they provoke. The relationship between POK and the park is crucial in understanding the dynamics of this community.
[records: Relationships (38)]"
},
{
"summary": "POKRALLY as a significant event",
"explanation": "The POKRALLY is a significant event taking place at Abila City Park. This event is a key factor in the community's dynamics and could be a potential
source of threat, depending on the nature of the rally and the reactions it provokes. The relationship between the rally and the park is crucial in understanding the dynamics of this
community. [records: Relationships (39)]"
},
{
"summary": "Role of Central Bulletin",
"explanation": "Central Bulletin is reporting on the POK rally taking place in Abila City Park. This suggests that the event has attracted media attention, which could
amplify its impact on the community. The role of Central Bulletin could be significant in shaping public perception of the event and the entities involved. [records: Relationships
(40)]"
}
]
}
# Real Data
Use the following text for your answer. Do not make anything up in your answer.
Text:
{input_text}
Output:

View File

@ -0,0 +1,122 @@
-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: One of the following types: [person, character, setting, dialogue, narrative technique, literary device]
- entity_description: Comprehensive description of the entity's attributes and activities
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
- relationship_strength: an integer score between 1 to 10, indicating strength of the relationship between the source entity and target entity
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_strength>)
3. Return output in The primary language of the provided text is "English." as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
4. If you have to translate into The primary language of the provided text is "English.", just translate the descriptions, nothing else!
5. When finished, output {completion_delimiter}.
-Examples-
######################
Example 1:
entity_types: [person, character, setting, dialogue, narrative technique, literary device]
text:
my kicks and shoves. Hullo!
I yelled. Hullo! Colonel! Let me out!
“And then suddenly in the silence I heard a sound which sent my heart
into my mouth. It was the clank of the levers and the swish of the
leaking cylinder. He had set the engine at work. The lamp still stood
upon the floor where I had placed it when examining the trough. By its
light I saw that the black ceiling was coming down upon me, slowly,
jerkily, but, as none knew better than myself, with a force which must
within a minute grind me to a shapeless pulp. I threw myself,
screaming, against the door, and dragged with my nails at the lock. I
implored the colonel to let me out, but the remorseless clanking of the
levers drowned my cries. The ceiling was only a foot or two above my
head,
------------------------
output:
("entity"{tuple_delimiter}COLONEL{tuple_delimiter}PERSON{tuple_delimiter}The Colonel is a character who is being addressed by the narrator, indicating a position of authority or control in the situation described.)
{record_delimiter}
("entity"{tuple_delimiter}NARRATOR{tuple_delimiter}CHARACTER{tuple_delimiter}The narrator is the character experiencing fear and desperation, trying to escape from a dangerous situation involving a descending ceiling.)
{record_delimiter}
("entity"{tuple_delimiter}LEVERS{tuple_delimiter)LITERARY DEVICE{tuple_delimiter}The levers symbolize the mechanism of control and the impending danger, contributing to the tension in the narrative.)
{record_delimiter}
("entity"{tuple_delimiter}CEILING{tuple_delimiter}SETTING{tuple_delimiter}The ceiling represents the physical threat to the narrator, creating a sense of claustrophobia and urgency in the scene.)
{record_delimiter}
("entity"{tuple_delimiter}DOOR{tuple_delimiter}SETTING{tuple_delimiter}The door is a barrier between the narrator and freedom, emphasizing the struggle for escape.)
{record_delimiter}
("entity"{tuple_delimiter}SILENCE{tuple_delimiter}LITERARY DEVICE{tuple_delimiter}Silence serves as a narrative technique that heightens the tension before the sound of the levers is heard, creating a dramatic contrast.)
{record_delimiter}
("relationship"{tuple_delimiter}NARRATOR{tuple_delimiter}COLONEL{tuple_delimiter}The narrator is pleading with the Colonel for help, indicating a relationship of desperation and authority.{tuple_delimiter}8)
{record_delimiter}
("relationship"{tuple_delimiter}NARRATOR{tuple_delimiter}CEILING{tuple_delimiter}The narrator is directly threatened by the descending ceiling, creating a relationship of fear and urgency.{tuple_delimiter}9)
{record_delimiter}
("relationship"{tuple_delimiter}NARRATOR{tuple_delimiter}DOOR{tuple_delimiter}The narrator is trying to escape through the door, establishing a relationship of struggle and confinement.{tuple_delimiter}7)
{record_delimiter}
("relationship"{tuple_delimiter}NARRATOR{tuple_delimiter}LEVERS{tuple_delimiter}The narrator's situation is exacerbated by the sound of the levers, which symbolize the mechanism of danger, linking them through tension.{tuple_delimiter}8)
{record_delimiter}
("relationship"{tuple_delimiter}SILENCE{tuple_delimiter}LEVERS{tuple_delimiter}The silence is broken by the sound of the levers, creating a relationship that emphasizes the shift from calm to chaos.{tuple_delimiter}6)
{completion_delimiter}
#############################
Example 2:
entity_types: [person, character, setting, dialogue, narrative technique, literary device]
text:
effect,” remarked Holmes. “This is wanting in the police
report, where more stress is laid, perhaps, upon the platitudes of the
magistrate than upon the details, which to an observer contain the
vital essence of the whole matter. Depend upon it, there is nothing so
unnatural as the commonplace.”
I smiled and shook my head. “I can quite understand your thinking so,”
I said. “Of course, in your position of unofficial adviser and helper
to everybody who is absolutely puzzled, throughout three continents,
you are brought in contact with all that is strange and bizarre. But
here”—I picked up the morning paper from the ground—“let us put it to a
practical test. Here is the first heading upon which I come. A
husbands cruelty to his wife. There is half a column of print, but I
know without reading it that it is all perfectly familiar to me. There
is, of
------------------------
output:
("entity"{tuple_delimiter}HOLMES{tuple_delimiter}PERSON{tuple_delimiter}Holmes is a character known for his keen observation and deduction skills, often serving as an unofficial adviser to those puzzled by strange occurrences.)
{record_delimiter}
("entity"{tuple_delimiter}POLICE REPORT{tuple_delimiter}LITERARY DEVICE{tuple_delimiter}The police report is a narrative element that emphasizes the contrast between mundane details and the more significant observations that Holmes values.)
{record_delimiter}
("entity"{tuple_delimiter}MAGISTRATE{tuple_delimiter}CHARACTER{tuple_delimiter}The magistrate is a character referenced in the context of the police report, representing the conventional authority that Holmes critiques.)
{record_delimiter}
("entity"{tuple_delimiter}MORNING PAPER{tuple_delimiter}SETTING{tuple_delimiter}The morning paper serves as a setting for the practical test Holmes proposes, representing the everyday reality that contrasts with the bizarre cases he encounters.)
{record_delimiter}
("entity"{tuple_delimiter}HUSBAND'S CRUELTY TO HIS WIFE{tuple_delimiter}DIALOGUE{tuple_delimiter}This heading from the morning paper exemplifies the commonplace nature of human cruelty, which Holmes finds familiar and unremarkable.)
{record_delimiter}
("relationship"{tuple_delimiter}HOLMES{tuple_delimiter}MAGISTRATE{tuple_delimiter}Holmes critiques the magistrate's focus on platitudes in the police report, highlighting a difference in their perspectives on what is significant in a case.{tuple_delimiter}8)
{record_delimiter}
("relationship"{tuple_delimiter}HOLMES{tuple_delimiter}POLICE REPORT{tuple_delimiter}Holmes contrasts the details in the police report with his own observations, indicating his belief that the report lacks the vital essence of the matter.{tuple_delimiter}9)
{record_delimiter}
("relationship"{tuple_delimiter}HOLMES{tuple_delimiter}MORNING PAPER{tuple_delimiter}Holmes uses the morning paper as a practical test to illustrate his point about the familiarity of commonplace events.{tuple_delimiter}7)
{record_delimiter}
("relationship"{tuple_delimiter}HUSBAND'S CRUELTY TO HIS WIFE{tuple_delimiter}MORNING PAPER{tuple_delimiter}The heading about the husband's cruelty is a specific example found in the morning paper, representing the mundane realities that Holmes finds unremarkable.{tuple_delimiter}6)
{completion_delimiter}
#############################
-Real Data-
######################
entity_types: [person, character, setting, dialogue, narrative technique, literary device]
text: {input_text}
######################
output:

View File

@ -0,0 +1,17 @@
You are an expert in literary analysis. You are skilled at dissecting texts to uncover themes, motifs, and character relationships. You are adept at helping people understand the intricate dynamics and structures within literary communities, facilitating deeper insights into how various works influence and reflect societal contexts.
Using your expertise, you're asked to generate a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, concise description in The primary language of the provided text is "English.". Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
Make sure it is written in third person, and include the entity names so we have the full context.
Enrich it as much as you can with relevant information from the nearby text, this is very important.
If no answer is possible, or the description is empty, only convey information that is provided within the text.
#######
-Data-
Entities: {entity_name}
Description List: {description_list}
#######
Output:

View File

@ -0,0 +1,2 @@
autogen-agentchat>=0.4.0,<0.5
autogen-ext[graphrag,openai,azure]>=0.4.0,<0.5

View File

@ -0,0 +1,130 @@
### This config file contains required core defaults that must be set, along with a handful of common optional settings.
### For a full list of available settings, see https://microsoft.github.io/graphrag/config/yaml/
### LLM settings ###
## There are a number of settings to tune the threading and token limits for LLM calls - check the docs.
encoding_model: cl100k_base # this needs to be matched to your model!
llm:
api_key: null
type: openai_chat # or azure_openai_chat
model: gpt-4o
model_supports_json: true # recommended if this is available for your model.
# audience: "https://cognitiveservices.azure.com/.default"
# api_base: https://<resource-name>.openai.azure.com
# api_version: 2024-08-01-preview
# deployment_name: gpt-4o
parallelization:
stagger: 0.3
# num_threads: 50
async_mode: threaded # or asyncio
embeddings:
async_mode: threaded # or asyncio
vector_store:
type: lancedb
db_uri: 'data/output/lancedb'
container_name: default
overwrite: true
llm:
api_key: null
type: openai_embedding # or azure_openai_embedding
model: text-embedding-3-small
# api_base: https://<resource-name>.openai.azure.com
# api_version: "2023-05-15"
# audience: "https://cognitiveservices.azure.com/.default"
# deployment_name: text-embedding-3-small
### Input settings ###
input:
type: file # or blob
file_type: text # or csv
base_dir: "data/input"
file_encoding: utf-8
file_pattern: ".*\\.txt$"
chunks:
size: 1200
overlap: 100
group_by_columns: [id]
### Storage settings ###
## If blob storage is specified in the following four sections,
## connection_string and container_name must be provided
cache:
type: file # or blob
base_dir: "cache"
reporting:
type: file # or console, blob
base_dir: "logs"
storage:
type: file # or blob
base_dir: "data/output"
## only turn this on if running `graphrag index` with custom settings
## we normally use `graphrag update` with the defaults
update_index_storage:
# type: file # or blob
# base_dir: "update_output"
### Workflow settings ###
skip_workflows: []
entity_extraction:
prompt: "prompts/entity_extraction.txt"
entity_types: [organization,person,geo,event]
max_gleanings: 1
summarize_descriptions:
prompt: "prompts/summarize_descriptions.txt"
max_length: 500
claim_extraction:
enabled: false
prompt: "prompts/claim_extraction.txt"
description: "Any claims or facts that could be relevant to information discovery."
max_gleanings: 1
community_reports:
prompt: "prompts/community_report.txt"
max_length: 2000
max_input_length: 8000
cluster_graph:
max_cluster_size: 10
embed_graph:
enabled: false # if true, will generate node2vec embeddings for nodes
umap:
enabled: false # if true, will generate UMAP embeddings for nodes
snapshots:
graphml: false
raw_entities: false
top_level_nodes: false
embeddings: false
transient: false
### Query settings ###
## The prompt locations are required here, but each search method has a number of optional knobs that can be tuned.
## See the config docs: https://microsoft.github.io/graphrag/config/yaml/#query
local_search:
prompt: "prompts/local_search_system_prompt.txt"
global_search:
map_prompt: "prompts/global_search_map_system_prompt.txt"
reduce_prompt: "prompts/global_search_reduce_system_prompt.txt"
knowledge_prompt: "prompts/global_search_knowledge_system_prompt.txt"
drift_search:
prompt: "prompts/drift_search_system_prompt.txt"

File diff suppressed because it is too large Load Diff