KAG/kag/builder/component/extractor/kag_extractor.py

325 lines
13 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import copy
import logging
import os
from typing import Dict, Type, List
from tenacity import stop_after_attempt, retry
from kag.builder.prompt.spg_prompt import SPG_KGPrompt
from kag.interface.builder import ExtractorABC
from kag.common.base.prompt_op import PromptOp
from knext.schema.client import OTHER_TYPE, CHUNK_TYPE, BASIC_TYPES
from kag.common.utils import processing_phrases, to_camel_case
from kag.builder.model.chunk import Chunk
from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
from knext.schema.client import SchemaClient
from knext.schema.model.base import SpgTypeEnum
logger = logging.getLogger(__name__)
class KAGExtractor(ExtractorABC):
"""
A class for extracting knowledge graph subgraphs from text using a large language model (LLM).
Inherits from the Extractor base class.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.llm = self._init_llm()
self.biz_scene = os.getenv("KAG_PROMPT_BIZ_SCENE", "default")
self.language = os.getenv("KAG_PROMPT_LANGUAGE", "en")
self.schema = SchemaClient(project_id=self.project_id).load()
self.ner_prompt = PromptOp.load(self.biz_scene, "ner")(language=self.language, project_id=self.project_id)
self.std_prompt = PromptOp.load(self.biz_scene, "std")(language=self.language)
self.triple_prompt = PromptOp.load(self.biz_scene, "triple")(
language=self.language
)
self.kg_types = []
for type_name, spg_type in self.schema.items():
if type_name in SPG_KGPrompt.ignored_types:
continue
if spg_type.spg_type_enum == SpgTypeEnum.Concept:
continue
properties = list(spg_type.properties.keys())
for p in properties:
if p not in SPG_KGPrompt.ignored_properties:
self.kg_types.append(type_name)
break
if self.kg_types:
self.kg_prompt = SPG_KGPrompt(self.kg_types, language=self.language, project_id=self.project_id)
@property
def input_types(self) -> Type[Input]:
return Chunk
@property
def output_types(self) -> Type[Output]:
return SubGraph
@retry(stop=stop_after_attempt(3))
def named_entity_recognition(self, passage: str):
"""
Performs named entity recognition on a given text passage.
Args:
passage (str): The text to perform named entity recognition on.
Returns:
The result of the named entity recognition operation.
"""
if self.kg_types:
kg_result = self.llm.invoke({"input": passage}, self.kg_prompt)
else:
kg_result = []
ner_result = self.llm.invoke({"input": passage}, self.ner_prompt)
return kg_result + ner_result
@retry(stop=stop_after_attempt(3))
def named_entity_standardization(self, passage: str, entities: List[Dict]):
"""
Standardizes named entities.
Args:
passage (str): The input text passage.
entities (List[Dict]): A list of recognized named entities.
Returns:
Standardized entity information.
"""
return self.llm.invoke(
{"input": passage, "named_entities": entities}, self.std_prompt
)
@retry(stop=stop_after_attempt(3))
def triples_extraction(self, passage: str, entities: List[Dict]):
"""
Extracts triples (subject-predicate-object structures) from a given text passage based on identified entities.
Args:
passage (str): The text to extract triples from.
entities (List[Dict]): A list of entities identified in the text.
Returns:
The result of the triples extraction operation.
"""
return self.llm.invoke(
{"input": passage, "entity_list": entities}, self.triple_prompt
)
def assemble_sub_graph_with_spg_records(self, entities: List[Dict]):
sub_graph = SubGraph([], [])
for record in entities:
s_name = record.get("entity", "")
s_label = record.get("category", "")
properties = record.get("properties", {})
tmp_properties = copy.deepcopy(properties)
spg_type = self.schema.get(s_label)
for prop_name, prop_value in properties.items():
if prop_value == "NAN":
tmp_properties.pop(prop_name)
continue
if prop_name in spg_type.properties:
from knext.schema.model.property import Property
prop: Property = spg_type.properties.get(prop_name)
o_label = prop.object_type_name_en
if o_label not in BASIC_TYPES:
if isinstance(prop_value, str):
prop_value = [prop_value]
for o_name in prop_value:
sub_graph.add_node(id=o_name, name=o_name, label=o_label)
sub_graph.add_edge(s_id=s_name, s_label=s_label, p=prop_name, o_id=o_name, o_label=o_label)
tmp_properties.pop(prop_name)
record["properties"] = tmp_properties
sub_graph.add_node(id=s_name, name=s_name, label=s_label, properties=properties)
return sub_graph, entities
@staticmethod
def assemble_sub_graph_with_triples(
sub_graph: SubGraph, entities: List[Dict], triples: List[list]
):
"""
Assembles edges in the subgraph based on a list of triples and entities.
Args:
sub_graph (SubGraph): The subgraph to add edges to.
entities (List[Dict]): A list of entities, for looking up category information.
triples (List[list]): A list of triples, each representing a relationship to be added to the subgraph.
"""
def get_category(entities_data, entity_name):
for entity in entities_data:
if entity["entity"] == entity_name:
return entity["category"]
return None
for tri in triples:
if len(tri) != 3:
continue
s_category = get_category(entities, tri[0])
tri[0] = processing_phrases(tri[0])
if s_category is None:
s_category = OTHER_TYPE
sub_graph.add_node(tri[0], tri[0], s_category)
o_category = get_category(entities, tri[2])
tri[2] = processing_phrases(tri[2])
if o_category is None:
o_category = OTHER_TYPE
sub_graph.add_node(tri[2], tri[2], o_category)
sub_graph.add_edge(
tri[0], s_category, to_camel_case(tri[1]), tri[2], o_category
)
return sub_graph
@staticmethod
def assemble_sub_graph_with_chunk(sub_graph: SubGraph, chunk: Chunk):
"""
Associates a Chunk object with the subgraph, adding it as a node and connecting it with existing nodes.
Args:
sub_graph (SubGraph): The subgraph to add the chunk information to.
chunk (Chunk): The chunk object containing the text and metadata.
"""
for node in sub_graph.nodes:
sub_graph.add_edge(node.id, node.label, "source", chunk.id, CHUNK_TYPE)
sub_graph.add_node(
chunk.id,
chunk.name,
CHUNK_TYPE,
{
"id": chunk.id,
"name": chunk.name,
"content": f"{chunk.name}\n{chunk.content}",
**chunk.kwargs
},
)
sub_graph.id = chunk.id
return sub_graph
def assemble_sub_graph(
self, sub_graph: SubGraph, chunk: Chunk, entities: List[Dict], triples: List[list]
):
"""
Integrates entity and triple information into a subgraph, and associates it with a chunk of text.
Args:
sub_graph (SubGraph): The subgraph to be assembled.
chunk (Chunk): The chunk of text the subgraph is about.
entities (List[Dict]): A list of entities identified in the chunk.
triples (List[list]): A list of triples representing relationships between entities.
Returns:
SubGraph: The constructed subgraph.
"""
self.assemble_sub_graph_with_entities(sub_graph, entities)
self.assemble_sub_graph_with_triples(sub_graph, entities, triples)
self.assemble_sub_graph_with_chunk(sub_graph, chunk)
return sub_graph
def assemble_sub_graph_with_entities(
self, sub_graph: SubGraph, entities: List[Dict]
):
"""
Assembles a subgraph using named entities.
Args:
sub_graph (SubGraph): The subgraph object to be assembled.
entities (List[Dict]): A list containing entity information.
"""
for ent in entities:
name = processing_phrases(ent["entity"])
sub_graph.add_node(
name,
name,
ent["category"],
{
"desc": ent.get("description", ""),
"semanticType": ent.get("type", ""),
**ent.get("properties", {}),
},
)
if "official_name" in ent:
official_name = processing_phrases(ent["official_name"])
if official_name != name:
sub_graph.add_node(
official_name,
official_name,
ent["category"],
{
"desc": ent.get("description", ""),
"semanticType": ent.get("type", ""),
**ent.get("properties", {}),
},
)
sub_graph.add_edge(
name,
ent["category"],
"OfficialName",
official_name,
ent["category"],
)
def append_official_name(
self, source_entities: List[Dict], entities_with_official_name: List[Dict]
):
"""
Appends official names to entities.
Args:
source_entities (List[Dict]): A list of source entities.
entities_with_official_name (List[Dict]): A list of entities with official names.
"""
tmp_dict = {}
for tmp_entity in entities_with_official_name:
name = tmp_entity["entity"]
category = tmp_entity["category"]
official_name = tmp_entity["official_name"]
key = f"{category}{name}"
tmp_dict[key] = official_name
for tmp_entity in source_entities:
name = tmp_entity["entity"]
category = tmp_entity["category"]
key = f"{category}{name}"
if key in tmp_dict:
official_name = tmp_dict[key]
tmp_entity["official_name"] = official_name
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the semantic extractor to process input data.
Args:
input (Input): Input data containing name and content.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of processed results, containing subgraph information.
"""
title = input.name
passage = title + "\n" + input.content
try:
entities = self.named_entity_recognition(passage)
sub_graph, entities = self.assemble_sub_graph_with_spg_records(entities)
filtered_entities = [{k: v for k, v in ent.items() if k in ["entity", "category"]} for ent in entities]
triples = self.triples_extraction(passage, filtered_entities)
std_entities = self.named_entity_standardization(passage, filtered_entities)
self.append_official_name(entities, std_entities)
self.assemble_sub_graph(sub_graph, input, entities, triples)
return [sub_graph]
except Exception as e:
import traceback
traceback.print_exc()
logger.info(e)
return []