KAG/kag/common/retriever/kag_retriever.py

423 lines
17 KiB
Python

# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import json
from tenacity import retry, stop_after_attempt
from kag.common.base.prompt_op import PromptOp
from kag.common.vectorizer import Vectorizer
from knext.graph_algo.client import GraphAlgoClient
from kag.interface.retriever.chunk_retriever_abc import ChunkRetrieverABC
from typing import List, Dict
import numpy as np
import logging
from knext.reasoner.client import ReasonerClient
from knext.schema.client import CHUNK_TYPE, OTHER_TYPE
from knext.project.client import ProjectClient
from kag.common.utils import processing_phrases
from kag.common.llm.client.llm_client import LLMClient
from knext.search.client import SearchClient
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.config import LogicFormConfiguration
logger = logging.getLogger(__name__)
class DefaultRetriever(ChunkRetrieverABC):
"""
KAGRetriever class for retrieving and processing knowledge graph data from a graph database.
this retriever references the implementation of Hippoag for the combination of dpr & ppr, developer can define your Retriever
Parameters:
- project_id (str, optional): Project ID to load specific project configurations.
"""
def __init__(self, project_id: str = None):
self.project_id = project_id or os.environ.get("KAG_PROJECT_ID")
self._init_llm()
self._init_search(self.project_id)
biz_scene = os.getenv('KAG_PROMPT_BIZ_SCENE', 'default')
language = os.getenv('KAG_PROMPT_LANGUAGE', 'en')
self.ner_prompt = PromptOp.load(biz_scene, "question_ner")(language=language, project_id=self.project_id)
self.std_prompt = PromptOp.load(biz_scene, "std")(language=language)
self.pagerank_threshold = 0.9
self.match_threshold = 0.8
self.pagerank_weight = 0.5
self.reranker_model_path = os.getenv("KAG_RETRIEVER_RERANKER_MODEL_PATH")
if self.reranker_model_path:
from kag.common.reranker.reranker import BGEReranker
self.reranker = BGEReranker(self.reranker_model_path, use_fp16=True)
else:
self.reranker = None
self.with_semantic = True
def _init_llm(self):
llm_config = eval(os.getenv("KAG_LLM", "{}"))
project_id = int(self.project_id)
config = ProjectClient().get_config(project_id)
llm_config.update(config.get("llm", {}))
self.llm = LLMClient.from_config(llm_config)
def _init_search(self, project_id):
host_addr = os.getenv("KAG_PROJECT_HOST_ADDR")
self.schema_util = SchemaUtils(LogicFormConfiguration({
"project_id": project_id,
"host_addr": host_addr,
}))
self.sc: SearchClient = SearchClient(host_addr, int(project_id))
vectorizer_config = eval(os.getenv("KAG_VECTORIZER", "{}"))
self.vectorizer = Vectorizer.from_config(
vectorizer_config
)
self.reason: ReasonerClient = ReasonerClient(host_addr, int(project_id))
self.graph_algo = GraphAlgoClient(host_addr, int(project_id))
@retry(stop=stop_after_attempt(3))
def named_entity_recognition(self, query: str):
"""
Perform named entity recognition.
This method invokes the pre-configured service client (self.llm) to process the input query,
using the named entity recognition (NER) prompt (self.ner_prompt).
Parameters:
query (str): The text input provided by the user or system for named entity recognition.
Returns:
The result returned by the service client, with the type and format depending on the used service.
"""
return self.llm.invoke({"input": query}, self.ner_prompt)
@retry(stop=stop_after_attempt(3))
def named_entity_standardization(self, query: str, entities: List[Dict]):
"""
Entity standardization function.
This function calls a remote service to process the input query and named entities,
standardizing the entities. This is useful for unifying different representations of the same entity in text,
improving the performance of natural language processing tasks.
Parameters:
- query: A string containing the query with named entities.
- entities: A list of dictionaries, each containing information about named entities.
Returns:
- The result of the remote service call, typically standardized named entity information.
"""
return self.llm.invoke(
{"input": query, "named_entities": entities}, self.std_prompt
)
@staticmethod
def append_official_name(source_entities: List[Dict], entities_with_official_name: List[Dict]):
"""
Appends official names to entities.
Parameters:
source_entities (List[Dict]): A list of source entities.
entities_with_official_name (List[Dict]): A list of entities with official names.
"""
tmp_dict = {}
for tmp_entity in entities_with_official_name:
name = tmp_entity["entity"]
category = tmp_entity["category"]
official_name = tmp_entity["official_name"]
key = f"{category}{name}"
tmp_dict[key] = official_name
for tmp_entity in source_entities:
name = tmp_entity["entity"]
category = tmp_entity["category"]
key = f"{category}{name}"
if key in tmp_dict:
official_name = tmp_dict[key]
tmp_entity["official_name"] = official_name
def calculate_sim_scores(self, query: str, doc_nums: int):
"""
Calculate the vector similarity scores between a query and document chunks.
Parameters:
query (str): The user's query text.
doc_nums (int): The number of document chunks to return.
Returns:
dict: A dictionary with keys as document chunk IDs and values as the vector similarity scores.
"""
scores = dict()
try:
query_vector = self.vectorizer.vectorize(query)
top_k = self.sc.search_vector(
label=self.schema_util.get_label_within_prefix(CHUNK_TYPE),
property_key="content",
query_vector=query_vector,
topk=doc_nums
)
scores = {item["node"]["id"]: item["score"] for item in top_k}
except Exception as e:
logger.error(
f"run calculate_sim_scores failed, info: {e}", exc_info=True
)
return scores
def calculate_pagerank_scores(self, start_nodes: List[Dict]):
"""
Calculate and retrieve PageRank scores for the given starting nodes.
Parameters:
start_nodes (list): A list containing document fragment IDs to be used as starting nodes for the PageRank algorithm.
Returns:
ppr_doc_scores (dict): A dictionary containing each document fragment ID and its corresponding PageRank score.
This method uses the PageRank algorithm in the graph store to compute scores for document fragments. If `start_nodes` is empty,
it returns an empty dictionary. Otherwise, it attempts to retrieve PageRank scores from the graph store and converts the result
into a dictionary format where keys are document fragment IDs and values are their respective PageRank scores. Any exceptions,
such as failures in running `run_pagerank_igraph_chunk`, are logged.
"""
scores = dict()
if len(start_nodes) != 0:
try:
scores = self.graph_algo.calculate_pagerank_scores(
self.schema_util.get_label_within_prefix(CHUNK_TYPE),
start_nodes
)
except Exception as e:
logger.error(
f"run calculate_pagerank_scores failed, info: {e}, start_nodes: {start_nodes}", exc_info=True
)
return scores
def match_entities(self, queries: Dict[str, str], top_k: int = 1):
"""
Match entities based on the provided queries.
:param queries: A dictionary containing keywords and their labels.
:param top_k: The number of top results to return. Default is 1.
:return: A tuple containing a list of matched entities and their scores.
"""
matched_entities = []
matched_entities_scores = []
for query, query_type in queries.items():
query = processing_phrases(query)
if query_type not in self.schema_util.node_en_zh.keys():
query_type = self.schema_util.get_label_within_prefix(OTHER_TYPE)
typed_nodes = self.sc.search_vector(
label=query_type,
property_key="name",
query_vector=self.vectorizer.vectorize(query),
topk=top_k,
)
if query_type != self.schema_util.get_label_within_prefix(OTHER_TYPE):
nontyped_nodes = self.sc.search_vector(
label=self.schema_util.get_label_within_prefix(OTHER_TYPE),
property_key="name",
query_vector=self.vectorizer.vectorize(query),
topk=top_k,
)
else:
nontyped_nodes = typed_nodes
if len(typed_nodes) == 0 and len(nontyped_nodes) != 0:
matched_entities.append(
{"name": nontyped_nodes[0]["node"]["name"], "type": OTHER_TYPE}
)
matched_entities_scores.append(nontyped_nodes[0]["score"])
elif len(typed_nodes) != 0 and len(nontyped_nodes) != 0:
if typed_nodes[0]["score"] > 0.8:
matched_entities.append(
{"name": typed_nodes[0]["node"]["name"], "type": query_type}
)
matched_entities_scores.append(typed_nodes[0]["score"])
else:
matched_entities.append(
{"name": nontyped_nodes[0]["node"]["name"], "type": OTHER_TYPE}
)
matched_entities_scores.append(nontyped_nodes[0]["score"])
matched_entities.append(
{"name": typed_nodes[0]["node"]["name"], "type": query_type}
)
matched_entities_scores.append(typed_nodes[0]["score"])
elif len(typed_nodes) != 0 and len(nontyped_nodes) == 0:
if typed_nodes[0]["score"] > 0.8:
matched_entities.append(
{"name": typed_nodes[0]["node"]["name"], "type": query_type}
)
matched_entities_scores.append(typed_nodes[0]["score"])
if not matched_entities:
logger.info(f"No entities matched for {queries}")
return matched_entities, matched_entities_scores
def calculate_combined_scores(self, sim_scores: Dict[str, float], pagerank_scores: Dict[str, float]):
"""
Calculate and return the combined scores that integrate both similarity scores and PageRank scores.
Parameters:
sim_scores (Dict[str, float]): A dictionary containing similarity scores, where keys are identifiers and values are scores.
pagerank_scores (Dict[str, float]): A dictionary containing PageRank scores, where keys are identifiers and values are scores.
Returns:
Dict[str, float]: A dictionary containing the combined scores, where keys are identifiers and values are the combined scores.
"""
def min_max_normalize(x):
if len(x) == 0:
return []
if np.max(x) - np.min(x) > 0:
return (x - np.min(x)) / (np.max(x) - np.min(x))
else:
return x - np.min(x)
all_keys = set(pagerank_scores.keys()).union(set(sim_scores.keys()))
for key in all_keys:
sim_scores.setdefault(key, 0.0)
pagerank_scores.setdefault(key, 0.0)
sim_scores = dict(zip(sim_scores.keys(), min_max_normalize(
np.array(list(sim_scores.values()))
)))
pagerank_scores = dict(zip(pagerank_scores.keys(), min_max_normalize(
np.array(list(pagerank_scores.values()))
)))
combined_scores = dict()
for key in pagerank_scores.keys():
combined_scores[key] = (sim_scores[key] * (1 - self.pagerank_weight) +
pagerank_scores[key] * self.pagerank_weight
)
return combined_scores
def recall_docs(self, query: str, top_k: int = 5, **kwargs):
"""
Recall relevant documents based on the query string.
Parameters:
- query (str): The user's query string.
- top_k (int, optional): The number of documents to return, default is 5.
Keyword Arguments:
- kwargs: Additional keyword arguments.
Returns:
- list: A list containing the top_k most relevant documents.
"""
assert isinstance(query, str), "Query must be a string"
chunk_nums = top_k * 20
if chunk_nums == 0:
return []
ner_list = self.named_entity_recognition(query)
print(ner_list)
if self.with_semantic:
std_ner_list = self.named_entity_standardization(query, ner_list)
self.append_official_name(ner_list, std_ner_list)
entities = {}
for item in ner_list:
entity = item.get("entity", "")
category = item.get("category", "")
official_name = item.get("official_name", "")
if not entity or not (category or official_name):
continue
if category.lower() in ["works", "person", "other"]:
entities[entity] = category
else:
entities[entity] = official_name or category
sim_scores = self.calculate_sim_scores(query, chunk_nums)
matched_entities, matched_scores = self.match_entities(entities)
pagerank_scores = self.calculate_pagerank_scores(matched_entities)
if not matched_entities:
combined_scores = sim_scores
elif matched_entities and np.min(matched_scores) > self.pagerank_threshold:
combined_scores = pagerank_scores
else:
combined_scores = self.calculate_combined_scores(sim_scores, pagerank_scores)
sorted_scores = sorted(
combined_scores.items(), key=lambda item: item[1], reverse=True
)
logger.debug(f"sorted_scores: {sorted_scores}")
return self.get_all_docs_by_id(query, sorted_scores, top_k)
def get_all_docs_by_id(self, query: str, doc_ids: list, top_k: int):
"""
Retrieve a list of documents based on their IDs.
Parameters:
- query (str): The query string for text matching.
- doc_ids (list): A list of document IDs to retrieve documents.
- top_k (int): The maximum number of documents to return.
Returns:
- list: A list of matched documents.
"""
matched_docs = []
hits_docs = set()
counter = 0
for doc_id in doc_ids:
if counter == top_k:
break
if isinstance(doc_id, tuple):
doc_score = doc_id[1]
doc_id = doc_id[0]
else:
doc_score = doc_ids[doc_id]
counter += 1
node = self.reason.query_node(label=self.schema_util.get_label_within_prefix(CHUNK_TYPE), id_value=doc_id)
node_dict = dict(node.items())
matched_docs.append(f"#{node_dict['name']}#{node_dict['content']}#{doc_score}")
hits_docs.add(node_dict['name'])
try:
text_matched = self.sc.search_text(query, [self.schema_util.get_label_within_prefix(CHUNK_TYPE)], topk=1)
if text_matched:
for item in text_matched:
title = item["node"]["name"]
if title not in hits_docs:
if len(matched_docs) > 0:
matched_docs.pop()
else:
logger.warning(f"{query} matched docs is empty")
matched_docs.append(f'#{item["node"]["name"]}#{item["node"]["content"]}#{item["score"]}')
break
except Exception as e:
logger.warning(f"{query} query chunk failed: {e}", exc_info=True)
logger.debug(f"matched_docs: {matched_docs}")
return matched_docs
def rerank_docs(self, queries: List[str], passages: List[str]):
"""
Re-ranks the given passages based on the provided queries.
Parameters:
- queries (List[str]): A list of queries.
- passages (List[str]): A list of passages.
Returns:
- List[str]: A re-ranked list of passages.
"""
if self.reranker is None:
return passages
return self.reranker.rerank(queries, passages)