Remove sensitive information

This commit is contained in:
庄舟 2024-10-24 11:46:15 +08:00
parent 95354d0f0b
commit d0d352d635
429 changed files with 224242 additions and 0 deletions

1
KAG_VERSION Normal file
View File

@ -0,0 +1 @@
0.0.3.20241022.2

7
LEGAL.md Normal file
View File

@ -0,0 +1,7 @@
Legal Disclaimer
Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
法律免责声明
关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。

2
MANIFEST.in Normal file
View File

@ -0,0 +1,2 @@
recursive-include kag *
recursive-exclude kag/examples *

18
kag/__init__.py Normal file
View File

@ -0,0 +1,18 @@
# Copyright 2024 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
__package_name__ = "openspg-kag"
__version__ = "0.0.3.20241022.2"
from kag.common.env import init_env
init_env()

10
kag/builder/__init__.py Normal file
View File

@ -0,0 +1,10 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.mapping.spg_type_mapping import SPGTypeMapping
from kag.builder.component.mapping.relation_mapping import RelationMapping
from kag.builder.component.writer.kg_writer import KGWriter
__all__ = [
"SPGTypeMapping",
"RelationMapping",
"KGWriter",
]

View File

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.aligner.semantic_aligner import SemanticAligner
__all__ = [
'SemanticAligner',
]

View File

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import List, Sequence, Dict, Type
from kag.builder.model.sub_graph import SubGraph
from kag.interface.builder import AlignerABC
from knext.common.base.runnable import Input, Output
class KAGPostProcessorAligner(AlignerABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@property
def input_types(self) -> Type[Input]:
return SubGraph
@property
def output_types(self) -> Type[Output]:
return SubGraph
def invoke(self, input: List[SubGraph], **kwargs) -> SubGraph:
merged_sub_graph = SubGraph(nodes=[], edges=[])
for sub_graph in input:
for node in sub_graph.nodes:
if node not in merged_sub_graph.nodes:
merged_sub_graph.nodes.append(node)
for edge in sub_graph.edges:
if edge not in merged_sub_graph.edges:
merged_sub_graph.edges.append(edge)
return merged_sub_graph
def _handle(self, input: Sequence[Dict]) -> Dict:
_input = [self.input_types.from_dict(i) for i in input]
_output = self.invoke(_input)
return _output.to_dict()
def batch(self, inputs: List[Input], **kwargs) -> List[Output]:
pass

View File

@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import List, Type
from kag.interface.builder import AlignerABC
from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
from kag.common.semantic_infer import SemanticEnhance
class SemanticAligner(AlignerABC, SemanticEnhance):
"""
A class for semantic alignment and enhancement, inheriting from Aligner and SemanticEnhance.
"""
def __init__(self, **kwargs):
AlignerABC.__init__(self, **kwargs)
SemanticEnhance.__init__(self, **kwargs)
@property
def input_types(self) -> Type[Input]:
return SubGraph
@property
def output_types(self) -> Type[Output]:
return SubGraph
def invoke(self, input: SubGraph, **kwargs) -> List[SubGraph]:
"""
Generates and adds concept nodes based on extracted entities and their context.
Args:
input (SubGraph): The input subgraph.
**kwargs: Additional keyword arguments.
Returns:
List[SubGraph]: A list containing the updated subgraph.
"""
expanded_concept_nodes = []
expanded_concept_edges = []
context = [
node.properties.get("content")
for node in input.nodes if node.label == 'Chunk'
]
context = context[0] if context else None
_dedup_keys = set()
for node in input.nodes:
if node.id == "" or node.name == "" or node.label == 'Chunk':
continue
if node.name in _dedup_keys:
continue
_dedup_keys.add(node.name)
expand_dict = self.expand_semantic_concept(node.name, context=context, target=None)
expand_nodes = [
{
"id": info["name"], "name": info["name"],
"label": self.concept_label,
"properties": {"desc": info["desc"]}
}
for info in expand_dict
]
expanded_concept_nodes.extend(expand_nodes)
path_nodes = [node.to_dict()] + expand_nodes
# entity -> concept, concept -> concept
for ix, concept in enumerate(path_nodes):
if ix == 0:
continue
expanded_concept_edges.append({
"s_id": path_nodes[ix-1]["id"],
"s_label": path_nodes[ix-1]["label"],
"p": self.hyper_edge,
"o_id": path_nodes[ix]["id"],
"o_label": path_nodes[ix]["label"]
})
[input.add_node(**n) for n in expanded_concept_nodes]
[input.add_edge(**e) for e in expanded_concept_edges]
return [input]

View File

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import List, Type, Dict
from kag.interface.builder import AlignerABC
from knext.schema.client import BASIC_TYPES
from kag.builder.model.spg_record import SPGRecord
from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
from knext.schema.client import SchemaClient
from knext.schema.model.base import ConstraintTypeEnum, BaseSpgType
class SPGPostProcessorAligner(AlignerABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spg_types = SchemaClient(project_id=self.project_id).load()
@property
def input_types(self) -> Type[Input]:
return SPGRecord
@property
def output_types(self) -> Type[Output]:
return SubGraph
def merge(self, spg_records: List[SPGRecord]):
merged_spg_records = {}
for record in spg_records:
key = f"{record.spg_type_name}#{record.get_property('name', '')}"
if key not in merged_spg_records:
merged_spg_records[key] = record
else:
old_record = merged_spg_records[key]
for prop_name, prop_value in record.properties.items():
if prop_name not in old_record.properties:
old_record.properties[prop_name] = prop_value
else:
prop = self.spg_types.get(record.spg_type_name).properties.get(
prop_name
)
if not prop:
continue
if (
prop.object_type_name not in BASIC_TYPES
or prop.constraint.get(ConstraintTypeEnum.MultiValue)
):
old_value = old_record.properties.get(prop_name)
if not prop_value:
prop_value = ""
prop_value_list = (
prop_value + "," + old_value
if old_value
else prop_value
).split(",")
old_record.properties[prop_name] = ",".join(
list(set(prop_value_list))
)
else:
old_record.properties[prop_name] = prop_value
return list(merged_spg_records.values())
@staticmethod
def from_spg_record(
spg_types: Dict[str, BaseSpgType], spg_records: List[SPGRecord]
):
sub_graph = SubGraph([], [])
for record in spg_records:
s_id = record.id
s_name = record.name
s_label = record.spg_type_name
properties = record.properties
spg_type = spg_types.get(record.spg_type_name)
for prop_name, prop_value in record.properties.items():
if prop_name in spg_type.properties:
from knext.schema.model.property import Property
prop: Property = spg_type.properties.get(prop_name)
o_label = prop.object_type_name
if o_label not in BASIC_TYPES:
prop_value_list = prop_value.split(",")
for o_id in prop_value_list:
sub_graph.add_edge(
s_id=s_id,
s_label=s_label,
p=prop_name,
o_id=o_id,
o_label=o_label,
)
properties.pop(prop_name)
sub_graph.add_node(
id=s_id, name=s_name, label=s_label, properties=properties
)
return sub_graph
def invoke(self, input: Input, **kwargs) -> List[Output]:
subgraph = SubGraph.from_spg_record(self.spg_types, [input])
return [subgraph]
def batch(self, inputs: List[Input], **kwargs) -> List[Output]:
merged_records = self.merge(inputs)
subgraph = SubGraph.from_spg_record(self.spg_types, merged_records)
return [subgraph]

View File

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from abc import ABC
from typing import List, Dict
import logging
from knext.common.base.component import Component
from knext.common.base.runnable import Input, Output
from knext.project.client import ProjectClient
from kag.common.llm.client import LLMClient
class BuilderComponent(Component, ABC):
"""
Abstract base class for all builder component.
"""
project_id: str = None
def _init_llm(self) -> LLMClient:
"""
Initializes the Large Language Model (LLM) client.
This method retrieves the LLM configuration from environment variables and the project ID.
It then fetches the project configuration using the project ID and updates the LLM configuration
with any additional settings from the project. Finally, it creates and initializes the LLM client
using the updated configuration.
Args:
None
Returns:
LLMClient
"""
llm_config = eval(os.getenv("KAG_LLM", "{}"))
project_id = self.project_id or os.getenv("KAG_PROJECT_ID")
if project_id:
try:
config = ProjectClient().get_config(project_id)
llm_config.update(config.get("llm", {}))
except:
logging.warning(
f"Failed to get project config for project id: {project_id}"
)
llm = LLMClient.from_config(llm_config)
return llm
@property
def type(self):
"""
Get the type label of the object.
Returns:
str: The type label of the object, fixed as "BUILDER".
"""
return "BUILDER"
def batch(self, inputs: List[Input], **kwargs) -> List[Output]:
results = []
for input in inputs:
results.extend(self.invoke(input, **kwargs))
return results
def _handle(self, input: Dict) -> List[Dict]:
_input = self.input_types.from_dict(input) if isinstance(input, dict) else input
_output = self.invoke(_input)
return [_o.to_dict() for _o in _output if _o]

View File

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.extractor.kag_extractor import KAGExtractor
from kag.builder.component.extractor.spg_extractor import SPGExtractor
from kag.builder.component.extractor.user_defined_extractor import (
UserDefinedExtractor,
)
__all__ = [
"KAGExtractor",
"SPGExtractor",
"UserDefinedExtractor",
]

View File

@ -0,0 +1,324 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import copy
import logging
import os
from typing import Dict, Type, List
from tenacity import stop_after_attempt, retry
from kag.builder.prompt.spg_prompt import SPG_KGPrompt
from kag.interface.builder import ExtractorABC
from kag.common.base.prompt_op import PromptOp
from knext.schema.client import OTHER_TYPE, CHUNK_TYPE, BASIC_TYPES
from kag.common.utils import processing_phrases, to_camel_case
from kag.builder.model.chunk import Chunk
from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
from knext.schema.client import SchemaClient
from knext.schema.model.base import SpgTypeEnum
logger = logging.getLogger(__name__)
class KAGExtractor(ExtractorABC):
"""
A class for extracting knowledge graph subgraphs from text using a large language model (LLM).
Inherits from the Extractor base class.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.llm = self._init_llm()
self.biz_scene = os.getenv("KAG_PROMPT_BIZ_SCENE", "default")
self.language = os.getenv("KAG_PROMPT_LANGUAGE", "en")
self.schema = SchemaClient(project_id=self.project_id).load()
self.ner_prompt = PromptOp.load(self.biz_scene, "ner")(language=self.language, project_id=self.project_id)
self.std_prompt = PromptOp.load(self.biz_scene, "std")(language=self.language)
self.triple_prompt = PromptOp.load(self.biz_scene, "triple")(
language=self.language
)
self.kg_types = []
for type_name, spg_type in self.schema.items():
if type_name in SPG_KGPrompt.ignored_types:
continue
if spg_type.spg_type_enum == SpgTypeEnum.Concept:
continue
properties = list(spg_type.properties.keys())
for p in properties:
if p not in SPG_KGPrompt.ignored_properties:
self.kg_types.append(type_name)
break
if self.kg_types:
self.kg_prompt = SPG_KGPrompt(self.kg_types, language=self.language, project_id=self.project_id)
@property
def input_types(self) -> Type[Input]:
return Chunk
@property
def output_types(self) -> Type[Output]:
return SubGraph
@retry(stop=stop_after_attempt(3))
def named_entity_recognition(self, passage: str):
"""
Performs named entity recognition on a given text passage.
Args:
passage (str): The text to perform named entity recognition on.
Returns:
The result of the named entity recognition operation.
"""
if self.kg_types:
kg_result = self.llm.invoke({"input": passage}, self.kg_prompt)
else:
kg_result = []
ner_result = self.llm.invoke({"input": passage}, self.ner_prompt)
return kg_result + ner_result
@retry(stop=stop_after_attempt(3))
def named_entity_standardization(self, passage: str, entities: List[Dict]):
"""
Standardizes named entities.
Args:
passage (str): The input text passage.
entities (List[Dict]): A list of recognized named entities.
Returns:
Standardized entity information.
"""
return self.llm.invoke(
{"input": passage, "named_entities": entities}, self.std_prompt
)
@retry(stop=stop_after_attempt(3))
def triples_extraction(self, passage: str, entities: List[Dict]):
"""
Extracts triples (subject-predicate-object structures) from a given text passage based on identified entities.
Args:
passage (str): The text to extract triples from.
entities (List[Dict]): A list of entities identified in the text.
Returns:
The result of the triples extraction operation.
"""
return self.llm.invoke(
{"input": passage, "entity_list": entities}, self.triple_prompt
)
def assemble_sub_graph_with_spg_records(self, entities: List[Dict]):
sub_graph = SubGraph([], [])
for record in entities:
s_name = record.get("entity", "")
s_label = record.get("category", "")
properties = record.get("properties", {})
tmp_properties = copy.deepcopy(properties)
spg_type = self.schema.get(s_label)
for prop_name, prop_value in properties.items():
if prop_value == "NAN":
tmp_properties.pop(prop_name)
continue
if prop_name in spg_type.properties:
from knext.schema.model.property import Property
prop: Property = spg_type.properties.get(prop_name)
o_label = prop.object_type_name_en
if o_label not in BASIC_TYPES:
if isinstance(prop_value, str):
prop_value = [prop_value]
for o_name in prop_value:
sub_graph.add_node(id=o_name, name=o_name, label=o_label)
sub_graph.add_edge(s_id=s_name, s_label=s_label, p=prop_name, o_id=o_name, o_label=o_label)
tmp_properties.pop(prop_name)
record["properties"] = tmp_properties
sub_graph.add_node(id=s_name, name=s_name, label=s_label, properties=properties)
return sub_graph, entities
@staticmethod
def assemble_sub_graph_with_triples(
sub_graph: SubGraph, entities: List[Dict], triples: List[list]
):
"""
Assembles edges in the subgraph based on a list of triples and entities.
Args:
sub_graph (SubGraph): The subgraph to add edges to.
entities (List[Dict]): A list of entities, for looking up category information.
triples (List[list]): A list of triples, each representing a relationship to be added to the subgraph.
"""
def get_category(entities_data, entity_name):
for entity in entities_data:
if entity["entity"] == entity_name:
return entity["category"]
return None
for tri in triples:
if len(tri) != 3:
continue
s_category = get_category(entities, tri[0])
tri[0] = processing_phrases(tri[0])
if s_category is None:
s_category = OTHER_TYPE
sub_graph.add_node(tri[0], tri[0], s_category)
o_category = get_category(entities, tri[2])
tri[2] = processing_phrases(tri[2])
if o_category is None:
o_category = OTHER_TYPE
sub_graph.add_node(tri[2], tri[2], o_category)
sub_graph.add_edge(
tri[0], s_category, to_camel_case(tri[1]), tri[2], o_category
)
return sub_graph
@staticmethod
def assemble_sub_graph_with_chunk(sub_graph: SubGraph, chunk: Chunk):
"""
Associates a Chunk object with the subgraph, adding it as a node and connecting it with existing nodes.
Args:
sub_graph (SubGraph): The subgraph to add the chunk information to.
chunk (Chunk): The chunk object containing the text and metadata.
"""
for node in sub_graph.nodes:
sub_graph.add_edge(node.id, node.label, "source", chunk.id, CHUNK_TYPE)
sub_graph.add_node(
chunk.id,
chunk.name,
CHUNK_TYPE,
{
"id": chunk.id,
"name": chunk.name,
"content": f"{chunk.name}\n{chunk.content}",
**chunk.kwargs
},
)
sub_graph.id = chunk.id
return sub_graph
def assemble_sub_graph(
self, sub_graph: SubGraph, chunk: Chunk, entities: List[Dict], triples: List[list]
):
"""
Integrates entity and triple information into a subgraph, and associates it with a chunk of text.
Args:
sub_graph (SubGraph): The subgraph to be assembled.
chunk (Chunk): The chunk of text the subgraph is about.
entities (List[Dict]): A list of entities identified in the chunk.
triples (List[list]): A list of triples representing relationships between entities.
Returns:
SubGraph: The constructed subgraph.
"""
self.assemble_sub_graph_with_entities(sub_graph, entities)
self.assemble_sub_graph_with_triples(sub_graph, entities, triples)
self.assemble_sub_graph_with_chunk(sub_graph, chunk)
return sub_graph
def assemble_sub_graph_with_entities(
self, sub_graph: SubGraph, entities: List[Dict]
):
"""
Assembles a subgraph using named entities.
Args:
sub_graph (SubGraph): The subgraph object to be assembled.
entities (List[Dict]): A list containing entity information.
"""
for ent in entities:
name = processing_phrases(ent["entity"])
sub_graph.add_node(
name,
name,
ent["category"],
{
"desc": ent.get("description", ""),
"semanticType": ent.get("type", ""),
**ent.get("properties", {}),
},
)
if "official_name" in ent:
official_name = processing_phrases(ent["official_name"])
if official_name != name:
sub_graph.add_node(
official_name,
official_name,
ent["category"],
{
"desc": ent.get("description", ""),
"semanticType": ent.get("type", ""),
**ent.get("properties", {}),
},
)
sub_graph.add_edge(
name,
ent["category"],
"OfficialName",
official_name,
ent["category"],
)
def append_official_name(
self, source_entities: List[Dict], entities_with_official_name: List[Dict]
):
"""
Appends official names to entities.
Args:
source_entities (List[Dict]): A list of source entities.
entities_with_official_name (List[Dict]): A list of entities with official names.
"""
tmp_dict = {}
for tmp_entity in entities_with_official_name:
name = tmp_entity["entity"]
category = tmp_entity["category"]
official_name = tmp_entity["official_name"]
key = f"{category}{name}"
tmp_dict[key] = official_name
for tmp_entity in source_entities:
name = tmp_entity["entity"]
category = tmp_entity["category"]
key = f"{category}{name}"
if key in tmp_dict:
official_name = tmp_dict[key]
tmp_entity["official_name"] = official_name
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the semantic extractor to process input data.
Args:
input (Input): Input data containing name and content.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of processed results, containing subgraph information.
"""
title = input.name
passage = title + "\n" + input.content
try:
entities = self.named_entity_recognition(passage)
sub_graph, entities = self.assemble_sub_graph_with_spg_records(entities)
filtered_entities = [{k: v for k, v in ent.items() if k in ["entity", "category"]} for ent in entities]
triples = self.triples_extraction(passage, filtered_entities)
std_entities = self.named_entity_standardization(passage, filtered_entities)
self.append_official_name(entities, std_entities)
self.assemble_sub_graph(sub_graph, input, entities, triples)
return [sub_graph]
except Exception as e:
import traceback
traceback.print_exc()
logger.info(e)
return []

View File

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import copy
import logging
from typing import List, Dict
from tenacity import retry, stop_after_attempt
from kag.builder.component.extractor import KAGExtractor
from kag.builder.model.sub_graph import SubGraph
from kag.builder.prompt.spg_prompt import SPG_KGPrompt
from kag.common.base.prompt_op import PromptOp
from knext.common.base.runnable import Input, Output
from knext.schema.client import BASIC_TYPES
logger = logging.getLogger(__name__)
class SPGExtractor(KAGExtractor):
"""
A Builder Component that extracting structured data from long texts by invoking large language model.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spg_ner_types, self.kag_ner_types = [], []
for type_name, spg_type in self.schema.items():
properties = list(spg_type.properties.keys())
for p in properties:
if p not in SPG_KGPrompt.ignored_properties:
self.spg_ner_types.append(type_name)
continue
self.kag_ner_types.append(type_name)
self.kag_ner_prompt = PromptOp.load(self.biz_scene, "ner")(language=self.language, project_id=self.project_id)
self.spg_ner_prompt = SPG_KGPrompt(self.spg_ner_types, self.language)
@retry(stop=stop_after_attempt(3))
def named_entity_recognition(self, passage: str):
"""
Performs named entity recognition on a given text passage.
Args:
passage (str): The text to perform named entity recognition on.
Returns:
The result of the named entity recognition operation.
"""
spg_ner_result = self.llm.batch({"input": passage}, self.spg_ner_prompt)
kag_ner_result = self.llm.invoke({"input": passage}, self.kag_ner_prompt)
return spg_ner_result + kag_ner_result
def assemble_sub_graph_with_spg_records(self, entities: List[Dict]):
sub_graph = SubGraph([], [])
for record in entities:
s_name = record.get("entity", "")
s_label = record.get("category", "")
properties = record.get("properties", {})
tmp_properties = copy.deepcopy(properties)
spg_type = self.schema.get(s_label)
for prop_name, prop_value in properties.items():
if prop_value == "NAN":
tmp_properties.pop(prop_name)
continue
if prop_name in spg_type.properties:
from knext.schema.model.property import Property
prop: Property = spg_type.properties.get(prop_name)
o_label = prop.object_type_name_en
if o_label not in BASIC_TYPES:
if isinstance(prop_value, str):
prop_value = [prop_value]
for o_name in prop_value:
sub_graph.add_node(id=o_name, name=o_name, label=o_label)
sub_graph.add_edge(s_id=s_name, s_label=s_label, p=prop_name, o_id=o_name, o_label=o_label)
tmp_properties.pop(prop_name)
record["properties"] = tmp_properties
sub_graph.add_node(id=s_name, name=s_name, label=s_label, properties=properties)
return sub_graph, entities
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the semantic extractor to process input data.
Args:
input (Input): Input data containing name and content.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of processed results, containing subgraph information.
"""
title = input.name
passage = title + "\n" + input.content
try:
entities = self.named_entity_recognition(passage)
sub_graph, entities = self.assemble_sub_graph_with_spg_records(entities)
filtered_entities = [{k: v for k, v in ent.items() if k in ["entity", "category"]} for ent in entities]
triples = self.triples_extraction(passage, filtered_entities)
std_entities = self.named_entity_standardization(passage, filtered_entities)
self.append_official_name(entities, std_entities)
self.assemble_sub_graph(sub_graph, input, entities, triples)
return [sub_graph]
except Exception as e:
import traceback
traceback.print_exc()
logger.info(e)
return []

View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Dict, List
from knext.common.base.runnable import Input, Output
from kag.interface.builder import ExtractorABC
class UserDefinedExtractor(ExtractorABC):
@property
def input_types(self) -> Input:
return Dict[str, str]
@property
def output_types(self) -> Output:
return Dict[str, str]
def invoke(self, input: Input, **kwargs) -> List[Output]:
return input

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.mapping.spg_type_mapping import SPGTypeMapping
from kag.builder.component.mapping.relation_mapping import RelationMapping
from kag.builder.component.mapping.spo_mapping import SPOMapping
__all__ = [
"SPGTypeMapping",
"RelationMapping",
"SPOMapping",
]

View File

@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from collections import defaultdict
from typing import Dict, List
from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
from knext.schema.client import SchemaClient
from knext.schema.model.schema_helper import (
SPGTypeName,
RelationName,
)
from kag.interface.builder.mapping_abc import MappingABC
class RelationMapping(MappingABC):
"""
A class that handles relation mappings by assembling subgraphs based on given subject, predicate, and object names.
This class extends the Mapping class.
Args:
subject_name (SPGTypeName): The name of the subject type.
predicate_name (RelationName): The name of the predicate.
object_name (SPGTypeName): The name of the object type.
"""
def __init__(
self,
subject_name: SPGTypeName,
predicate_name: RelationName,
object_name: SPGTypeName,
**kwargs
):
super().__init__(**kwargs)
schema = SchemaClient(project_id=self.project_id).load()
assert subject_name in schema, f"{subject_name} is not a valid SPG type name"
assert object_name in schema, f"{object_name} is not a valid SPG type name"
self.subject_type = schema.get(subject_name)
self.object_type = schema.get(object_name)
assert predicate_name in self.subject_type.properties or predicate_name in set(
[key.split("_")[0] for key in self.subject_type.relations.keys()]
), f"{predicate_name} is not a valid SPG property/relation name"
self.predicate_name = predicate_name
self.src_id_field = None
self.dst_id_field = None
self.property_mapping: Dict = defaultdict(list)
self.linking_strategies: Dict = dict()
def add_src_id_mapping(self, source_name: str):
"""
Adds a field mapping from source data to the subject's ID property.
Args:
source_name (str): The name of the source field to map.
Returns:
self
"""
self.src_id_field = source_name
return self
def add_dst_id_mapping(self, source_name: str):
"""
Adds a field mapping from source data to the object's ID property.
Args:
source_name (str): The name of the source field to map.
Returns:
self
"""
self.dst_id_field = source_name
return self
def add_sub_property_mapping(self, source_name: str, target_name: str):
"""
Adds a field mapping from source data to a property of the subject type.
Args:
source_name (str): The source field to be mapped.
target_name (str): The target field to map the source field to.
Returns:
self
"""
self.property_mapping[target_name].append(source_name)
return self
@property
def input_types(self) -> Input:
return Dict[str, str]
@property
def output_types(self) -> Output:
return SubGraph
def assemble_sub_graph(self, record: Dict[str, str]) -> SubGraph:
"""
Assembles a subgraph from the provided record.
Args:
record (Dict[str, str]): The record containing the data to assemble into a subgraph.
Returns:
SubGraph: The assembled subgraph.
"""
sub_graph = SubGraph(nodes=[], edges=[])
if self.property_mapping:
s_id = record.get(self.src_id_field or "srcId")
o_id = record.get(self.dst_id_field or "dstId")
sub_properties = {}
for target_name, source_names in self.property_mapping.items():
for source_name in source_names:
value = record.get(source_name)
sub_properties[target_name] = value
else:
s_id = record.pop(self.src_id_field or "srcId")
o_id = record.pop(self.dst_id_field or "dstId")
sub_properties = record
sub_graph.add_edge(
s_id=s_id,
s_label=self.subject_type.name_en,
p=self.predicate_name,
o_id=o_id,
o_label=self.object_type.name_en,
properties=sub_properties,
)
return sub_graph
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the assembly process to create a subgraph from the input data.
Args:
input (Input): The input data to assemble into a subgraph.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list containing the assembled subgraph.
"""
sub_graph = self.assemble_sub_graph(input)
return [sub_graph]

View File

@ -0,0 +1,199 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from collections import defaultdict
from typing import Dict, List, Callable
import pandas
from knext.schema.client import BASIC_TYPES
from kag.builder.model.sub_graph import SubGraph, Node
from knext.common.base.runnable import Input, Output
from knext.schema.client import SchemaClient
from knext.schema.model.base import SpgTypeEnum
from knext.schema.model.schema_helper import (
SPGTypeName,
PropertyName,
)
from kag.interface.builder.mapping_abc import MappingABC
FuseFunc = Callable[[SubGraph], List[SubGraph]]
LinkFunc = Callable[[str, Node], List[Node]]
class SPGTypeMapping(MappingABC):
"""
A class for mapping SPG (Simple Property Graph) types and handling their properties and strategies.
Attributes:
spg_type_name (SPGTypeName): The name of the SPG type.
fuse_op (FuseOpABC, optional): The user-defined fuse operator. Defaults to None.
"""
def __init__(self, spg_type_name: SPGTypeName, fuse_func: FuseFunc = None, **kwargs):
super().__init__(**kwargs)
self.schema = SchemaClient(project_id=self.project_id).load()
assert (
spg_type_name in self.schema
), f"SPG type [{spg_type_name}] does not exist."
self.spg_type = self.schema.get(spg_type_name)
self.property_mapping: Dict = defaultdict(list)
self.link_funcs: Dict = dict()
self.fuse_func = fuse_func
def add_property_mapping(
self,
source_name: str,
target_name: PropertyName,
link_func: LinkFunc = None,
):
"""
Adds a property mapping from a source name to a target name within the SPG type.
Args:
source_name (str): The source name of the property.
target_name (PropertyName): The target name of the property within the SPG type.
link_func (LinkFunc, optional): The user-defined link operator. Defaults to None.
Returns:
self
"""
if (
target_name not in ["id", "name"]
and target_name not in self.spg_type.properties
):
raise ValueError(
f"Property [{target_name}] does not exist in [{self.spg_type.name}]."
)
self.property_mapping[target_name].append(source_name)
if link_func is not None:
self.link_funcs[target_name] = link_func
return self
@property
def input_types(self) -> Input:
return Dict[str, str]
@property
def output_types(self) -> Output:
return SubGraph
def field_mapping(self, record: Dict[str, str]) -> Dict[str, str]:
"""
Maps fields from a record based on the defined property mappings.
Args:
record (Dict[str, str]): The input record containing source names and values.
Returns:
Dict[str, str]: A mapped record with target names and corresponding values.
"""
mapped_record = {}
for target_name, source_names in self.property_mapping.items():
for source_name in source_names:
value = record.get(source_name)
mapped_record[target_name] = value
return mapped_record
def assemble_sub_graph(self, properties: Dict[str, str]):
"""
Assembles a sub-graph based on the provided properties and linking strategies.
Args:
properties (Dict[str, str]): The properties to be used for assembling the sub-graph.
Returns:
SubGraph: The assembled sub-graph.
"""
sub_graph = SubGraph(nodes=[], edges=[])
s_id = properties.get("id", "")
s_name = properties.get("name", s_id)
s_label = self.spg_type.name_en
for prop_name, prop_value in properties.items():
if not prop_value or prop_value == pandas.NaT:
continue
if prop_name in self.spg_type.properties:
prop = self.spg_type.properties.get(prop_name)
o_label = prop.object_type_name_en
if o_label not in BASIC_TYPES:
prop_value_list = prop_value.split(",")
for o_id in prop_value_list:
if prop_name in self.link_funcs:
link_func = self.link_funcs.get(prop_name)
o_ids = link_func(o_id, properties)
for _o_id in o_ids:
sub_graph.add_edge(
s_id=s_id,
s_label=s_label,
p=prop_name,
o_id=_o_id,
o_label=o_label,
)
else:
sub_graph.add_edge(
s_id=s_id,
s_label=s_label,
p=prop_name,
o_id=o_id,
o_label=o_label,
)
if self.spg_type.spg_type_enum == SpgTypeEnum.Concept:
self.hypernym_predicate(sub_graph, s_id)
else:
sub_graph.add_node(
id=s_id, name=s_name, label=s_label, properties=properties
)
return sub_graph
def hypernym_predicate(self, sub_graph: SubGraph, concept_id: str):
"""
Adds hypernym predicates to the sub-graph based on the provided concept ID.
Args:
sub_graph (SubGraph): The sub-graph to which hypernym predicates will be added.
concept_id (str): The ID of the concept.
"""
p = getattr(self.spg_type, "hypernym_predicate") or "isA"
label = self.spg_type.name_en
concept_list = concept_id.split("-")
father_id = ""
for concept_name in concept_list:
concept_id = father_id + "-" + concept_name if father_id else concept_name
sub_graph.add_node(id=concept_id, name=concept_name, label=label)
if father_id:
sub_graph.add_edge(
s_id=concept_id, s_label=label, p=p, o_id=father_id, o_label=label
)
father_id = concept_id
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the mapping process on the given input and returns the resulting sub-graphs.
Args:
input (Input): The input data to be processed.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of resulting sub-graphs.
"""
if self.property_mapping:
properties = self.field_mapping(input)
else:
properties = input
sub_graph = self.assemble_sub_graph(properties)
return [sub_graph]

View File

@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from collections import defaultdict
from typing import List, Type, Dict
from kag.interface.builder.mapping_abc import MappingABC
from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
from knext.schema.client import OTHER_TYPE
class SPOMapping(MappingABC):
def __init__(self):
super().__init__()
self.s_type_col = None
self.s_id_col = None
self.p_type_col = None
self.o_type_col = None
self.o_id_col = None
self.sub_property_mapping = defaultdict(list)
self.sub_property_col = None
@property
def input_types(self) -> Type[Input]:
return Dict[str, str]
@property
def output_types(self) -> Type[Output]:
return SubGraph
def add_field_mappings(self, s_id_col: str, p_type_col: str, o_id_col: str, s_type_col: str = None, o_type_col: str = None):
self.s_type_col = s_type_col
self.s_id_col = s_id_col
self.p_type_col = p_type_col
self.o_type_col = o_type_col
self.o_id_col = o_id_col
return self
def add_sub_property_mapping(self, source_name: str, target_name: str = None):
"""
Adds a field mapping from source data to a property of the subject type.
Args:
source_name (str): The source field to be mapped.
target_name (str): The target field to map the source field to.
Returns:
self
"""
if self.sub_property_col:
raise ValueError("Fail to add sub property mapping.")
if not target_name:
self.sub_property_col = source_name
else:
self.sub_property_mapping[target_name].append(source_name)
return self
def assemble_sub_graph(self, record: Dict[str, str]):
"""
Assembles a subgraph from the provided record.
Args:
record (Dict[str, str]): The record containing the data to assemble into a subgraph.
Returns:
SubGraph: The assembled subgraph.
"""
sub_graph = SubGraph(nodes=[], edges=[])
s_type = record.get(self.s_type_col) or OTHER_TYPE
s_id = record.get(self.s_id_col) or ""
p = record.get(self.p_type_col) or ""
o_type = record.get(self.o_type_col) or OTHER_TYPE
o_id = record.get(self.o_id_col) or ""
sub_graph.add_node(id=s_id, name=s_id, label=s_type)
sub_graph.add_node(id=o_id, name=o_id, label=o_type)
sub_properties = {}
if self.sub_property_col:
sub_properties = json.loads(record.get(self.sub_property_col, '{}'))
sub_properties = {k: str(v) for k, v in sub_properties.items()}
else:
for target_name, source_names in self.sub_property_mapping.items():
for source_name in source_names:
value = record.get(source_name)
sub_properties[target_name] = value
sub_graph.add_edge(s_id=s_id, s_label=s_type, p=p, o_id=o_id, o_label=o_type, properties=sub_properties)
return sub_graph
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the mapping process on the given input and returns the resulting sub-graphs.
Args:
input (Input): The input data to be processed.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of resulting sub-graphs.
"""
record: Dict[str, str] = input
sub_graph = self.assemble_sub_graph(record)
return [sub_graph]

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.reader.csv_reader import CSVReader
from kag.builder.component.reader.pdf_reader import PDFReader
from kag.builder.component.reader.json_reader import JSONReader
from kag.builder.component.reader.markdown_reader import MarkDownReader
from kag.builder.component.reader.docx_reader import DocxReader
from kag.builder.component.reader.txt_reader import TXTReader
from kag.builder.component.reader.dataset_reader import HotpotqaCorpusReader, TwowikiCorpusReader, MusiqueCorpusReader
from kag.builder.component.reader.yuque_reader import YuqueReader
__all__ = [
"TXTReader",
"PDFReader",
"MarkDownReader",
"JSONReader",
"HotpotqaCorpusReader",
"MusiqueCorpusReader",
"TwowikiCorpusReader",
"YuqueReader",
"CSVReader",
"DocxReader",
]

View File

@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from typing import List, Type, Dict
import pandas as pd
from kag.builder.model.chunk import Chunk
from kag.interface.builder.reader_abc import SourceReaderABC
from knext.common.base.runnable import Input, Output
class CSVReader(SourceReaderABC):
"""
A class for reading CSV files, inheriting from `SourceReader`.
Supports converting CSV data into either a list of dictionaries or a list of Chunk objects.
Args:
output_type (Output): Specifies the output type, which can be "Dict" or "Chunk".
**kwargs: Additional keyword arguments passed to the parent class constructor.
"""
def __init__(self, output_type="Chunk", **kwargs):
super().__init__(**kwargs)
if output_type == "Dict":
self.output_types = Dict[str, str]
else:
self.output_types = Chunk
self.id_col = kwargs.get("id_col", "id")
self.name_col = kwargs.get("name_col", "name")
self.content_col = kwargs.get("content_col", "content")
@property
def input_types(self) -> Type[Input]:
return str
@property
def output_types(self) -> Type[Output]:
return self._output_types
@output_types.setter
def output_types(self, output_types):
self._output_types = output_types
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Reads a CSV file and converts the data format based on the output type.
Args:
input (Input): Input parameter, expected to be a string representing the path to the CSV file.
**kwargs: Additional keyword arguments, which may include `id_column`, `name_column`, `content_column`, etc.
Returns:
List[Output]:
- If `output_types` is `Chunk`, returns a list of Chunk objects.
- If `output_types` is `Dict`, returns a list of dictionaries.
"""
try:
data = pd.read_csv(input)
data = data.astype(str)
except Exception as e:
raise IOError(f"Failed to read the file: {e}")
if self.output_types == Chunk:
chunks = []
basename, _ = os.path.splitext(os.path.basename(input))
for idx, row in enumerate(data.to_dict(orient="records")):
kwargs = {k: v for k, v in row.items() if k not in [self.id_col, self.name_col, self.content_col]}
chunks.append(
Chunk(
id=row.get(self.id_col) or Chunk.generate_hash_id(f"{input}#{idx}"),
name=row.get(self.name_col) or f"{basename}#{idx}",
content=row[self.content_col],
**kwargs
)
)
return chunks
else:
return data.to_dict(orient="records")

View File

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import os
from typing import List, Type
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
from knext.common.base.runnable import Input, Output
class HotpotqaCorpusReader(SourceReaderABC):
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return str
@property
def output_types(self) -> Type[Output]:
"""The type of output this Runnable object produces specified as a type annotation."""
return Chunk
def invoke(self, input: str, **kwargs) -> List[Output]:
if os.path.exists(str(input)):
with open(input, "r") as f:
corpus = json.load(f)
else:
corpus = json.loads(input)
chunks = []
for item_key, item_value in corpus.items():
chunk = Chunk(
id=item_key,
name=item_key,
content="\n".join(item_value),
)
chunks.append(chunk)
return chunks
class MusiqueCorpusReader(SourceReaderABC):
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return str
@property
def output_types(self) -> Type[Output]:
"""The type of output this Runnable object produces specified as a type annotation."""
return Chunk
def get_basename(self, file_name: str):
base, ext = os.path.splitext(os.path.basename(file_name))
return base
def invoke(self, input: str, **kwargs) -> List[Output]:
id_column = kwargs.get("id_column", "title")
name_column = kwargs.get("name_column", "title")
content_column = kwargs.get("content_column", "text")
if os.path.exists(str(input)):
with open(input, "r") as f:
corpusList = json.load(f)
else:
corpusList = input
chunks = []
for item in corpusList:
chunk = Chunk(
id=item[id_column],
name=item[name_column],
content=item[content_column],
)
chunks.append(chunk)
return chunks
class TwowikiCorpusReader(MusiqueCorpusReader):
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return str
@property
def output_types(self) -> Type[Output]:
"""The type of output this Runnable object produces specified as a type annotation."""
return Chunk

View File

@ -0,0 +1,179 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from typing import List, Type,Union
from docx import Document
from kag.builder.component.reader import MarkDownReader
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
from knext.common.base.runnable import Input, Output
from kag.common.llm.client import LLMClient
from kag.builder.prompt.outline_prompt import OutlinePrompt
def split_txt(content):
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
p = pipeline(
task=Tasks.document_segmentation,
model='damo/nlp_bert_document-segmentation_chinese-base')
result = p(documents=content)
result = result[OutputKeys.TEXT]
res = [r for r in result.split('\n\t') if len(r) > 0]
return res
class DocxReader(SourceReaderABC):
"""
A class for reading Docx files, inheriting from SourceReader.
This class is specifically designed to extract text content from Docx files and generate Chunk objects based on the extracted content.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.split_level = kwargs.get("split_level", 3)
self.split_using_outline = kwargs.get("split_using_outline", True)
self.outline_flag = True
self.llm = self._init_llm()
language = os.getenv("KAG_PROMPT_LANGUAGE", "zh")
self.prompt = OutlinePrompt(language)
@property
def input_types(self) -> Type[Input]:
return str
@property
def output_types(self) -> Type[Output]:
return Chunk
def outline_chunk(self, chunk: Union[Chunk, List[Chunk]],basename) -> List[Chunk]:
if isinstance(chunk, Chunk):
chunk = [chunk]
outlines = []
for c in chunk:
outline = self.llm.invoke({"input": c.content}, self.prompt)
outlines.extend(outline)
content = "\n".join([c.content for c in chunk])
chunks = self.sep_by_outline(content, outlines,basename)
return chunks
def sep_by_outline(self,content,outlines,basename):
position_check = []
for outline in outlines:
start = content.find(outline)
position_check.append((outline,start))
chunks = []
for idx,pc in enumerate(position_check):
chunk = Chunk(
id = Chunk.generate_hash_id(f"{basename}#{pc[0]}"),
name=f"{basename}#{pc[0]}",
content=content[pc[1]:position_check[idx+1][1] if idx+1 < len(position_check) else len(position_check)],
)
chunks.append(chunk)
return chunks
@staticmethod
def _extract_text_from_docx(doc: Document) -> str:
"""
Extracts text from a Docx document.
This method iterates through all paragraphs in the provided Docx document,
appending each paragraph's text to a list, and then joins these texts into
a single string separated by newline characters, effectively extracting the
entire text content of the document.
Args:
doc (Document): A Document object representing the Docx file from which
text is to be extracted.
Returns:
str: A string containing all the text from the Docx document, with paragraphs
separated by newline characters.
"""
full_text = []
for para in doc.paragraphs:
full_text.append(para.text)
return full_text
def _get_title_from_text(self, text: str) -> str:
text = text.strip()
title = text.split('\n')[0]
text = "\n".join(text.split('\n'))
return title,text
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Processes the input Docx file, extracts its text content, and generates a Chunk object.
Args:
input (Input): The file path of the Docx file to be processed.
**kwargs: Additional keyword arguments, not used in the current implementation.
Returns:
List[Output]: A list containing a single Chunk object with the extracted text.
Raises:
ValueError: If the input is empty.
IOError: If the file cannot be read or the text extraction fails.
"""
if not input:
raise ValueError("Input cannot be empty")
chunks = []
try:
doc = Document(input)
full_text = self._extract_text_from_docx(doc)
content = "\n".join(full_text)
except OSError as e:
raise IOError(f"Failed to read file: {input}") from e
basename, _ = os.path.splitext(os.path.basename(input))
for text in full_text:
title,text = self._get_title_from_text(text)
chunk = Chunk(
id=Chunk.generate_hash_id(f"{basename}#{title}"),
name=f"{basename}#{title}",
content=text,
)
chunks.append(chunk)
if len(chunks) < 2:
chunks = self.outline_chunk(chunks,basename)
if len(chunks) < 2:
semantic_res = split_txt(content)
chunks = [Chunk(
id=Chunk.generate_hash_id(input+"#"+r[:10]),
name=basename+"#"+r[:10],
content=r,
) for r in semantic_res]
return chunks
if __name__== "__main__":
reader = DocxReader()
print(reader.output_types)
file_path = os.path.dirname(__file__)
res = reader.invoke(os.path.join(file_path,"../../../../tests/builder/data/test_docx.docx"))
print(res)

View File

@ -0,0 +1,164 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import os
from typing import List, Type, Dict, Union
from kag.builder.component.reader.markdown_reader import MarkDownReader
from kag.builder.model.chunk import Chunk
from kag.interface.builder.reader_abc import SourceReaderABC
from knext.common.base.runnable import Input, Output
from kag.common.llm.client import LLMClient
class JSONReader(SourceReaderABC):
"""
A class for reading JSON files, inheriting from `SourceReader`.
Supports converting JSON data into either a list of dictionaries or a list of Chunk objects.
Args:
output_types (Output): Specifies the output type, which can be "Dict" or "Chunk".
**kwargs: Additional keyword arguments passed to the parent class constructor.
"""
def __init__(self, output_type="Chunk", **kwargs):
super().__init__(**kwargs)
if output_type == "Dict":
self.output_types = Dict[str, str]
else:
self.output_types = Chunk
self.id_col = kwargs.get("id_col", "id")
self.name_col = kwargs.get("name_col", "name")
self.content_col = kwargs.get("content_col", "content")
@property
def input_types(self) -> Type[Input]:
return str
@property
def output_types(self) -> Type[Output]:
return self._output_types
@output_types.setter
def output_types(self, output_types):
self._output_types = output_types
@staticmethod
def _read_from_file(file_path: str) -> Union[dict, list]:
"""
Safely reads JSON from a file and returns its content.
Args:
file_path (str): The path to the JSON file.
Returns:
Union[dict, list]: The parsed JSON content.
Raises:
ValueError: If there is an error reading the JSON file.
"""
try:
with open(file_path, "r") as file:
return json.load(file)
except json.JSONDecodeError as e:
raise ValueError(f"Error reading JSON from file: {e}")
except FileNotFoundError as e:
raise ValueError(f"File not found: {e}")
@staticmethod
def _parse_json_string(json_string: str) -> Union[dict, list]:
"""
Parses a JSON string and returns its content.
Args:
json_string (str): The JSON string to parse.
Returns:
Union[dict, list]: The parsed JSON content.
Raises:
ValueError: If there is an error parsing the JSON string.
"""
try:
return json.loads(json_string)
except json.JSONDecodeError as e:
raise ValueError(f"Error parsing JSON string: {e}")
def invoke(self, input: str, **kwargs) -> List[Output]:
"""
Parses the input string data and generates a list of Chunk objects or returns the original data.
This method supports receiving JSON-formatted strings. It extracts specific fields based on provided keyword arguments.
It can read from a file or directly parse a string. If the input data is in the expected format, it generates a list of Chunk objects;
otherwise, it throws a ValueError if the input is not a JSON array or object.
Args:
input (str): The input data, which can be a JSON string or a file path.
**kwargs: Keyword arguments used to specify the field names for ID, name, and content.
Returns:
List[Output]: A list of Chunk objects or the original data.
Raises:
ValueError: If the input data format is incorrect or parsing fails.
"""
id_col = kwargs.get("id_col", "id")
name_col = kwargs.get("name_col", "name")
content_col = kwargs.get("content_col", "content")
self.id_col = id_col
self.name_col = name_col
self.content_col = content_col
try:
if os.path.exists(input):
corpus = self._read_from_file(input)
else:
corpus = self._parse_json_string(input)
except ValueError as e:
raise e
if not isinstance(corpus, (list, dict)):
raise ValueError("Expected input to be a JSON array or object")
if isinstance(corpus, dict):
corpus = [corpus]
if self.output_types == Chunk:
chunks = []
basename, _ = os.path.splitext(os.path.basename(input))
for idx, item in enumerate(corpus):
if not isinstance(item, dict):
continue
chunk = Chunk(
id=item.get(self.id_col) or Chunk.generate_hash_id(f"{input}#{idx}"),
name=item.get(self.name_col) or f"{basename}#{idx}",
content=item.get(self.content_col),
)
chunks.append(chunk)
return chunks
else:
return corpus
if __name__ == "__main__":
reader = JSONReader()
json_string = '''[
{
"title": "test_json",
"text": "Test content"
}
]'''
chunks = reader.invoke(json_string,name_column="title",content_col = "text")
res = 1

View File

@ -0,0 +1,414 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import bs4.element
import markdown
from bs4 import BeautifulSoup, Tag
from typing import List, Type
import logging
import re
import requests
import pandas as pd
from io import StringIO
from tenacity import stop_after_attempt, retry
from kag.interface.builder import SourceReaderABC
from kag.builder.model.chunk import Chunk, ChunkTypeEnum
from knext.common.base.runnable import Output, Input
from kag.builder.prompt.analyze_table_prompt import AnalyzeTablePrompt
class MarkDownReader(SourceReaderABC):
"""
A class for reading MarkDown files, inheriting from `SourceReader`.
Supports converting MarkDown data into a list of Chunk objects.
Args:
cut_depth (int): The depth of cutting, determining the level of detail in parsing. Default is 1.
"""
ALL_LEVELS = [f"h{x}" for x in range(1, 7)]
TABLE_CHUCK_FLAG = "<<<table_chuck>>>"
def __init__(self, cut_depth: int = 1, **kwargs):
super().__init__(**kwargs)
self.cut_depth = int(cut_depth)
self.llm_module = kwargs.get("llm_module", None)
self.analyze_table_prompt = AnalyzeTablePrompt(language="zh")
self.analyze_img_prompt = AnalyzeTablePrompt(language="zh")
@property
def input_types(self) -> Type[Input]:
return str
@property
def output_types(self) -> Type[Output]:
return Chunk
def to_text(self, level_tags):
"""
Converts parsed hierarchical tags into text content.
Args:
level_tags (list): Parsed tags organized by Markdown heading levels and other tags.
Returns:
str: Text content derived from the parsed tags.
"""
content = []
for item in level_tags:
if isinstance(item, list):
content.append(self.to_text(item))
else:
header, tag = item
if not isinstance(tag, Tag):
continue
elif tag.name in self.ALL_LEVELS:
content.append(
f"{header}-{tag.text}" if len(header) > 0 else tag.text
)
else:
content.append(self.tag_to_text(tag))
return "\n".join(content)
def tag_to_text(self, tag: bs4.element.Tag):
"""
将html tag转换为text
如果是table输出markdown添加表格标记方便后续构建Chunk
:param tag:
:return:
"""
if tag.name == "table":
try:
html_table = str(tag)
table_df = pd.read_html(html_table)[0]
return f"{self.TABLE_CHUCK_FLAG}{table_df.to_markdown(index=False)}{self.TABLE_CHUCK_FLAG}"
except:
logging.warning("parse table tag to text error", exc_info=True)
return tag.text
@retry(stop=stop_after_attempt(5))
def analyze_table(self, table,analyze_mathod="human"):
if analyze_mathod == "llm":
if self.llm_module == None:
logging.INFO("llm_module is None, cannot use analyze_table")
return table
variables = {
"table": table
}
response = self.llm_module.invoke(
variables = variables,
prompt_op = self.analyze_table_prompt,
with_json_parse=False
)
if response is None or response == "" or response == []:
raise Exception("llm_module return None")
return response
else:
from io import StringIO
import pandas as pd
try:
df = pd.read_html(StringIO(table))[0]
except Exception as e:
logging.warning(f"analyze_table error: {e}")
return table
content = ""
for index, row in df.iterrows():
content+=f"{index+1}行的数据如下:"
for col_name, value in row.items():
content+=f"{col_name}的值为{value}"
content+='\n'
return content
@retry(stop=stop_after_attempt(5))
def analyze_img(self, img_url):
response = requests.get(img_url)
response.raise_for_status()
image_data = response.content
pass
def replace_table(self, content: str):
pattern = r"<table[^>]*>([\s\S]*?)<\/table>"
for match in re.finditer(pattern, content):
table = match.group(0)
table = self.analyze_table(table)
content = content.replace(match.group(1), table)
return content
def replace_img(self, content: str):
pattern = r"<img[^>]*src=[\"\']([^\"\']*)[\"\']"
for match in re.finditer(pattern, content):
img_url = match.group(1)
img_msg = self.analyze_img(img_url)
content = content.replace(match.group(0), img_msg)
return content
def extract_table(self, level_tags, header=""):
"""
Extracts tables from the parsed hierarchical tags along with their headers.
Args:
level_tags (list): Parsed tags organized by Markdown heading levels and other tags.
header (str): Current header text being processed.
Returns:
list: A list of tuples, each containing the table's header, context text, and the table tag.
"""
tables = []
for idx, item in enumerate(level_tags):
if isinstance(item, list):
tables += self.extract_table(item, header)
else:
tag = item[1]
if not isinstance(tag, Tag):
continue
if tag.name in self.ALL_LEVELS:
header = f"{header}-{tag.text}" if len(header) > 0 else tag.text
if tag.name == "table":
if idx - 1 >= 0:
context = level_tags[idx - 1]
if isinstance(context, tuple):
tables.append((header, context[1].text, tag))
else:
tables.append((header, "", tag))
return tables
def parse_level_tags(
self,
level_tags: list,
level: str,
parent_header: str = "",
cur_header: str = "",
):
"""
Recursively parses level tags to organize them into a structured format.
Args:
level_tags (list): A list of tags to be parsed.
level (str): The current level being processed.
parent_header (str): The header of the parent tag.
cur_header (str): The header of the current tag.
Returns:
list: A structured representation of the parsed tags.
"""
if len(level_tags) == 0:
return []
output = []
prefix_tags = []
while len(level_tags) > 0:
tag = level_tags[0]
if tag.name in self.ALL_LEVELS:
break
else:
prefix_tags.append((parent_header, level_tags.pop(0)))
if len(prefix_tags) > 0:
output.append(prefix_tags)
cur = []
while len(level_tags) > 0:
tag = level_tags[0]
if tag.name not in self.ALL_LEVELS:
cur.append((parent_header, level_tags.pop(0)))
else:
if tag.name > level:
cur += self.parse_level_tags(
level_tags,
tag.name,
f"{parent_header}-{cur_header}"
if len(parent_header) > 0
else cur_header,
tag.name,
)
elif tag.name == level:
if len(cur) > 0:
output.append(cur)
cur = [(parent_header, level_tags.pop(0))]
cur_header = tag.text
else:
if len(cur) > 0:
output.append(cur)
return output
if len(cur) > 0:
output.append(cur)
return output
def cut(self, level_tags, cur_level, final_level):
"""
Cuts the provided level tags into chunks based on the specified levels.
Args:
level_tags (list): A list of tags to be cut.
cur_level (int): The current level in the hierarchy.
final_level (int): The final level to which the tags should be cut.
Returns:
list: A list of cut chunks.
"""
output = []
if cur_level == final_level:
cur_prefix = []
for sublevel_tags in level_tags:
if (
isinstance(sublevel_tags, tuple)
):
cur_prefix.append(self.to_text([sublevel_tags,]))
else:
break
cur_prefix = "\n".join(cur_prefix)
if len(cur_prefix) > 0:
output.append(cur_prefix)
for sublevel_tags in level_tags:
if isinstance(sublevel_tags, list):
output.append(cur_prefix + "\n" + self.to_text(sublevel_tags))
return output
else:
cur_prefix = []
for sublevel_tags in level_tags:
if (
isinstance(sublevel_tags, tuple)
):
cur_prefix.append(sublevel_tags[1].text)
else:
break
cur_prefix = "\n".join(cur_prefix)
if len(cur_prefix) > 0:
output.append(cur_prefix)
for sublevel_tags in level_tags:
if isinstance(sublevel_tags, list):
output += self.cut(sublevel_tags, cur_level + 1, final_level)
return output
def solve_content(self, id: str, title: str, content: str, **kwargs) -> List[Output]:
"""
Converts Markdown content into structured chunks.
Args:
id (str): An identifier for the content.
title (str): The title of the content.
content (str): The Markdown formatted content to be processed.
Returns:
List[Output]: A list of processed content chunks.
"""
html_content = markdown.markdown(
content, extensions=["markdown.extensions.tables"]
)
# html_content = self.replace_table(html_content)
soup = BeautifulSoup(html_content, "html.parser")
if soup is None:
raise ValueError("The MarkDown file appears to be empty or unreadable.")
top_level = None
for level in self.ALL_LEVELS:
tmp = soup.find_all(level)
if len(tmp) > 0:
top_level = level
break
if top_level is None:
chunk = Chunk(
id=Chunk.generate_hash_id(str(id)),
name=title,
content=soup.text,
ref=kwargs.get("ref", ""),
)
return [chunk]
tags = [tag for tag in soup.children if isinstance(tag, Tag)]
level_tags = self.parse_level_tags(tags, top_level)
cutted = self.cut(level_tags, 0, self.cut_depth)
chunks = []
for idx, content in enumerate(cutted):
chunk = None
if self.TABLE_CHUCK_FLAG in content:
chunk = self.get_table_chuck(content, title, id, idx)
chunk.ref = kwargs.get("ref", "")
else:
chunk = Chunk(
id=Chunk.generate_hash_id(f"{id}#{idx}"),
name=f"{title}#{idx}",
content=content,
ref=kwargs.get("ref", ""),
)
chunks.append(chunk)
return chunks
def get_table_chuck(self, table_chunk_str: str, title: str, id: str, idx: int) -> Chunk:
"""
convert table chunk
:param table_chunk_str:
:return:
"""
table_chunk_str = table_chunk_str.replace("\\N", "")
pattern = f"{self.TABLE_CHUCK_FLAG}(.*){self.TABLE_CHUCK_FLAG}"
matches = re.findall(pattern, table_chunk_str, re.DOTALL)
if not matches or len(matches) <= 0:
# 找不到表格信息按照Text Chunk处理
return Chunk(
id=Chunk.generate_hash_id(f"{id}#{idx}"),
name=f"{title}#{idx}",
content=table_chunk_str,
)
table_markdown_str = matches[0]
html_table_str = markdown.markdown(table_markdown_str, extensions=["markdown.extensions.tables"])
try:
df = pd.read_html(html_table_str)[0]
except Exception as e:
logging.warning(f"get_table_chuck error: {e}")
df = pd.DataFrame()
# 确认是表格Chunk去除内容中的TABLE_CHUCK_FLAG
replaced_table_text = re.sub(pattern, f'\n{table_markdown_str}\n', table_chunk_str, flags=re.DOTALL)
return Chunk(
id=Chunk.generate_hash_id(f"{id}#{idx}"),
name=f"{title}#{idx}",
content=replaced_table_text,
type=ChunkTypeEnum.Table,
csv_data=df.to_csv(index=False),
)
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Processes a Markdown file and returns its content as structured chunks.
Args:
input (Input): The path to the Markdown file.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of processed content chunks.
"""
file_path: str = input
if not file_path.endswith(".md"):
raise ValueError(f"Please provide a markdown file, got {file_path}")
if not os.path.isfile(file_path):
raise FileNotFoundError(f"The file {file_path} does not exist.")
with open(file_path, "r") as reader:
content = reader.read()
basename, _ = os.path.splitext(os.path.basename(file_path))
chunks = self.solve_content(input, basename, content)
return chunks

View File

@ -0,0 +1,254 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import re
from typing import List, Sequence, Type, Union
from langchain_community.document_loaders import PyPDFLoader
import pdfminer.layout
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
from knext.common.base.runnable import Input, Output
from kag.builder.prompt.outline_prompt import OutlinePrompt
from pdfminer.high_level import extract_text
from pdfminer.high_level import extract_pages
from pdfminer.layout import LTTextContainer, LTPage
from pdfminer.pdfparser import PDFParser
from pdfminer.pdfdocument import PDFDocument
from pdfminer.layout import LAParams,LTTextBox
from pdfminer.pdfpage import PDFPage
from pdfminer.pdfparser import PDFParser
from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter
from pdfminer.converter import PDFPageAggregator
from pdfminer.pdfpage import PDFTextExtractionNotAllowed
import pdfminer
import logging
logger = logging.getLogger(__name__)
class PDFReader(SourceReaderABC):
"""
A PDF reader class that inherits from SourceReader.
Attributes:
if_split (bool): Whether to split the content by pages. Default is False.
use_pypdf (bool): Whether to use PyPDF2 for processing PDF files. Default is True.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.split_level = kwargs.get("split_level", 3)
self.split_using_outline = kwargs.get("split_using_outline", True)
self.outline_flag = True
self.llm = self._init_llm()
language = os.getenv("KAG_PROMPT_LANGUAGE", "zh")
self.prompt = OutlinePrompt(language)
@property
def input_types(self) -> Type[Input]:
return str
@property
def output_types(self) -> Type[Output]:
return Chunk
def outline_chunk(self, chunk: Union[Chunk, List[Chunk]],basename) -> List[Chunk]:
if isinstance(chunk, Chunk):
chunk = [chunk]
outlines = []
for c in chunk:
outline = self.llm.invoke({"input": c.content}, self.prompt)
outlines.extend(outline)
content = "\n".join([c.content for c in chunk])
chunks = self.sep_by_outline(content, outlines,basename)
return chunks
def sep_by_outline(self,content,outlines,basename):
position_check = []
for outline in outlines:
start = content.find(outline)
position_check.append((outline,start))
chunks = []
for idx,pc in enumerate(position_check):
chunk = Chunk(
id = Chunk.generate_hash_id(f"{basename}#{pc[0]}"),
name=f"{basename}#{pc[0]}",
content=content[pc[1]:position_check[idx+1][1] if idx+1 < len(position_check) else len(position_check)],
)
chunks.append(chunk)
return chunks
@staticmethod
def _process_single_page(
page: str,
watermark: str,
remove_header: bool = False,
remove_footnote: bool = False,
remove_lists: List[str] = None,
) -> list:
"""
Processes a single page of text, removing headers, footnotes, watermarks, and specified lists.
Args:
page (str): The text content of a single page.
watermark (str): The watermark text to be removed.
remove_header (bool): Whether to remove the header. Default is False.
remove_footnote (bool): Whether to remove the footnote. Default is False.
remove_lists (List[str]): A list of strings to be removed. Default is None.
Returns:
list: A list of processed text lines.
"""
lines = page.split("\n")
if remove_header and len(lines) > 0:
lines = lines[1:]
if remove_footnote and len(lines) > 0:
lines = lines[:-1]
cleaned = [line.strip().replace(watermark, "") for line in lines]
if remove_lists is None:
return cleaned
for s in remove_lists:
cleaned = [line.strip().replace(s, "") for line in cleaned]
return cleaned
@staticmethod
def _extract_text_from_page(page_layout: LTPage) -> str:
"""
Extracts text from a given page layout.
Args:
page_layout (LTPage): The layout of the page containing text elements.
Returns:
str: The extracted text.
"""
text = ""
for element in page_layout:
if isinstance(element, LTTextContainer):
text += element.get_text()
return text
def invoke(self, input: str, **kwargs) -> Sequence[Output]:
"""
Processes a PDF file, splitting or extracting content based on configuration.
Args:
input (str): The path to the PDF file.
**kwargs: Additional keyword arguments, such as `clean_list`.
Returns:
Sequence[Output]: A sequence of processed outputs.
Raises:
ValueError: If the file is not a PDF file or the content is empty/unreadable.
FileNotFoundError: If the file does not exist.
"""
if not input.endswith(".pdf"):
raise ValueError(f"Please provide a pdf file, got {input}")
if not os.path.isfile(input):
raise FileNotFoundError(f"The file {input} does not exist.")
self.fd = open(input, "rb")
self.parser = PDFParser(self.fd)
self.document = PDFDocument(self.parser)
chunks = []
basename, _ = os.path.splitext(os.path.basename(input))
# get outline
try:
outlines = self.document.get_outlines()
except Exception as e:
logger.warning(f"loading PDF file: {e}")
self.outline_flag = False
if not self.outline_flag:
with open(input, "rb") as file:
for idx, page_layout in enumerate(extract_pages(file)):
content = ""
for element in page_layout:
if hasattr(element, "get_text"):
content = content + element.get_text()
chunk = Chunk(
id=Chunk.generate_hash_id(f"{basename}#{idx}"),
name=f"{basename}#{idx}",
content=content,
)
chunks.append(chunk)
try:
outline_chunks = self.outline_chunk(chunks, basename)
except Exception as e:
raise RuntimeError(f"Error loading PDF file: {e}")
if len(outline_chunks) > 0:
chunks = outline_chunks
else:
split_words = []
for item in outlines:
level, title, dest, a, se = item
split_words.append(title.strip().replace(" ",""))
# save the outline position in content
try:
text = extract_text(input)
except Exception as e:
raise RuntimeError(f"Error loading PDF file: {e}")
cleaned_pages = [
self._process_single_page(x, "", False, False) for x in text
]
sentences = []
for cleaned_page in cleaned_pages:
sentences += cleaned_page
content = "".join(sentences)
positions = [(input,0)]
for split_word in split_words:
pattern = re.compile(split_word)
for i,match in enumerate(re.finditer(pattern, content)):
if i == 1:
start, end = match.span()
positions.append((split_word,start))
for idx,position in enumerate(positions):
chunk = Chunk(
id = Chunk.generate_hash_id(f"{basename}#{position[0]}"),
name=f"{basename}#{position[0]}",
content=content[position[1]:positions[idx+1][1] if idx+1 < len(positions) else None],
)
chunks.append(chunk)
return chunks
if __name__ == '__main__':
reader = PDFReader(split_using_outline=True)
pdf_path = os.path.join(os.path.dirname(__file__),"../../../../tests/builder/data/aiwen.pdf")
chunk = reader.invoke(pdf_path)
print(chunk)

View File

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from typing import List, Type
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
from knext.common.base.runnable import Input, Output
class TXTReader(SourceReaderABC):
"""
A PDF reader class that inherits from SourceReader.
"""
@property
def input_types(self) -> Type[Input]:
return str
@property
def output_types(self) -> Type[Output]:
return Chunk
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
The main method for processing text reading. This method reads the content of the input (which can be a file path or text content) and converts it into a Chunk object.
Args:
input (Input): The input string, which can be the path to a text file or direct text content.
**kwargs: Additional keyword arguments, currently unused but kept for potential future expansion.
Returns:
List[Output]: A list containing Chunk objects, each representing a piece of text read.
Raises:
ValueError: If the input is empty.
IOError: If there is an issue reading the file specified by the input.
"""
if not input:
raise ValueError("Input cannot be empty")
try:
if os.path.exists(input):
with open(input, "r", encoding='utf-8') as f:
content = f.read()
else:
content = input
except OSError as e:
raise IOError(f"Failed to read file: {input}") from e
basename, _ = os.path.splitext(os.path.basename(input))
chunk = Chunk(
id=Chunk.generate_hash_id(input),
name=basename,
content=content,
)
return [chunk]

View File

@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import requests
from typing import Type, List
from kag.builder.component.reader import MarkDownReader
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
from knext.common.base.runnable import Input, Output
from kag.common.llm.client import LLMClient
class YuqueReader(SourceReaderABC):
def __init__(self, token: str, **kwargs):
super().__init__(**kwargs)
self.token = token
self.markdown_reader = MarkDownReader(**kwargs)
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return str
@property
def output_types(self) -> Type[Output]:
"""The type of output this Runnable object produces specified as a type annotation."""
return Chunk
@staticmethod
def get_yuque_api_data(token, url):
headers = {"X-Auth-Token": token}
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise an HTTPError for bad responses (4xx and 5xx)
return response.json()["data"] # Assuming the API returns JSON data
except requests.exceptions.HTTPError as http_err:
print(f"HTTP error occurred: {http_err}")
except requests.exceptions.RequestException as err:
print(f"Error occurred: {err}")
except Exception as err:
print(f"An error occurred: {err}")
def invoke(self, input: str, **kwargs) -> List[Output]:
if not input:
raise ValueError("Input cannot be empty")
url: str = input
data = self.get_yuque_api_data(self.token, url)
id = data.get("id", "")
title = data.get("title", "")
content = data.get("body", "")
chunks = self.markdown_reader.solve_content(id, title, content)
return chunks
if __name__ == "__main__":
llm_module = LLMClient.from_config("/Users/zhangxinhong.zxh/workspace/openspgapp/openspg/python/kag/tests/llm/config/ollama.yaml")
reader = YuqueReader("N400SPQifX4GkPVbHYekCRqklyQ0hqNMt4xiPSf5",llm_module = llm_module)
res = reader.invoke("https://yuque-api.antfin-inc.com/api/v2/repos/fg4znh/kg7h1z/docs/gokiiu")
a = 1

View File

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.splitter.length_splitter import LengthSplitter
from kag.builder.component.splitter.semantic_splitter import SemanticSplitter
from kag.builder.component.splitter.pattern_splitter import PatternSplitter
from kag.builder.component.splitter.outline_splitter import OutlineSplitter
__all__ = [
"LengthSplitter",
"SemanticSplitter",
"PatternSplitter",
]

View File

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from abc import ABC
from typing import Type, List, Union
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SplitterABC
class BaseTableSplitter(SplitterABC):
"""
A base class for splitting table, inheriting from Splitter.
"""
def split_table(self, org_chunk: Chunk, chunk_size: int = 2000, sep: str = "\n"):
"""
split markdown format table into smaller markdown table
"""
try:
return self._split_table(org_chunk=org_chunk, chunk_size=chunk_size, sep=sep)
except Exception:
return None
def _split_table(self, org_chunk: Chunk, chunk_size: int = 2000, sep: str = "\n"):
output = []
content = org_chunk.content
table_start = content.find("|")
table_end = content.rfind("|") + 1
if table_start is None or table_end is None or table_start == table_end:
return None
prefix = content[0:table_start].strip("\n ")
table_rows = content[table_start:table_end].split("\n")
table_header = table_rows[0]
table_header_segmentation = table_rows[1]
suffix = content[table_end:].strip("\n ")
splitted = []
cur = [prefix, table_header, table_header_segmentation]
cur_len = len(prefix)
for idx, row in enumerate(table_rows[2:]):
if cur_len > chunk_size:
cur.append(suffix)
splitted.append(cur)
cur_len = 0
cur = [prefix, table_header, table_header_segmentation]
cur.append(row)
cur_len += len(row)
output = []
for idx, sentences in enumerate(splitted):
chunk = Chunk(
id=f"{org_chunk.id}#{chunk_size}#table#{idx}#LEN",
name=f"{org_chunk.name}#{idx}",
content=sep.join(sentences),
type=org_chunk.type,
**org_chunk.kwargs
)
output.append(chunk)
return output

View File

@ -0,0 +1,149 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Type, List, Union
from kag.builder.model.chunk import Chunk, ChunkTypeEnum
from knext.common.base.runnable import Input, Output
from kag.builder.component.splitter.base_table_splitter import BaseTableSplitter
class LengthSplitter(BaseTableSplitter):
"""
A class for splitting text based on length, inheriting from Splitter.
Attributes:
split_length (int): The maximum length of each split chunk.
window_length (int): The length of the overlap between chunks.
"""
def __init__(self, split_length: int = 500, window_length: int = 100, **kwargs):
super().__init__(**kwargs)
self.split_length = int(split_length)
self.window_length = int(window_length)
@property
def input_types(self) -> Type[Input]:
return Chunk
@property
def output_types(self) -> Type[Output]:
return Chunk
def split_sentence(self, content):
"""
Splits the given content into sentences based on delimiters.
Args:
content (str): The content to be split.
Returns:
list: A list of sentences.
"""
sentence_delimiters = ".。??!"
output = []
start = 0
for idx, char in enumerate(content):
if char in sentence_delimiters:
end = idx
tmp = content[start: end + 1].strip()
if len(tmp) > 0:
output.append(tmp)
start = idx + 1
res = content[start:]
if len(res) > 0:
output.append(res)
return output
def slide_window_chunk(
self,
org_chunk: Chunk,
chunk_size: int = 2000,
window_length: int = 300,
sep: str = "\n",
) -> List[Chunk]:
"""
Splits the content into chunks using a sliding window approach.
Args:
org_chunk (Chunk): The original chunk to be split.
chunk_size (int, optional): The maximum size of each chunk. Defaults to 2000.
window_length (int, optional): The length of the overlap between chunks. Defaults to 300.
sep (str, optional): The separator used to join sentences. Defaults to "\n".
Returns:
List[Chunk]: A list of Chunk objects.
"""
if org_chunk.type == ChunkTypeEnum.Table:
table_chunks = self.split_table(org_chunk=org_chunk, chunk_size=chunk_size, sep=sep)
if table_chunks is not None:
return table_chunks
content = self.split_sentence(org_chunk.content)
splitted = []
cur = []
cur_len = 0
for sentence in content:
if cur_len + len(sentence) > chunk_size:
if cur:
splitted.append(cur)
tmp = []
cur_len = 0
for item in cur[::-1]:
if cur_len >= window_length:
break
tmp.append(item)
cur_len += len(item)
cur = tmp[::-1]
cur.append(sentence)
cur_len += len(sentence)
if len(cur) > 0:
splitted.append(cur)
output = []
for idx, sentences in enumerate(splitted):
chunk = Chunk(
id=f"{org_chunk.id}#{chunk_size}#{window_length}#{idx}#LEN",
name=f"{org_chunk.name}",
content=sep.join(sentences),
type=org_chunk.type,
**org_chunk.kwargs
)
output.append(chunk)
return output
def invoke(self, input: Chunk, **kwargs) -> List[Output]:
"""
Invokes the splitter on the given input chunk.
Args:
input (Chunk): The input chunk to be split.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of split chunks.
"""
cutted = []
if isinstance(input,list):
for item in input:
cutted.extend(
self.slide_window_chunk(
item, self.split_length, self.window_length
)
)
else:
cutted.extend(
self.slide_window_chunk(
input, self.split_length, self.window_length
)
)
return cutted

View File

@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
import os
import re
from typing import List, Type,Union
from kag.interface.builder import SplitterABC
from kag.builder.prompt.outline_prompt import OutlinePrompt
from kag.builder.model.chunk import Chunk
from knext.common.base.runnable import Input, Output
from kag.common.llm.client.llm_client import LLMClient
logger = logging.getLogger(__name__)
class OutlineSplitter(SplitterABC):
def __init__(self,**kwargs):
super().__init__(**kwargs)
self.llm = self._init_llm()
language = os.getenv("KAG_PROMPT_LANGUAGE", "zh")
self.prompt = OutlinePrompt(language)
@property
def input_types(self) -> Type[Input]:
return Chunk
@property
def output_types(self) -> Type[Output]:
return Chunk
def outline_chunk(self, chunk: Union[Chunk, List[Chunk]]) -> List[Chunk]:
if isinstance(chunk, Chunk):
chunk = [chunk]
outlines = []
for c in chunk:
outline = self.llm.invoke({"input": c.content}, self.prompt)
outlines.extend(outline)
content = "\n".join([c.content for c in chunk])
chunks = self.sep_by_outline(content, outlines)
return chunks
def sep_by_outline(self,content,outlines):
position_check = []
for outline in outlines:
start = content.find(outline)
position_check.append((outline,start))
chunks = []
for idx,pc in enumerate(position_check):
chunk = Chunk(
id = Chunk.generate_hash_id(f"{pc[0]}#{idx}"),
name=f"{pc[0]}#{idx}",
content=content[pc[1]:position_check[idx+1][1] if idx+1 < len(position_check) else len(position_check)],
)
chunks.append(chunk)
return chunks
def invoke(self,input: Input, **kwargs) -> List[Chunk]:
chunks = self.outline_chunk(input)
return chunks
if __name__ == "__main__":
from kag.builder.component.splitter.length_splitter import LengthSplitter
from kag.builder.component.splitter.outline_splitter import OutlineSplitter
from kag.builder.component.reader.docx_reader import DocxReader
from kag.common.env import init_kag_config
init_kag_config(os.path.join(os.path.dirname(__file__),"../../../../tests/builder/component/test_config.cfg"))
docx_reader = DocxReader()
length_splitter = LengthSplitter(split_length=8000)
outline_splitter = OutlineSplitter()
docx_path = os.path.join(os.path.dirname(__file__),"../../../../tests/builder/data/test_docx.docx")
# chain = docx_reader >> length_splitter >> outline_splitter
chunk = docx_reader.invoke(docx_path)
chunks = length_splitter.invoke(chunk)
chunks = outline_splitter.invoke(chunks)
print(chunks)

View File

@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Type, List, Union
import re
import os
from kag.builder.model.chunk import Chunk, ChunkTypeEnum
from kag.interface.builder.splitter_abc import SplitterABC
from knext.common.base.runnable import Input, Output
class PatternSplitter(SplitterABC):
def __init__(self, pattern_dict: dict = None, chunk_cut_num=None):
"""
pattern_dict:
{
"pattern": 匹配pattern,
"group": {
"header":1,
"name":2,
"content":3
}
}
"""
super().__init__()
if pattern_dict is None:
pattern_dict = {
"pattern": r"(\d+)\.([^0-9]+?)([^0-9第版].*?)(?=\d+\.|$)",
"group": {"header": 2, "name": 2, "content": 0},
}
self.pattern = pattern_dict["pattern"]
self.group = pattern_dict["group"]
self.chunk_cut_num = chunk_cut_num
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return Chunk
@property
def output_types(self) -> Type[Output]:
"""The type of output this Runnable object produces specified as a type annotation."""
return List[Chunk]
def split_sentence(self, content):
sentence_delimiters = "。??!;\n"
output = []
start = 0
for idx, char in enumerate(content):
if char in sentence_delimiters:
end = idx
tmp = content[start : end + 1].strip()
if len(tmp) > 0:
output.append(tmp)
start = idx + 1
res = content[start:]
if len(res) > 0:
output.append(res)
return output
def slide_window_chunk(
self,
content: Union[str, List[str]],
chunk_size: int = 2000,
window_length: int = 300,
sep: str = "\n",
prefix: str = "SlideWindow",
) -> List[Chunk]:
if isinstance(content, str):
content = self.split_sentence(content)
splitted = []
cur = []
cur_len = 0
for sentence in content:
if cur_len + len(sentence) > chunk_size:
splitted.append(cur)
tmp = []
cur_len = 0
for item in cur[::-1]:
if cur_len >= window_length:
break
tmp.append(item)
cur_len += len(item)
cur = tmp[::-1]
cur.append(sentence)
cur_len += len(sentence)
if len(cur) > 0:
splitted.append(cur)
output = []
for idx, sentences in enumerate(splitted):
chunk_name = f"{prefix}#{idx}"
chunk = Chunk(
id=Chunk.generate_hash_id(chunk_name),
name=chunk_name,
content=sep.join(sentences),
)
output.append(chunk)
return output
def chunk_split(
self,
chunk: Chunk,
) -> List[Chunk]:
text = chunk.content
pattern = re.compile(self.pattern, re.DOTALL)
# 查找所有匹配项
matches = pattern.finditer(text)
# 遍历所有匹配项
chunks = []
for match in matches:
chunk = Chunk(
chunk_header=match.group(self.group["header"]),
name=match.group(self.group["name"]),
id=Chunk.generate_hash_id(match.group(self.group["content"])),
content=match.group(self.group["content"]),
)
chunk = [chunk]
if self.chunk_cut_num:
chunk = self.slide_window_chunk(
content=chunk[0].content,
chunk_size=self.chunk_cut_num,
window_length=self.chunk_cut_num / 4,
sep="\n",
prefix=chunk[0].name,
)
chunks.extend(chunk)
return chunks
def invoke(self, input: Chunk, **kwargs) -> List[Output]:
chunks = self.chunk_split(input)
return chunks
def to_rest(self):
pass
@classmethod
def from_rest(cls, rest_model):
pass
class LayeredPatternSpliter(PatternSplitter):
pass
def _test():
pattern_dict = {
"pattern": r"(\d+)\.([^0-9]+?)([^0-9第版].*?)(?=\d+\.|$)",
"group": {"header": 2, "name": 2, "content": 0},
}
ds = PatternSplitter(pattern_dict=pattern_dict)
from kag.builder.component.reader.pdf_reader import PDFReader
reader = PDFReader()
file_path = os.path.dirname(__file__)
test_file_path = os.path.join(file_path, "../../../../tests/builder/data/aiwen.pdf")
pre_output = reader._handle(test_file_path)
handle_input = pre_output[0]
handle_result = ds._handle(handle_input)
print("handle_result", handle_result)
return handle_result
if __name__ == "__main__":
res = _test()
print(res)

View File

@ -0,0 +1,776 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
import os
import re
from typing import List, Type
from kag.interface.builder import SplitterABC
from kag.builder.prompt.semantic_seg_prompt import SemanticSegPrompt
from kag.builder.model.chunk import Chunk
from knext.common.base.runnable import Input, Output
from kag.common.llm.client.llm_client import LLMClient
logger = logging.getLogger(__name__)
class SemanticSplitter(SplitterABC):
"""
A class for semantically splitting text into smaller chunks based on the content's structure and meaning.
Inherits from the Splitter class.
Attributes:
kept_char_pattern (re.Pattern): Regex pattern to match Chinese/ASCII characters.
split_length (int): The maximum length of each chunk after splitting.
llm_client (LLMClient): Instance of LLMClient initialized with `model` config.
semantic_seg_op (SemanticSegPrompt): Instance of SemanticSegPrompt for semantic segmentation.
"""
def __init__(self, split_length: int = 1000, **kwargs):
super().__init__(**kwargs)
# Chinese/ASCII characters
self.kept_char_pattern = re.compile(
r"[^\u4e00-\u9fa5\u3000-\u303F\uFF01-\uFF0F\uFF1A-\uFF20\uFF3B-\uFF40\uFF5B-\uFF65\x00-\x7F]+"
)
self.split_length = int(split_length)
self.llm = self._init_llm()
language = os.getenv("KAG_PROMPT_LANGUAGE", "zh")
self.semantic_seg_op = SemanticSegPrompt(language)
@property
def input_types(self) -> Type[Input]:
return Chunk
@property
def output_types(self) -> Type[Output]:
return Chunk
@staticmethod
def parse_llm_output(content: str, llm_output: list):
"""
Parses the output from the LLM to generate segmented information.
Args:
content (str): The original content being split.
llm_output (list): Output from the LLM indicating segment locations and abstracts.
Returns:
list: A list of dictionaries containing segment names, contents, and lengths.
"""
seg_info = llm_output
seg_info.sort()
locs = [x[0] for x in seg_info]
abstracts = [x[1] for x in seg_info]
locs.append(len(content))
splitted = []
for idx in range(len(abstracts)):
start = locs[idx]
end = locs[idx + 1]
splitted.append(
{
"name": abstracts[idx],
"content": content[start:end],
"length": end - start,
}
)
return splitted
def semantic_chunk(
self,
org_chunk: Chunk,
chunk_size: int = 1000,
) -> List[Chunk]:
"""
Splits the given content into semantic chunks using an LLM.
Args:
org_chunk (Chunk): The original chunk to be split.
chunk_size (int, optional): Maximum size of each chunk. Defaults to 1000.
Returns:
List[Chunk]: A list of Chunk objects representing the split content.
"""
result = self.llm.invoke({"input": org_chunk.content}, self.semantic_seg_op)
splitted = self.parse_llm_output(org_chunk.content, result)
logger.debug(f"splitted = {splitted}")
chunks = []
for idx, item in enumerate(splitted):
split_name = item["name"]
if len(item["content"]) < chunk_size:
chunk = Chunk(
id=f"{org_chunk.id}#{chunk_size}#{idx}#SEM",
name=f"{org_chunk.name}#{split_name}",
content=item["content"],
abstract=item["name"],
**org_chunk.kwargs
)
chunks.append(chunk)
else:
print("chunk over size")
innerChunk = Chunk(
id=Chunk.generate_hash_id(item["content"]),
name=f"{org_chunk.name}#{split_name}",
content=item["content"],
)
chunks.extend(
self.semantic_chunk(
innerChunk, chunk_size
)
)
return chunks
def invoke(self, input: Input, **kwargs) -> List[Output]:
"""
Invokes the splitting process on the provided input.
Args:
input (Input): The input to be processed.
**kwargs: Additional keyword arguments.
Returns:
List[Output]: A list of outputs generated from the input.
"""
chunks = self.semantic_chunk(input, self.split_length)
return chunks
if __name__ == "__main__":
chunk = Chunk(
id="1",
name="test",
content="""
与贸易有关的知识产权协定
各成员方
本着减少国际贸易中的扭曲及障碍的愿望考虑到有必要促进对知识产权有效和充分的保护以及确保实施保护产权的措施及程序本身不致成为合法贸易的障碍认识到为此目的有必要制定关于下列的新规则及规范
11994关贸总协定基本原则及有关的国际知识产权协议和公约的适用性
2关于与贸易有关的知识产权的效力范围和使用的适当标准及原则的规定
3关于在考虑到各国法律体系差异的同时使用有效并适当的方法实施与贸易有关的知识产权的规定
4关于采取多边性的防止和解决各国间争端的有效并迅捷的程序的规定
5旨在使谈判结果有最广泛的参加者的过渡安排
认识到建立应付国际仿冒商品贸易的原则规则及规范的多边框架的必要性
认识到知识产权为私有权
认识到保护知识产权的国家体制基本的公共政策目标包括发展和技术方面的目标
还认识到最不发达国家成员方为建立一个稳固可行的技术基础而在国内实施法律和条例方面对最大限度的灵活性具有特殊需要
强调通过多边程序为解决与贸易有关的知识财产问题争端作出更加有力的承诺以缓解紧张局势的重要性
希望在世界贸易组织及世界知识产权组织本协议中称"WIPO"之间以及其他有关国际组织之间建立一种相互支持的关系
兹协议如下
第一部分 总则和基本原则
第1条 义务的性质和范围
1.各成员方应使本协议的规定生效各成员方可以但不应受强制地在其本国法律中实行比本协议所要求的更加广泛的保护只要这种保护不与本协议条款相抵触各成员方应在各自的法律体系及惯例范围内自由确定实施本协议各条款的适当方法
2.本协议所称的"知识产权"一词系指第二部分第1至第7节所列举所有种类的知识财产
3.各成员方应给予其他成员方国民以本协议所规定的待遇就相关的知识产权而言如果所有世界贸易组织成员方已是这些公约的成员方则其他成员方国民应被理解为符合1967巴黎公约1971伯尔尼公约罗马公约有关集成电路知识产权条约所规定的受保护资格标准的自然人或法人任何利用由罗马公约第5条第3款或第6条第2款所提供之可能性的成员方应如那些条款所预见的那样向与贸易有关的知识产权理事会作出通报
第2条 知识产权公约
1.关于本协议第二第三及第四部分各成员方应遵守巴黎公约1967第l条至第12条以及第19条规定
2.本协议第一至第四部分的所有规定均不得减损各成员方按照巴黎公约伯尔尼公约罗马公约有关集成电路知识产权条约而可能相互承担的现行义务
第3条 国民待遇
1.在服从分别在1967巴黎公约1971伯尔尼公约罗马公约有关集成电路知识产权条约中已作的例外规定的条件下在保护知识产权方面每一成员方应给予其他成员方的待遇其优惠不得少于它给予自己国民的优惠对于录音及广播机构的表演者制作者本项义务只对本协议中规定的权利适用任何利用由1971伯尔尼公约第6条或罗马公约第16条第1款第2子款所规定之可能性的成员方均应向与贸易有关的知识产权理事会作出在那些条款中预知的通报
2.第1款所允许的与司法及行政程序有关的例外包括在一成员方司法管辖权范围内指定服务地址或委任代理人只有在为确保与本协议规定不一致的法律规章得到遵守所必要的并且此种作法不以一种可能对贸易构成变相限制的方式被采用的条件下各成员方方可利用
第4条 最惠国待遇
在知识产权的保护方面由一成员方授予任一其他国家国民的任何利益优惠特权或豁免均应立即无条件地给予所有其他成员方的国民一成员方给予的下列利益优惠特权或豁免免除此项义务
1得自国际司法协助协定或一种一般性的并非专门限于保护知识产权的法律实施的
2按照认可所给予的待遇只起在另一国所给予的待遇的作用而不起国民待遇作用的1971伯尔尼公约罗马公约的规定授予的
3有关本协议未作规定的录音与广播组织的表演者及制作者权利的
4得自世界贸易组织协定生效之前已生效的与知识产权保护有关的国际协定的条件是此类协定已通报与贸易有关的知识产权理事会并且不得构成一种对其他各成员方国民随意的或不公正的歧视
第5条 关于保护的获得或保持的多边协定
第3条和第4条规定的义务不适用于在世界知识产权组织主持下达成的有关知识产权获得或保持的多边协定规定的程序
第6条 
就本协议下争端的解决而言按照第3条及第4条本协议中的任何条款均不得用以提出知识产权失效问题
第7条 
知识产权的保护和实施应当对促进技术革新以及技术转让和传播作出贡献对技术知识的生产者和使用者的共同利益作出贡献并应当以一种有助于社会和经济福利以及有助于权利与义务平衡的方式进行
第8条 
1.各成员方在制订或修正其法律和规章时可采取必要措施以保护公众健康和营养并促进对其社会经济和技术发展至关重要部门的公众利益只要该措施符合本协议规定
2.可能需要采取与本协议的规定相一致的适当的措施以防止知识产权所有者滥用知识产权或藉以对贸易进行不合理限制或实行对国际间的技术转让产生不利影响的作法
第二部分 关于知识产权的效力范围及使用的标准
第1节 版权及相关权利
第9条 伯尔尼公约的关系
1.各成员方应遵守1971伯尔尼公约第l至第21条及其附件的规定然而各成员方根据本协议对公约第6条副则授予的权利或由其引伸出的权利没有权利和义务
2.对版权的保护可延伸到公式但不得延伸到思想程序操作方法或数学上的概念等
第10条 计算机程序和数据汇编
1.计算机程序无论是信源代码还是目标代码均应根据1971伯尔尼公约的规定作为文献著作而受到保护
2.不论是机读的还是其他形式的数据或其他材料的汇编其内容的选择和安排如构成了智力创造即应作为智力创造加以保护这种不得延及数据或材料本身的保护不应妨碍任何存在于数据或材料本身的版权
第11条 出租权
至少在计算机程序和电影艺术作品方面一成员方应给予作者及其权利继承人以授权或禁止将其拥有版权的作品原著或复制品向公众作商业性出租的权利除非此类出租已导致了对该作品的广泛复制而这种复制严重损害了该成员方给予作者及其权利继承人的独家再版权否则在电影艺术作品方面一成员方可免除此项义务在计算机程序方面当程序本身不是出租的主要对象时此项义务不适用于出租
第12条 保护期
电影艺术作品或实用艺术作品以外作品的保护期应以不同于自然人的寿命计算此期限应为自授权出版的日历年年终起算的不少于50年或者若作品在创作后50年内未被授权出版则应为自创作年年终起算的50年
第13条 限制和例外
各成员方应将对独占权的限制和例外规定限于某些特殊情况而不影响作品的正常利用也不无理妨碍权利所有者的合法利益
第14条 对录音音响录音制品的保护
1.在表演者的表演在录制品上的录制方面表演者应能阻止下列未经其许可的行为录制和翻录其尚未录制的表演表演者也应能阻止下列未经其许可的行为将其现场表演作无线电广播和向公众传播
2.录音制品制作者应享有授权或禁止直接或间接翻制其录音制品的权利
3.广播机构应有权禁止下列未经其许可的行为录制翻录以无线广播手段转播以及向公众传播同一录音制品的电视广播若各成员方未向广播机构授予此种权利则应依照伯尔尼公约1971向广播内容的版权所有者提供阻止上述行为的可能性
4.第11条关于计算机程序的规定经对细节作必要修改后应适用于录音制品的制作者及经一成员方法律确认的录音制品的任何其他版权所有者若一成员方在1994年4月15日实行了在出租录音制品方面向版权所有者提供合理补偿的制度则它可在录音制品的商业性出租未对版权所有者的独占翻录权造成重大损害的条件下维持该项制度
5.录音制品制作者和表演者根据本协议可以获得的保护期至少应持续到从录音制品被制作或演出进行的日历年年终起算的50年期结束时按照第3款给予的保护期至少应从广播播出的日历年年终起算持续20年
6.有关按第2款及第3款授予的权利任何成员方可在罗马公约允许的范围内对按第2款及第3款授予的权利规定条件限制例外及保留
但是1971伯尔尼公约第18条的规定经对细节作必要修改后也应适用于录音制品表演者和制作者的权利
第2节
第15条 保护事项
1.任何能够将一个企业的商品和服务与另一企业的商品和服务区别开来的标志或标志组合均应能够构成商标此种标志尤其是包含有个人姓名的词字母数目字图形要素和色彩组合以及诸如此类的标志组合应有资格注册为商标若标志没有固有的能够区别有关商品及服务的特征则各成员方可将其通过使用而得到的独特性作为或给予注册的依据各成员方可要求标志在视觉上是可以感知的以此作为注册的一项条件
2.第1款不得理解为阻止一成员以其他理由拒绝商标注册只要这些理由不减损巴黎公约1967的规定
3.各成员方可将使用作为给予注册的依据然而商标的实际使用不应是提出注册申请的一项条件申请不得仅由于在自申请之日起的3年期期满之前未如所计划那样地加以使用而遭拒绝
4.商标所适用的商品或服务的性质在任何情况下均不得构成对商标注册的障碍
5.各成员方应在每一商标注册之前或之后立即将其公布并应为请求取消注册提供合理机会此外各成员方可为反对一个商标的注册提供机会
第16条 授予权利
1.已注册商标所有者应拥有阻止所有未经其同意的第三方在贸易中使用与已注册商标相同或相似的商品或服务的其使用有可能招致混淆的相同或相似的标志在对相同商品或服务使用相同标志的情况下应推定存在混淆之可能上述权利不应妨碍任何现行的优先权也不应影响各成员方以使用为条件获得注册权的可能性
2.1967巴黎公约第6条副则经对细节作必要修改后应适用于服务在确定一个商标是否为知名商标时各成员方应考虑到有关部分的公众对该商标的了解包括由于该商标的推行而在有关成员方得到的了解
3.1967巴黎公约第6条副则经对细节作必要修改后应适用于与已注册商标的商品和服务不相似的商品或服务条件是该商标与该商品和服务有关的使用会表明该商品或服务与已注册商标所有者之间的联系而且已注册商标所有者的利益有可能为此种使用所破坏
第17条 
各成员方可对商标所赋予的权利作些有限的例外规定诸如公正使用说明性术语条件是此种例外要考虑到商标所有者和第三方的合法利益
第18条 保护期
商标首次注册及每次续期注册的期限不得少于7年商标注册允许无限期地续期
第19条 使用规定
1.如果注册的保持要求以商标付诸使用为条件则除非商标所有者提出了此类使用存在障碍的充分理由否则注册只有在商标至少连续三年以上未予使用的情况下方可取消
2.当商标由其他人的使用是处在该商标所有者的控制之下时这种使用应按是为保持注册目的之使用而予以承认
第20条 其他要求
商标在贸易当中的使用不得受到一些特殊要求不正当的妨碍比如与另一商标一道使用以特殊形式使用或以有害于该商标将一个企业的商品或服务与其他企业的商品或服务区分开来的能力之方式使用等这并不排除规定识别生产某种商品或服务的企业的商标与识别该企业同类特殊商品或服务的商标一道但不联在一起使用的要求
第21条 许可与转让
各成员方可以确定商标许可与转让的条件同时不言而喻强制性的商标许可是不应允许的已注册商标的所有者有权将商标所属企业与商标一同转让或只转让商标不转让企业
第3节 地理标志
第22条 对地理标志的保护
1.本协议所称的地理标志是识别一种原产于一成员方境内或境内某一区域或某一地区的商品的标志而该商品特定的质量声誉或其他特性基本上可归因于它的地理来源
2.在地理标志方面各成员方应向各利益方提供法律手段以阻止
1使用任何手段在商品的设计和外观上以在商品地理标志上误导公众的方式标志或暗示该商品原产于并非其真正原产地的某个地理区域
2作任何在1967巴黎公约第10条副则意义内构成一种不公平竞争行为的使用
3.若某种商品不产自于某个地理标志所指的地域而其商标又包含了该地理标志或由其组成如果该商品商标中的该标志具有在商品原产地方面误导公众的性质则成员方在其法律许可的条件下或应利益方之请求应拒绝或注销该商标的注册
4.上述第123款规定的保护应适用于下述地理标志该地理标志虽然所表示的商品原产地域地区或所在地字面上无误但却向公众错误地表明商品是原产于另一地域
第23条 对葡萄酒和烈性酒地理标志的额外保护
1.每一成员方应为各利益方提供法律手段以阻止不产自于某一地理标志所指地方的葡萄酒或烈性酒使用该地理标志即使在标明了商品真正原产地或在翻译中使用了该地理标志或伴以"种类""类型""风味""仿制"等字样的情况下也不例外
2.对于不产自于由某一地理标志所指的原产地而又含有该产地地理标志的葡萄酒或烈性酒如果一成员方的立法允许或应某一利益方之请求应拒绝或注销其商标注册
3.如果不同的葡萄酒使用了同名的地理标志则根据上述第22条第4款规定每一种标志均受到保护每一成员方应确定使同名地理标志能够相互区别开来的现实条件同时应考虑到确保有关的生产者受到公正待遇并不致使消费者产生误解混淆
4.为了便于对葡萄酒地理标志进行保护应在与贸易有关的知识产权理事会内就建立对参加体系的那些成员方有资格受到保护的葡萄酒地理标志进行通报与注册的多边体系进行谈判
第24条 国际谈判与例外
1.成员方同意进行旨在加强第23条规定的对独特地理标志的保护的谈判成员方不得援用下述第4至8款的规定拒绝进行谈判或缔结双边或多边协定在此类谈判中成员方应愿意考虑这些规定对关于其使用是此类谈判之议题的独特地理标志的连续适用性
2.与贸易有关的知识产权理事会应对本节规定之适用情况实行审查首次此类审查应在世界贸易组织协定生效2年之内进行任何影响履行该规定义务的事项均可提请理事会注意应一成员方之请求理事会应就经有关成员方之间双边磋商或多组双边磋商仍无法找到令人满意的解决办法的问题同任何一个或多个成员方进行磋商理事会应采取可能被一致认为有助于本节之实施及促进本节目标之实现的行动
3.在实施本节规定时成员方不得在世界贸易组织协定生效日即将来临之际减少对该成员方境内的地理标志的保护
4.本节中无任何规定要求一成员方阻止其国民或居民继续或类似地使用另一成员方与商品或服务有关的用以区别葡萄酒或烈性酒的特殊地理标志这些国民或居民在该成员方境内
1在1994年4月15日之前至少已有10年
2在上述日期之前已诚实守信地连续使用了标示相同或相关商品或服务的地理标志
5.若一商标已被诚实守信地使用或注册
1在如第六部分中所确定的这些规定在那一成员方适用之日以前
2在该地理标志在其原产国获得保护之前通过诚实守信的使用而获得一商标的权利则为实施本节而采取的措施就不得以该商标同某一地理标志相同或类似为由而损害其注册的合格性和合法性或使用该商标的权利
6.本节中无任何规定要求一成员方适用其关于任何其他成员方的商品和服务的地理标志的规定这些商品或服务的相关标志与作为那一成员方境内此类商品或服务的普通名称在一般用语中是惯用的名词完全相同本节中无任何规定要求一成员方适用其关于任何其他成员方的葡萄制品的地理标志的规定这些葡萄制品与在世界贸易组织协定生效之日存在于该成员方境内的葡萄品种的惯用名称完全相同
7.一成员方可以规定任何根据本节所提出的有关商标使用或注册的请求必须在对该受保护标志的非法使用被公布后5年之内提出或在商标在那一成员方境内注册之日以后提出条件是在注册之日商标已被公告如果该日期早于非法使用被公布的日期则条件就应是地理标志未被欺诈地使用或注册
8.本节规定丝毫不得妨碍任何个人在贸易中使用其姓名或其前任者姓名的权利若该姓名的使用导致公众的误解则除外
第4节 工业设计
第25条 保护的要求
1.成员方应为新的或原始的独立创造的工业设计提供保护成员方可以规定设计如果与已知的设计或已知的设计要点的组合没有重大区别则不视其为新的或原始的成员方可以规定此类保护不应延伸至实质上是由技术或功能上的考虑所要求的设计
2.每一成员方应保证对于获取对纺织品设计保护的规定不得无理损害寻求和获得此类保护的机会特别是在费用检查或发表方面各成员方可自行通过工业设计法或版权法履行该项义务
第26条 
1.受保护工业设计的所有者应有权阻止未经所有者同意的第三方为商业目的生产销售或进口含有或体现为是受保护设计的复制品或实为复制品的设计的物品
2.成员方可以对工业设计的保护规定有限的例外条件是这种例外没有无理地与对受保护工业设计的正常利用相冲突且没有无理损害受保护设计所有者的合法利益同时考虑到第三方的合法利益
3.有效保护期限至少为10年
第5节
第27条 可取得专利的事项
1.根据下述第23款的规定所有技术领域内的任何发明无论是产品还是工艺均可取得专利只要它们是新的包含一个发明性的步骤工业上能够适用根据第65条第4款第70条第8款和本条第3款的规定专利的取得和专利权的享受应不分发明地点技术领域以及产品是进口的还是当地生产的
2.若阻止某项发明在境内的商业利用对保护公共秩序或公共道德包括保护人类动物或植物的生命或健康或避免对环境造成严重污染是必要的则成员方可拒绝给予该项发明以专利权条件是不是仅因为其国内法禁止这种利用而作出此种拒绝行为
3.以下情况成员方也可不授予专利
1对人类或动物的医学治疗的诊断治疗和外科手术方法
2微生物以外的动植物非生物和微生物生产方法以外的动物或植物的实为生物的生产方法然而成员方应或以专利形式或以一种特殊有效的体系或以综合形式对植物种类提供保护应在世界贸易组织协定生效4年之后对本子款的规定进行审查
第28条 授予的权利
1.一项专利应授予其所有者以下独占权
1若一项专利的标的事项是一种产品则专利所有者有权阻止未得到专利所有者同意的第三方制造使用出卖销售或为这些目的而进口被授予专利的产品
2若专利的标的事项是一种方法则专利所有者有权阻止未得到专利所有者同意的第三方使用该方法或使用出卖销售或至少是为这些目的而进口直接以此方法获得的产品
2.专利所有者还应有权转让或通过继承转让该项专利及签订专利权使用契约
第29条 专利申请者的条件
1.成员方应要求专利申请者用足够清晰与完整的方式披露其发明以便于为熟悉该门技术者所运用并要求申请者在申请之日指明发明者已知的运用该项发明的最佳方式若是要求取得优先权则需在优先权申请之日指明
2.成员方可要求专利申请者提供关于该申请者在国外相同的申请与授予情况的信息
第30条 授予权利的例外
成员方可对专利授予的独占权规定有限的例外条件是该例外规定没有无理地与专利的正常利用相冲突也未损害专利所有者的合法利益同时考虑到第三者的合法利益
第31条 未经权利人授权的其他使用
若一成员方的法律允许未经权利人授权而对专利的标的事项作其他使用包括政府或经政府许可的第三者的使用则应遵守以下规定
1此类使用的授权应根据专利本身的条件来考虑
2只有在拟议中的使用者在此类使用前已作出以合理的商业条件获得权利人授权的努力而该项努力在一段合理时间内又未获成功时方可允许此类使用在发生全国性紧急状态或其他极端紧急状态或为公共的非商业性目的而使用的情况下成员方可放弃上述要求即使是在发生全国性紧急状态或其他极端紧急状态的情况下仍应合理地尽早通报权利人至于公共的非商业性使用若政府或订约人在未查专利状况的情况下得知或有根据得知一项有效的专利正在或将要被政府使用或为政府而使用则应及时通知权利人
3此类使用的范围和期限应限制在被授权的意图之内至于半导体技术只应用于公共的非商业性目的或用于抵销在司法或行政程序后被确定的反竞争的做法
4此类使用应是非独占性的
5此类使用应是不可转让的除非是同享有此类使用的那部分企业或信誉一道转让
6任何此类使用之授权均应主要是为授权此类使用的成员方国内市场供应之目的
7在被授权人的合法利益受到充分保护的条件下当导致此类使用授权的情况下不复存在和可能不再产生时有义务将其终止应有动机的请求主管当局应有权对上述情况的继续存在进行检查
8考虑到授权的经济价值应视具体情况向权利人支付充分的补偿金
9任何与此类使用之授权有关的决定其法律效力应接受该成员方境内更高当局的司法审查或其他独立审查
10任何与为此类使用而提供的补偿金有关的决定应接受成员方境内更高当局的司法审查或其他独立审查
11若是为抵销在司法或行政程序后被确定为反竞争做法而允许此类使用则成员方没有义务适用上述第2和第6子款规定的条件在决定此种情况中补偿金的数额时可以考虑纠正反竞争做法的需要若导致此项授权的条件可能重新出现则主管当局应有权拒绝终止授权
12若此类使用被授权允许利用一项不侵犯另一项专利第一项专利就不能加以利用的专利第二项专利则下列附加条件应适用
第二项专利中要求予以承认的发明应包括比第一项专利中要求予以承认的发明经济意义更大的重要的技术进步
第一项专利的所有者应有权以合理的条件享有使用第二项专利中要求予以承认之发明的相互特许权
除非同第二项专利一道转让否则第一项专利所授权的使用应是不可转让的
第32条 撤销收回
应提供对撤销或收回专利的决定进行司法审查的机会
第33条 保护的期限
有效的保护期限自登记之日起不得少于20年
第34条 工艺专利的举证责任
1.在第28条第l款第2子款所述及关于侵犯所有者权利的民事诉讼中若一项专利的标的事项是获取某种产品的工艺则司法当局应有权令被告证明获取相同产品的工艺不同于取得专利的工艺因此各成员方应规定在下列情况中至少一种情况下任何未经专利所有者同意而生产的相同产品若无相反的证据应被视为是以取得专利的工艺获取的
1如果以该项取得专利的工艺获取的产品是新的
2如果该相同产品极有可能是以该工艺生产的而专利所有者又不能通过合理的努力确定实际使用的工艺
2.只要上述第1或第2子款所述及的条件得到满足任何成员方均应有权规定上述第1款所指明的举证责任应由有嫌疑的侵权者承担
3.在举出相反证据时应考虑被告保护其生产和商业秘密的合法权益
第6节 集成电路的外观设计
第35条 与有关集成电路的知识产权条约的关系
各成员方同意按有关集成电路的知识产权条约中第2条至第7条第6条中第3款除外第12条和第16条第3款规定对集成电路的外观设计提供保护此外还服从以下规定
第36条 保护范围
根据下述第37条第1款的规定成员方应认为下列未经权利所有人授权的行为是非法的进口销售或为商业目的分售受保护的外观设计含有受保护设计的集成电路或仅在继续含有非法复制的外观设计的范围内含有这种集成电路的产品
第37条 不需要权利人授权的行为
1.尽管有上述第36条的规定但若从事或指令从事上条中所述及的关于含有非法复制的外观设计的集成电路或含有此种集成电路的任何产品的任何行为的人在获取该集成电路或含有此种集成电路的产品时未得知且没有合理的根据得知它含有非法复制的外观设计则成员方不应认为这种行为是非法的成员方应规定该行为人在接到关于复制该设计是非法行为的充分警告之后仍可从事与在此之前的存货和订单有关的任何行为但有责任向权利人支付一笔与对自由商谈而取得的通过关于该设计的专利使用权所应付费用相当的合理的专利权税
2.若对该外观设计的专利权使用授权是非自愿的或者被政府使用或为政府而使用是未经权利人授权的则上述第31条第1-11子款规定的条件在对细节作必要修改后应适用
第38条 保护的期限
1.在要求将登记作为保护条件的成员方对外观设计的保护期限自填写登记申请表之日或自在世界上任何地方首次进行商业开发之日计起应不少于10年时间
2.在不要求将登记作为保护条件的成员方对外观设计的保护期限自在世界上任何地方首次进行商业开发之日计起应不少于10年时间
3.尽管有上述第1第2款规定成员方仍可规定在外观设计被发明15年后保护应自动消失
第7节 对未泄露之信息的保护
第39条
1.在确保有效的保护以对付1967巴黎公约第10条副则所述及的不公平竞争的过程中各成员方应对下述第2款所规定的未泄露之信息和下述第3款所规定的提交给政府或政府机构的数据提供保护
2.自然人和法人应有可能阻止由其合法掌握的信息在未得到其同意的情况下被以违反诚信商业作法的方式泄露获得或使用只要此信息
1在作为一个实体或其组成部分的精确形状及组合不为正规地处理此种信息的那部分人所共知或不易被其得到的意义上说是秘密的
2由于是秘密的而具有商业价值
3被其合法的掌握者根据情况采取了合理的保密措施
3.成员方当被要求呈交本公开的试验或其他所获得需要付出相当劳动的数据以作为同意使用新型化学物质生产的药品或农用化学品在市场上销售的一项条件时应保护该数据免受不公平的商业利用此外成员方应保护该数据免于泄露除非是出于保护公共利益的需要或采取了保证该数据免受不公平商业利用的措施
第8节 在契约性专利权使用中对反竞争性行为的控制
第40条
1.各成员方一致认为与限制竞争的知识产权有关的一些专利权使用作法或条件对贸易可能产生不利影响可能妨碍技术的转让和传播
2.本协议中无任何规定阻止成员方在其立法中详细载明在特定情况下可能构成对有关市场中的竞争具有不利影响的知识产权滥用的专利权使用作法或条件如上述所规定一成员方可按照本协议的其他规定根据国内有关法律和规定采取适当措施阻止或控制此种作法这些措施可能包括例如独占性回授条件阻止否认合法性的条件和强制性的一揽子许可证交易
3.若一成员方有理由认为是另一成员方国民或居民的知识产权所有者正在从事违反该成员方关于本节主题事项的法律规章的活动并希望使该另一成员方遵守此类法规则在不妨碍两个成员方中任何一方依法采取任何行动和作出最终决定的充分自由的条件下该另一成员方在接到该成员方的请求后应与之进行磋商被请求的成员方对与提出请求的成员方进行磋商应给予充分的同情的考虑为此提供充分的机会并应在服从国内法和令双方满意的关于提出请求的成员方保护资料机密性的协议之最后决定的条件下通过提供与该问题有关的可以公开利用的非机密性资料和可供该成员方利用的其他资料进行合作
4.其国民或居民正在另一成员方接受关于所断言的违反该成员方关于本节主题事项的法律规章的诉讼的成员方根据请求应由另一成员方给予按照与上述第3款相同的条件进行磋商的机会
第三部分 知识产权的实施
第1节 一般义务
第41条
1.成员方应保证由本部分所具体规定的实施程序根据国内法是有效的以便允许对任何对本协议所涉及的知识产权的侵犯行为采取有效行动包括及时地阻止侵权的补救措施和对进一步侵权构成一种威慑的补救措施在运用这些程序时应避免对合法贸易构成障碍并规定防止其滥用的保障措施
2.知识产权的实施程序应公平合理不应不必要地繁琐消耗资财也不应有不合理的时限及毫无道理的拖延
3.对一案件案情实质的裁决最好应以书面形式作出并陈述理由至少应使诉讼各方没有不适当延迟地获知裁决结果对该案件案情实质的判定应只以各方被提供机会了解的证据为依据
4.诉讼各方应有机会让司法当局对最终行政决定及根据一成员方法律中关于一案件重要程度的司法规定对至少是对一案件案情实质最初司法裁决的法律方面进行审查然而没有义务为对刑事案件中的判定无罪进行审查提供机会
5.显然本部分未规定任何关于设置与实施大多数法律不同的实施知识产权的司法制度的义务也不影响成员方实施其大多数法律的能力本部分也未规定任何关于在实施知识产权和实施大多数法律之间进行资源分配的义务
第2节 民事和行政程序及补救
第42条 公平合理的程序
成员方应使权利所有人可以利用关于本协议所涉及的任何知识产权之实施的民事司法程序被告应有权及时获得内容充实的书面通知其中包括控告的依据应允许成员方由独立的法律辩护人充当其代表关于强制性的亲自出庭程序中不应规定过多烦琐的要求该程序所涉及的各方应有充分的权利证实其要求并提出所有相关的证据该程序应规定一种识别和保护机密性资料的方法除非该规定与现行的宪法要求相抵触
第43条 
1.若一当事方已提交了足以支持其要求的合理有效的证据并具体指明了由对方掌握的与证实其要求有关的证据则司法当局应有权决定按照在此类案件中确保对机密性资料保护的条件令对方出示该证据
2.若诉讼一当事方有意地并无正当理由地拒绝有关方面使用或在合理期限内未提供必要的资料或严重地妨碍与某一强制行动有关的程序则一成员方可授权司法当局根据呈交上来的信息包括因被拒绝使用信息而受到不利影响的一方呈交的申诉和事实陈述作出或是肯定的或是否定的最初和最终裁决这一切须在向各方提供机会听到断言或证据的情况下进行
第44条 
1.司法当局应有权命令一当事方停止侵权行为特别是在涉及对知识产权有侵权行为的进口货物结关之后立即阻止这些货物进入其司法管辖区内的商业渠道各成员方对涉及由个人在得知或有合理的根据得知经营受保护产品会构成对知识产权的侵犯之前获得或订购的该产品不必提供此项授权
2.尽管有本部分的其他规定若第二部分中专门阐述的关于未经权利人授权的政府使用或由政府授权的第三方的使用的各项规定得到遵守则各成员方可将针对此类使用的可资利用的补救措施限制在依据第31条第8子款的补偿金支付上在其他情况下本部分的补救措施应适用或者若这些补救措施与成员方的法律不符则应适用宣告性判决和适当的补偿金
第45条 
1.司法当局有权令故意从事侵权活动或有合理的根据知道是在从事侵权活动的侵权人就因侵权人对权利所有人知识产权的侵犯而对权利所有人造成的损害向其支付适当的补偿
2.司法当局有权令侵权人向权利所有人支付费用可能包括聘请律师的有关费用在有关案件中即使侵权人并非故意地从事侵权活动或有合理的根据知道其正在从事侵权活动成员方仍可授权司法当局下令追偿利润和/或支付预先确定的损失
第46条 其他补救
为了对侵权行为造成有效的威慑司法当局有权令其发现正在授权的货物避免对权利所有人造成损害的方式不作任何补偿地在商业渠道以外予以处置或者在不与现行法律要求相抵触的情况下予以销毁司法当局还有权令在侵权物品生产中主要使用的材料和工具以减少进一步侵权危险的方式不作任何补偿地在商业渠道以外予以处置在考虑此类请求时应考虑侵权的严重程度与被决定的补救两者相称的必要性以及第三者的利益对于仿冒商标产品除例外情况仅仅除去非法所贴商标还不足以允许将该产品放行到商业渠道之中
第47条 告知权
成员方可规定司法当局有权令侵权人将与侵权产品或服务的生产销售有牵连的第三方的身份及其销售渠道告知权利所有人除非这种授权与侵权的危害程度不成比例
第48条 被告的赔偿
1.司法当局有权令应其请求而采取措施并滥用实施程序的申诉方受到错误命令或抑制的被告方因此种滥用而遭受的损害向其提供补偿司法当局还应有权令申诉方支付被告的费用可能包括聘请律师的有关费用
2.关于与知识产权的保护或实施有关的任何法律的实施若政府机构和政府官员在实施法律过程中有诚意地采取了行动或打算采取行动则应仅免除其对适当补救措施的责任
第49条 行政程序
在能够决定以任何民事补救作为关于一案件案情实质的行政程序之结果的范围内此类程序应遵守与本节中所规定的那些原则大体相等的原则
第3节 临时措施
第50条
1.司法当局应有权决定及时有效的临时措施
1阻止任何对知识产权侵权行为的发生特别是阻止包括刚刚结关的进口商品在内的侵权商品进入其司法管辖区内的商业渠道
2保护关于被断言的侵权行为的有关证据
2.在适当的情况下特别是在任何延迟可能会给权利人带来不可弥补的损害或证据极有毁灭危险的情况下司法当局有权采取适当的措施
3.司法当局有权要求申诉人提供合理有效的证据以便司法当局充分肯定地确认申诉人就是权利人申诉人的权利正在受到侵犯或者这种侵犯即将发生同时司法当局应要求申诉人提供足以保护被告和防止滥用的保证金或同等的担保
4.若临时措施已经采取应在实施措施后最短时间内通知受影响的成员方应被告之请求应对这些措施进行重新审查包括听取被告陈述以便在关于措施的通报发出后的合理时间内决定这些措施是否应加以修正撤销或确认
5.将要实施临时措施令的司法当局可以要求申诉人提供为鉴别相关产品所必需的其他资料
6.在成员方法律允许的情况下导致对案件案情实质作出裁决的合理诉讼时间由发布措施令的司法当局决定在没有此种决定的情况下该合理时间为不超过20个工作日或者31天以长者为准若此类诉讼在该合理时间内没有开始则在不妨碍上述第4款规定的同时按照上述第1第2款所采取的临时措施应根据被告的请求予以撤销或使其停止生效
7.若临时措施被撤销或由于申诉人的任何作为或疏忽而失效或随后发现对知识产权的侵犯或侵犯的威胁并不存在则应被告之请求司法当局应有权令申诉人就这些措施对被告造成的任何损害向被告提供适当的赔偿
8.在能够决定以任何临时措施作为行政程序之结果的范围内此类程序应遵守与本节中所规定的那些原则大体相等的原则
第4节 与边境措施相关的特殊要求
第51条 海关当局的中止放行
成员方应依照以下规定采纳程序使有确凿根据怀疑仿冒商标商品或盗版商品的进口可能发生的权利人能够以书面形式向主管的行政或司法当局提出由海关当局中止放行该货物进入自由流通的申请在本节的要求得到满足的条件下成员方可使对含有其他侵犯知识产权行为货物的申请能够被提出成员方还可规定关于海关当局中止放行从其境内出口的侵权货物的相应程序
第52条 
应要求任何启动上述第51条程序的权利人提供适当的证据以使主管当局确信根据进口国的法律确有对权利人知识产权无可争辩的侵犯并提供对该货物充分详细的描述以使海关当局可以迅速地对其加以识别主管当局应在一个合理的时间内通知申请人是否接受其请求若主管当局决定受理应通知申诉人海关当局采取行动的时间
第53条 保证金或同等担保
1.主管当局应有权要求申请人提供一笔足以保护被告和有关当局并阻止滥用的保证金或同等担保该保证金或同等担保不应无理地阻碍对这些程序的援用
2.假如根据本节关于申请的规定对涉及工业设计专利外观设计或未泄露信息的货物进入自由流通的放行已由海关当局根据非由司法或其他独立机构作出的裁决中止下述第55条规定的期限已到期而仍未获得主管当局暂时放行的许可而且假如关于进口的所有其他条件均得到了遵从则该货物的所有人进口商或收货人在提交了一笔其数额足以保护权利人不受侵权损害的保证金的条件下应有权使该货物放行保证金的支付不应妨碍向权利人作其他有效的补偿显然如果权利人在一段合理的时间内没有寻求起诉权则保证金应予免除
第54条 中止通知
根据上述第51条的规定货物放行一旦被中止应立即通知进口商和申请人
第55条 中止的持续期限
若在申请人被送达中止通知后不超过10个工作日之内海关当局仍未接到关于被告以外的一方已开始将会导致对案件的案情实质作出裁决的诉讼或者主管当局已采取延长对货物放行中止的临时措施的通知则只要进口或出口的所有其他条件均已得到了遵从该货物便应予放行在适当的情况下上述期限可再延长10个工作日若导致对一案件案情实质作出裁决的诉讼已经开始则应被告之请求应进行审查包括听取被告的陈述以便决定在一段合理的时间内这些措施是否应加以修正撤销或确认尽管有上述规定或货物的放行已经被中止或根据临时司法措施继续被中止则第50条第6款的规定应适用
第56条 对商品进口商和货主的补偿
有关当局应有权令申请人就因错误扣押货物或扣押根据上述第55条规定应予放行的货物而对进口商收货人和货主所造成的损害向其支付适当的赔偿
第57条 资料和调查权
在不妨碍对机密资料进行保护的同时成员方应授权主管当局给予权利人使被海关当局扣押的货物接受检查以证实权利人要求的充分机会有关当局也应有权给予进口商使任何此类货物接受检查的同等机会若对案件的案情实质已作出了积极的裁决成员方可以授权主管当局将有关发货人进口商和收货人的姓名和地址以及有关货物的数量通知权利人
第58条 依职权之行为
若成员方要求主管当局主动采取行动中止放行其已获得无可争辩的证据证明知识产权正在受到侵犯的货物
1主管当局在任何时候均可从权利人处寻求任何有助于其行使权力的资料
2应迅速地将该中止通知进口商和权利人若进口商已就中止一事向主管当局呈交了上诉则中止应按上述第55条规定的经对细节作了必要修改的条件进行
3若政府机构和政府官员有诚意地采取了行动或者打算采取行动则成员方仅应免除其对适当补救措施所应承担的责任
第59条 
在不妨碍权利人其他行动权和被告向司法当局寻求审查权利的同时主管当局根据上述第40条的规定应有权下令销毁或处理侵权货物对于仿冒商标的货物当局应不允许侵权货物原封不动地再出口若使其按照不同的海关程序办理例外情况除外
第60条 少量进口
对于旅游者和私人行李中携带的或少量寄存的非商业性质的少量货物成员方可免除上述条款的适用
第5节 刑事程序
第61条
成员方应规定刑事程序和惩罚至少适用于具有商业规模的故意的商标仿冒和盗版案件可资利用的补救措施应包括足以构成一种威慑的与对相应程度的刑事犯罪适用的处罚水平相同的监禁和/或罚款措施在适当的案件中可资利用的补救措施还包括对侵犯货物及在从事此种违法行为时主要使用的材料和工具予以扣押没收和销毁成员方可规定适用于其他侵犯知识产权案件的刑事程序和惩罚特别是对于故意和具有商业规模的侵权案件
第四部分 知识产权的取得和保持及相关程序
第62条
1.成员方可要求遵循合理的程序和手续以此作为第二部分第2至第6节所规定的知识产权的取得和保持的一项条件此类程序和手续应符合本协议的规定
2.若知识产权的取得以知识产权被授予或登记为准则成员方应确保在符合取得知识产权的实质性条件的情况下有关授予或登记的程序允许在一段合理时间内授予或登记权利以避免保护期限被不适当地缩短
3.1967巴黎公约第4条应在对细节作必要修改之后适用于服务标记
4.有关知识产权之取得和保持的程序有关行政撤销的程序若成员方的法律规定了这样的程序有关诸如抗辩撤销和废除等的程序应服从第41条第2款和第3款规定的总原则
5.上述第4款所涉及的任何程序中的最终行政决定应接受司法当局或准司法当局的审查然而在抗辩和行政撤销不成功的情况下假若此类程序的基础可能成为程序无效的原因则没有任何义务为对裁决进行此类审查提供机会
第五部分 争端的预防和解决
第63条 透明度
1.由一成员方制度实施的关于本协议主题事项知识产权的效力范围取得实施和防止滥用问题的法律和规章对一般申请的最终司法裁决和行政裁决应以该国官方语言以使各成员方政府和权利人能够熟悉的方式予以公布若此种公布不可行则应使之可以公开利用正在实施中的一成员方的政府或一政府机构与另一成员方政府或一政府机构之间关于本协议主题事项的各项协议也应予以公布
2.成员方应将上述第1款所述及的法律和规章通报与贸易有关的知识产权理事会以协助理事会对本协议的执行情况进行检查理事会应努力去最大限度地减轻成员方在履行该项义务方面的负担若与世界知识产权组织就建立一份含有这些法律和规章的共同登记簿一事进行的磋商取得成功理事会便可决定免除直接向理事会通报此类法律和规章的义务理事会在这方面还应考虑采取本协议从1967巴黎公约第6条的各项规定派生出来的各项义务所要求的与通报有关的任何行动
3.应另一成员方的书面请求每一成员方应准备提供上述第1款所述及的资料一成员方在有理由相信知识产权领域中某个特定的司法裁决行政裁决或双边协议影响到其由本协议所规定的权利时也可以书面形式要求向其提供或充分详尽地告知该特定的司法裁决行政裁决或双边协议
4.上述第1至第3款中无任何规定要求成员方泄露将会妨碍法律实施或违背公共利益或损害特定的国营或私营企业合法商业利益的资料
第64条 争端解决
1.由争端解决谅解所详细阐释并运用的1994关贸总协定第22条和第23条的各项规定应运用于本协议下的争端磋商与解决本协议中另有规定者除外
2.在自世界贸易组织协定生效之日起的5年之内1994关贸总协定第23条第1款第2和第3子款不应适用于本协议下的争端解决
3.在第2款所述及的期限内与贸易有关的知识产权理事会应检查根据本协议提出的由1994关贸总协定第23条第l款第2和第3子款所规定的那种类型控诉的规模和形式并向部长级会议提交建议请其批准部长级会议关于批准此类建议或延长第2款中所述及时限的任何决定只应以全体一致的方式作出被批准的建议应对所有成员方生效无须进一步的正式接受程序
第六部分 过渡期安排
第65条 过渡期安排
1.根据下述第2第3和第4款的规定成员方无义务在世界贸易组织协定生效之日后一般1年期满之前适用于本协议的规定
2.发展中国家成员方有权将第l款中所确定的本协议除第3第4和第5条以外的各项规定的适用日期推迟4年时间
3.处于由中央计划经济向市场私营企业经济转换进程中的和正在进行知识产权制度结构改革并在制定和实施知识产权法方面面临特殊困难的任何其他成员方也可从上述第2款所述及的期限推迟中获益
4.一发展中国家成员方如果按本协议有义务在上述第2款所规定的本协议对该成员方适用之日将对产品专利的保护扩大到在其境内无法加以保护的技术领域则可将第二部分第5节关于产品专利的规定对此类技术领域的适用再推迟5年时间
5利用上述第1第2第3和第4款所规定过渡期的成员方应保证在该时期内其法律规章和作法中的任何变更不导致它们与本协议规定相一致的程度降低
第66条 最不发达国家成员方
1.鉴于最不发达国家成员方的特殊需要和要求其经济财政和行政的压力以及其对创造一个可行的技术基础的灵活性的需要不应要求这些成员方在自上述第65条第l款所规定的适用日期起的10年内适用本协议第3第4和第5条除外应最不发达国家无可非议的请求与贸易有关的知识产权理事会应将此期限予以延长
2.发达国家缔约方应给境内的企业和机构提供奖励以促进和鼓励对最不发达国家成员方转让技术使其能够建立一个稳固可行的技术基础
第67条 技术合作
为便于本协议的实施发达国家成员方应根据请求和双边达成的条件向发展中国家和最不发达国家成员方提供对其有利的技术和金融合作此类合作应包括协助制定对有关知识产权保护实施及阻止滥用的法律和规章还应包括对设方和加强与这些事项有关的国内机关和机构包括人员培训提供支持
 
第七部分 机构安排和最后条款
第68条 与贸易有关的知识产权理事会
与贸易有关的知识产权理事会应对本协议的执行情况尤其是成员方履行本协议所规定义务的情况进行监督并应为成员方提供与贸易有关的知识产权事宜进行磋商的机会理事会应履行成员方指定给它的其他职责并尤其应在争端解决程序方面对成员方提出的请求提供帮助在行使职能时与贸易有关的知识产权理事会可与它认为合适的方面进行磋商并从那里寻找资料在与世界知识产权组织进行磋商时理事会应谋求在其第一次会议的1年内作出与该组织所属各机构进行合作的适当安排
第69条 国际合作
成员方同意相互进行合作以消除侵犯知识产权商品的国际贸易为此它们应在其行政范围内设立和通报联络站并随时交流关于侵权商品贸易的情报它们尤其应促进海关当局之间在仿冒商标商品和盗版商品贸易方面的情报交流与合作
第70条 对现有标的事项的保护
1.对于某个成员方在本协议对其适用之日以前发生的行为本协议不规定该成员方承担任何义务
2.本协议另有规定者除外对于在本协议对有关成员方适用之日已存在的及在该日期在该成员方受到保护的或在本协议规定的期限内达到或以后将要达到保护标准的所有标的事项本协议规定义务与本款和下述第3第4款有关的关于现有著作的版权义务仅应根据1971伯尔尼公约第18条来决定关于现有唱片的唱片制作商和表演者的权利仅应根据1971伯尔尼公约第18条来决定该条的适用办法由本协议第14条第6款作了具体规定
3.对于在本协议适用之日处于无专利权状态的标的事项没有对其恢复保护的义务
4.关于在按照与本协议相一致立法的条件已构成侵权的并在有关成员方接受世界贸易组织协定之日以前已开始的或已对其进行了大量投资的体现受保护标的事项的特定对象方面的任何行为任何成员方可为在本协议对那一成员方生效之日以后此类行为继续发生的情况下权利人可以利用的补救措施规定一个限度然而在此情况下该成员方至少应规定支付合理的补偿
5.成员方没有义务在本协议对其适用之日以前适用关于购买的原物或复制品的第11条和第14条第4款规定
6.不应要求成员方将第31条或第27条第1款关于对技术领域专利权的享用应一视同仁的规定适用于本协议生效之日以前未经权利人许可而经政府授权的使用
7.在以登记为保护条件的知识产权方面对于在本协议对有关成员方适用之日仍未得到批准的保护申请应允许对其作正式修改以要求根据本协议的规定加强保护此类修改不得包括新事项
8.若世界贸易组织协定生效之日已到而一成员方仍未能对药品和农用化学品提供与其根据第27条所承担义务相当的有效的专利保护则该成员方应
1尽管有第五部分的规定仍自建立世界贸易组织协定生效之日起规定一种使关于此类发明的专利申请得以提出的方式
2自本协议适用之日起将本协议所规定的授予专利权标准适用于这些申请视这些标准已由成员方在申请提出之日或若优先权有效并已被提出权利要求则在优先权申请之日予以适用
3自专利被批准时起并在自依照本协议第33条的申请提出之日计起的专利期的其余部分对符合上述第2子款保护标准的申请提供专利保护
9.若按第8款第1子款一项产品是一成员方内的专利申请对象则尽管有第五部分的规定应自在那一成员方获准进行市场销售之时计起给予其独占的市场销售权5年或直到一项产品专利在那一成员方被批准或拒绝之时以时间较短者为准条件是在世界贸易组织协定生效之后在另一成员方那项产品的专利申请已被提出专利已被批准并获准在该另一成员方进行市场销售
第71条 审查和修正
1.与贸易有关的知识产权理事会应在上述第65条第2款所述及的过渡期期满之后对本协议的履行情况进行审查理事会应参考在履行中获得的经验在过渡期期满之日2年后对本协议的履行情况进行审查并在此后每隔两年审查一次理事会也可根据可能成为对本协议进行修改或修正之理由的任何有关的新情况进行审查
2.仅为适应在已生效的其他国际协议中已达到的和根据那些国际协议为世界贸易组织所有成员方所接受的对知识产权更高的保护程度而提出的修正案可提交部长级会议以便其根据与贸易有关的知识产权理事会一致同意的建议采取与世界贸易组织协定第10条第6款相符的行动
第72条 
未经其他成员方的同意不得对本协议的任何条款作出保留
第73条 保障的例外规定
本协议中的任何内容均不应解释为
1要求一成员方提供他认为其泄露违背其根本安全利益的任何资料
2阻止任一成员方采取他认为是对保护其根本安全利益所必需的行动
与裂变物质或从中获取裂变物质的物质有关的
与枪支弹药和战争工具走私有关的以及与直接或间接为供给军方之目的而从事的其他货物和物资的走私有关的
在战时或在国际关系中出现其他紧急情况时采取的
3阻止任何成员方为根据联合国宪章所承担的维持国际和平与安全的义务而采取的行动
"""
)
splitter = SemanticSplitter(1000)
res = splitter.invoke(chunk)
print(res)

View File

@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from collections import defaultdict
from typing import List
from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
from kag.common.vectorizer import Vectorizer, Neo4jBatchVectorizer
from kag.interface.builder.vectorizer_abc import VectorizerABC
from knext.schema.client import SchemaClient
from knext.project.client import ProjectClient
from knext.schema.model.base import IndexTypeEnum
class BatchVectorizer(VectorizerABC):
def __init__(self, project_id: str = None, **kwargs):
super().__init__(**kwargs)
self.project_id = project_id or os.getenv("KAG_PROJECT_ID")
self._init_graph_store()
self.vec_meta = self._init_vec_meta()
self.vectorizer = Vectorizer.from_config(self.vectorizer_config)
def _init_graph_store(self):
"""
Initializes the Graph Store client.
This method retrieves the graph store configuration from environment variables and the project ID.
It then fetches the project configuration using the project ID and updates the graph store configuration
with any additional settings from the project. Finally, it creates and initializes the graph store client
using the updated configuration.
Args:
project_id (str): The id of project.
Returns:
GraphStore
"""
graph_store_config = eval(os.getenv("KAG_GRAPH_STORE", "{}"))
vectorizer_config = eval(os.getenv("KAG_VECTORIZER", "{}"))
config = ProjectClient().get_config(self.project_id)
graph_store_config.update(config.get("graph_store", {}))
vectorizer_config.update(config.get("vectorizer", {}))
self.vectorizer_config = vectorizer_config
def _init_vec_meta(self):
vec_meta = defaultdict(list)
schema_client = SchemaClient(project_id=self.project_id)
spg_types = schema_client.load()
for type_name, spg_type in spg_types.items():
for prop_name, prop in spg_type.properties.items():
if prop_name == "name" or prop.index_type in [IndexTypeEnum.Vector, IndexTypeEnum.TextAndVector]:
vec_meta[type_name].append(self._create_vector_field_name(prop_name))
return vec_meta
def _create_vector_field_name(self, property_key):
from kag.common.utils import to_snake_case
name = f"{property_key}_vector"
name = to_snake_case(name)
return "_" + name
def _neo4j_batch_vectorize(self, vectorizer: Vectorizer, input: SubGraph) -> SubGraph:
node_list = []
node_batch = []
for node in input.nodes:
if not node.id or not node.name:
continue
properties = {"id": node.id, "name": node.name}
properties.update(node.properties)
node_list.append((node, properties))
node_batch.append((node.label, properties.copy()))
batch_vectorizer = Neo4jBatchVectorizer(vectorizer, self.vec_meta)
batch_vectorizer.batch_vectorize(node_batch)
for (node, properties), (_node_label, new_properties) in zip(
node_list, node_batch
):
for key, value in properties.items():
if key in new_properties and new_properties[key] == value:
del new_properties[key]
node.properties.update(new_properties)
return input
def invoke(self, input: Input, **kwargs) -> List[Output]:
modified_input = self._neo4j_batch_vectorize(self.vectorizer, input)
return [modified_input]

View File

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.builder.component.writer.kg_writer import KGWriter
__all__ = [
"KGWriter",
]

View File

@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
import os
from enum import Enum
from typing import Type, Dict, List
from knext.graph_algo.client import GraphAlgoClient
from kag.builder.model.sub_graph import SubGraph
from kag.interface.builder.writer_abc import SinkWriterABC
from knext.common.base.runnable import Input, Output
logger = logging.getLogger(__name__)
class AlterOperationEnum(str, Enum):
Upsert = "UPSERT"
Delete = "DELETE"
class KGWriter(SinkWriterABC):
"""
A class that extends `SinkWriter` to handle writing data into a Neo4j knowledge graph.
This class is responsible for configuring the graph store based on environment variables and
an optional project ID, initializing the Neo4j client, and setting up the schema.
It also manages semantic indexing and multi-threaded operations.
"""
def __init__(self, project_id: str = None, **kwargs):
super().__init__(**kwargs)
self.project_id = project_id or os.getenv("KAG_PROJECT_ID")
self.client = GraphAlgoClient(project_id=project_id)
@property
def input_types(self) -> Type[Input]:
return SubGraph
@property
def output_types(self) -> Type[Output]:
return None
def invoke(
self, input: Input, alter_operation: str = AlterOperationEnum.Upsert, lead_to_builder: bool = False
) -> List[Output]:
"""
Invokes the specified operation (upsert or delete) on the graph store.
Args:
input (Input): The input object representing the subgraph to operate on.
alter_operation (str): The type of operation to perform (Upsert or Delete).
lead_to_builder (str): enable lead to event infer builder
Returns:
List[Output]: A list of output objects (currently always [None]).
"""
self.client.write_graph(sub_graph=input.to_dict(), operation=alter_operation, lead_to_builder=lead_to_builder)
return [None]
def _handle(self, input: Dict, alter_operation: str, **kwargs):
"""The calling interface provided for SPGServer."""
_input = self.input_types.from_dict(input)
_output = self.invoke(_input, alter_operation)
return None

View File

@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
import importlib
import os
from kag.builder.component import SPGTypeMapping, KGWriter
from kag.builder.component.extractor import KAGExtractor
from kag.builder.component.splitter import LengthSplitter
from kag.builder.component.vectorizer.batch_vectorizer import BatchVectorizer
from knext.common.base.chain import Chain
from knext.builder.builder_chain_abc import BuilderChainABC
logger = logging.getLogger(__name__)
def get_reader(file_path: str):
file = os.path.basename(file_path)
suffix = file.split(".")[-1]
assert suffix.lower() in READER_MAPPING, f"{suffix} is not supported. Supported suffixes are: {list(READER_MAPPING.keys())}"
reader_path = READER_MAPPING.get(suffix.lower())
mod_path, class_name = reader_path.rsplit('.', 1)
module = importlib.import_module(mod_path)
reader_class = getattr(module, class_name)
return reader_class
READER_MAPPING = {
"csv": "kag.builder.component.reader.csv_reader.CSVReader",
"json": "kag.builder.component.reader.json_reader.JSONReader",
"txt": "kag.builder.component.reader.txt_reader.TXTReader",
"pdf": "kag.builder.component.reader.pdf_reader.PDFReader",
"docx": "kag.builder.component.reader.docx_reader.DocxReader",
"md": "kag.builder.component.reader.markdown_reader.MarkdownReader",
}
class DefaultStructuredBuilderChain(BuilderChainABC):
"""
A class representing a default SPG builder chain, used to import structured data based on schema definitions
Steps:
0. Initializing by a give SpgType name, which indicates the target of import.
1. SourceReader: Reading structured dicts from a given file.
2. SPGTypeMapping: Mapping source fields to the properties of target type, and assemble a sub graph.
By default, the same name mapping is used, which means importing the source field into a property with the same name.
3. KGWriter: Writing sub graph into KG storage.
Attributes:
spg_type_name (str): The name of the SPG type.
"""
def __init__(self, spg_type_name: str, **kwargs):
super().__init__(**kwargs)
self.spg_type_name = spg_type_name
def build(self, **kwargs):
"""
Builds the processing chain for the SPG.
Args:
**kwargs: Additional keyword arguments.
Returns:
chain: The constructed processing chain.
"""
file_path = kwargs.get("file_path")
source = get_reader(file_path)()
mapping = SPGTypeMapping(spg_type_name=self.spg_type_name)
sink = KGWriter()
chain = source >> mapping >> sink
return chain
def invoke(self, file_path, max_workers=10, **kwargs):
logger.info(f"begin processing file_path:{file_path}")
"""
Invokes the processing chain with the given file path and optional parameters.
Args:
file_path (str): The path to the input file.
max_workers (int, optional): The maximum number of workers. Defaults to 10.
**kwargs: Additional keyword arguments.
Returns:
The result of invoking the processing chain.
"""
return super().invoke(file_path=file_path, max_workers=max_workers, **kwargs)
class DefaultUnstructuredBuilderChain(BuilderChainABC):
"""
A class representing a default KAG builder chain, used to extract graph from documents and import unstructured data.
Steps:
0. Initializing.
1. SourceReader: Reading chunks from a given file.
2. LengthSplitter: Splitting chunk to smaller chunks. The chunk size can be adjusted through parameters.
3. KAGExtractor: Extracting entities and relations from chunks, and assembling a sub graph.
By default,the extraction process includes NER and SPO Extraction.
4. KGWriter: Writing sub graph into KG storage.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, **kwargs) -> Chain:
"""
Builds the processing chain for the KAG.
Args:
**kwargs: Additional keyword arguments.
Returns:
chain: The constructed processing chain.
"""
file_path = kwargs.get("file_path")
split_length = kwargs.get("split_length")
window_length = kwargs.get("window_length")
source = get_reader(file_path)()
splitter = LengthSplitter(split_length, window_length)
extractor = KAGExtractor()
vectorizer = BatchVectorizer()
sink = KGWriter()
chain = source >> splitter >> extractor >> vectorizer >> sink
return chain
def invoke(self, file_path: str, split_length: int = 500, window_length: int = 100, max_workers=10, **kwargs):
logger.info(f"begin processing file_path:{file_path}")
"""
Invokes the processing chain with the given file path and optional parameters.
Args:
file_path (str): The path to the input file.
split_length (int, optional): The length at which the file should be split. Defaults to 500.
window_length (int, optional): The length of the processing window. Defaults to 100.
max_workers (int, optional): The maximum number of worker threads. Defaults to 10.
**kwargs: Additional keyword arguments.
Returns:
The result of invoking the processing chain.
"""
return super().invoke(file_path=file_path, max_workers=max_workers, split_length=window_length, window_length=window_length, **kwargs)

View File

@ -0,0 +1,10 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import hashlib
from enum import Enum
from typing import Dict, Any
class ChunkTypeEnum(str, Enum):
Table = "TABLE"
Text = "TEXT"
class Chunk:
def __init__(
self,
id: str,
name: str,
content: str,
type: ChunkTypeEnum = ChunkTypeEnum.Text,
**kwargs
):
self.id = id
self.name = name
self.type = type
self.content = content
self.kwargs = kwargs
@staticmethod
def generate_hash_id(value):
if isinstance(value, str):
value = value.encode("utf-8")
hasher = hashlib.sha256()
hasher.update(value)
return hasher.hexdigest()
def __str__(self):
tmp = {
"id": self.id,
"name": self.name,
"content": self.content
if len(self.content) <= 64
else self.content[:64] + " ...",
}
return f"<Chunk>: {tmp}"
__repr__ = __str__
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"content": self.content,
"type": self.type.value if isinstance(self.type, ChunkTypeEnum) else self.type,
"properties": self.kwargs,
}
@classmethod
def from_dict(cls, input_: Dict[str, Any]):
return cls(
id=input_.get("id"),
name=input_.get("name"),
content=input_.get("content"),
type=input_.get("type"),
**input_.get("properties", {}),
)

View File

@ -0,0 +1,276 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import pprint
from typing import Dict, Any, List, Tuple
from knext.schema.model.schema_helper import (
SPGTypeName,
PropertyName,
RelationName,
)
class SPGRecord:
"""Data structure in operator, used to store entity information."""
def __init__(self, spg_type_name: SPGTypeName):
self._spg_type_name = spg_type_name
self._properties = {}
self._relations = {}
@property
def id(self) -> str:
return self.get_property("id", "")
@property
def name(self) -> str:
return self.get_property("name", self.id)
@property
def spg_type_name(self) -> SPGTypeName:
"""Gets the spg_type_name of this SPGRecord. # noqa: E501
:return: The spg_type_name of this SPGRecord. # noqa: E501
:rtype: str
"""
return self._spg_type_name
@spg_type_name.setter
def spg_type_name(self, spg_type_name: SPGTypeName):
"""Sets the spg_type_name of this SPGRecord.
:param spg_type_name: The spg_type_name of this SPGRecord. # noqa: E501
:type: str
"""
self._spg_type_name = spg_type_name
@property
def properties(self) -> Dict[PropertyName, str]:
"""Gets the properties of this SPGRecord. # noqa: E501
:return: The properties of this SPGRecord. # noqa: E501
:rtype: dict
"""
return self._properties
@properties.setter
def properties(self, properties: Dict[PropertyName, str]):
"""Sets the properties of this SPGRecord.
:param properties: The properties of this SPGRecord. # noqa: E501
:type: dict
"""
self._properties = properties
@property
def relations(self) -> Dict[str, str]:
"""Gets the relations of this SPGRecord. # noqa: E501
:return: The relations of this SPGRecord. # noqa: E501
:rtype: dict
"""
return self._relations
@relations.setter
def relations(self, relations: Dict[str, str]):
"""Sets the properties of this SPGRecord.
:param relations: The relations of this SPGRecord. # noqa: E501
:type: dict
"""
self._relations = relations
def get_property(
self, property_name: PropertyName, default_value: str = None
) -> str:
"""Gets a property of this SPGRecord by name. # noqa: E501
:param property_name: The property name. # noqa: E501
:param default_value: If property value is None, the default_value will be return. # noqa: E501
:return: A property value. # noqa: E501
:rtype: str
"""
return self.properties.get(property_name, default_value)
def upsert_property(self, property_name: PropertyName, value: str):
"""Upsert a property of this SPGRecord. # noqa: E501
:param property_name: The updated property name. # noqa: E501
:param value: The updated property value. # noqa: E501
:type: str
"""
self.properties[property_name] = value
return self
def append_property(self, property_name: PropertyName, value: str):
"""Append a property of this SPGRecord. # noqa: E501
:param property_name: The updated property name. # noqa: E501
:param value: The updated property value. # noqa: E501
:type: str
"""
property_value = self.get_property(property_name)
if property_value:
property_value_list = property_value.split(',')
if value not in property_value_list:
self.properties[property_name] = property_value + ',' + value
else:
self.properties[property_name] = value
return self
def upsert_properties(self, properties: Dict[PropertyName, str]):
"""Upsert properties of this SPGRecord. # noqa: E501
:param properties: The updated properties. # noqa: E501
:type: dict
"""
self.properties.update(properties)
return self
def remove_property(self, property_name: PropertyName):
"""Removes a property of this SPGRecord. # noqa: E501
:param property_name: The property name. # noqa: E501
:type: str
"""
self.properties.pop(property_name)
return self
def remove_properties(self, property_names: List[PropertyName]):
"""Removes properties by given names. # noqa: E501
:param property_names: A list of property names. # noqa: E501
:type: list
"""
for property_name in property_names:
self.properties.pop(property_name)
return self
def get_relation(
self,
relation_name: RelationName,
object_type_name: SPGTypeName,
default_value: str = None,
) -> str:
"""Gets a relation of this SPGRecord by name. # noqa: E501
:param relation_name: The relation name. # noqa: E501
:param object_type_name: The object SPG type name. # noqa: E501
:param default_value: If property value is None, the default_value will be return. # noqa: E501
:return: A relation value. # noqa: E501
:rtype: str
"""
return self.relations.get(relation_name + "#" + object_type_name, default_value)
def upsert_relation(
self, relation_name: RelationName, object_type_name: SPGTypeName, value: str
):
"""Upsert a relation of this SPGRecord. # noqa: E501
:param relation_name: The updated relation name. # noqa: E501
:param object_type_name: The object SPG type name. # noqa: E501
:param value: The updated relation value. # noqa: E501
:type: str
"""
self.relations[relation_name + "#" + object_type_name] = value
return self
def upsert_relations(self, relations: Dict[Tuple[RelationName, SPGTypeName], str]):
"""Upsert relations of this SPGRecord. # noqa: E501
:param relations: The updated relations. # noqa: E501
:type: dict
"""
for (relation_name, object_type_name), value in relations.items():
self.relations[relation_name + "#" + object_type_name] = value
return self
def remove_relation(
self, relation_name: RelationName, object_type_name: SPGTypeName
):
"""Removes a relation of this SPGRecord. # noqa: E501
:param relation_name: The relation name. # noqa: E501
:param object_type_name: The object SPG type name. # noqa: E501
:type: str
"""
self.relations.pop(relation_name + "#" + object_type_name)
return self
def remove_relations(self, relation_names: List[Tuple[RelationName, SPGTypeName]]):
"""Removes relations by given names. # noqa: E501
:param relation_names: A list of relation names. # noqa: E501
:type: list
"""
for (relation_name, object_type_name) in relation_names:
self.relations.pop(relation_name + "#" + object_type_name)
return self
def to_str(self):
"""Returns the string representation of the model"""
return pprint.pformat(self.__dict__())
def to_dict(self):
"""Returns the model properties as a dict"""
return {
"spgTypeName": self.spg_type_name,
"properties": {
**self.properties,
**self.relations,
},
}
def __dict__(self):
"""Returns this SPGRecord as a dict"""
return {
"spgTypeName": self.spg_type_name,
"properties": self.properties,
"relations": self.relations,
}
@classmethod
def from_dict(cls, input: Dict[str, Any]):
"""Returns the model from a dict"""
spg_type_name = input.get("spgTypeName")
_cls = cls(spg_type_name)
properties = input.get("properties")
for k, v in properties.items():
if "#" in k:
relation_name, object_type_name = k.split("#")
_cls.relations.update({relation_name + "#" + object_type_name: v})
else:
_cls.properties.update({k: v})
return _cls
def __repr__(self):
"""For `print` and `pprint`"""
return pprint.pformat(self.__dict__())

View File

@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import pprint
from typing import Dict, List, Any
from knext.schema.client import BASIC_TYPES
from kag.builder.model.spg_record import SPGRecord
from knext.schema.model.base import BaseSpgType
class Node(object):
id: str
name: str
label: str
properties: Dict[str, str]
hash_map: Dict[int, str] = dict()
def __init__(self, _id: str, name: str, label: str, properties: Dict[str, str]):
self.name = name
self.label = label
self.properties = properties
self.id = _id
@classmethod
def from_spg_record(cls, idx, spg_record: SPGRecord):
return cls(
_id=idx,
name=spg_record.get_property("name"),
label=spg_record.spg_type_name,
properties=spg_record.properties,
)
@staticmethod
def unique_key(spg_record):
return spg_record.spg_type_name + '_' + spg_record.get_property("name", "")
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"label": self.label,
"properties": self.properties,
}
@classmethod
def from_dict(cls, input: Dict):
return cls(
_id=input["id"],
name=input["name"],
label=input["label"],
properties=input["properties"],
)
def __eq__(self, other):
return self.name == other.name and self.label == other.label and self.properties == other.properties
class Edge(object):
id: str
from_id: str
from_type: str
to_id: str
to_type: str
label: str
properties: Dict[str, str]
def __init__(
self, _id: str, from_node: Node, to_node: Node, label: str, properties: Dict[str, str]
):
self.from_id = from_node.id
self.from_type = from_node.label
self.to_id = to_node.id
self.to_type = to_node.label
self.label = label
self.properties = properties
if not _id:
_id = id(self)
self.id = _id
@classmethod
def from_spg_record(
cls, s_idx, subject_record: SPGRecord, o_idx, object_record: SPGRecord, label: str
):
from_node = Node.from_spg_record(s_idx, subject_record)
to_node = Node.from_spg_record(o_idx, object_record)
return cls(_id="", from_node=from_node, to_node=to_node, label=label, properties={})
def to_dict(self):
return {
"id": self.id,
"from": self.from_id,
"to": self.to_id,
"fromType": self.from_type,
"toType": self.to_type,
"label": self.label,
"properties": self.properties,
}
@classmethod
def from_dict(cls, input: Dict):
return cls(
_id=input["id"],
from_node=Node(_id=input["from"], name=input["from"],label=input["fromType"], properties={}),
to_node=Node(_id=input["to"], name=input["to"], label=input["toType"], properties={}),
label=input["label"],
properties=input["properties"],
)
def __eq__(self, other):
return self.from_id == other.from_id and self.to_id == other.to_id and self.label == other.label and self.properties == other.properties and self.from_type == other.from_type and self.to_type == other.to_type
class SubGraph(object):
id: str
nodes: List[Node] = list()
edges: List[Edge] = list()
def __init__(self, nodes: List[Node], edges: List[Edge]):
self.nodes = nodes
self.edges = edges
def add_node(self, id: str, name: str, label: str, properties=None):
if not properties:
properties = dict()
self.nodes.append(Node(_id=id, name=name, label=label, properties=properties))
return self
def add_edge(self, s_id: str, s_label: str, p: str, o_id: str, o_label: str, properties=None):
if not properties:
properties = dict()
s_node = Node(_id=s_id, name=s_id, label=s_label, properties={})
o_node = Node(_id=o_id, name=o_id, label=o_label, properties={})
self.edges.append(Edge(_id="", from_node=s_node, to_node=o_node, label=p, properties=properties))
return self
def to_dict(self):
return {
"resultNodes": [n.to_dict() for n in self.nodes],
"resultEdges": [e.to_dict() for e in self.edges],
}
def __repr__(self):
return pprint.pformat(self.to_dict())
def merge(self, sub_graph: 'SubGraph'):
self.nodes.extend(sub_graph.nodes)
self.edges.extend(sub_graph.edges)
@classmethod
def from_spg_record(
cls, spg_types: Dict[str, BaseSpgType], spg_records: List[SPGRecord]
):
sub_graph = cls([], [])
for record in spg_records:
s_id = record.id
s_name = record.name
s_label = record.spg_type_name.split('.')[-1]
properties = record.properties
spg_type = spg_types.get(record.spg_type_name)
for prop_name, prop_value in record.properties.items():
if prop_name in spg_type.properties:
from knext.schema.model.property import Property
prop: Property = spg_type.properties.get(prop_name)
o_label = prop.object_type_name.split('.')[-1]
if o_label not in BASIC_TYPES:
prop_value_list = prop_value.split(',')
for o_id in prop_value_list:
sub_graph.add_edge(s_id=s_id, s_label=s_label, p=prop_name, o_id=o_id, o_label=o_label)
properties.pop(prop_name)
sub_graph.add_node(id=s_id, name=s_name, label=s_label, properties=properties)
return sub_graph
@classmethod
def from_dict(cls, input: Dict[str, Any]):
return cls(
nodes=[Node.from_dict(node) for node in input["resultNodes"]],
edges=[Edge.from_dict(edge) for edge in input["resultEdges"]],
)

View File

@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from abc import ABC, abstractmethod
from typing import List
from knext.builder.operator.base import BaseOp
from kag.builder.model.sub_graph import Node, SubGraph
class FuseOpABC(BaseOp, ABC):
"""
Interface for fusing mapped sub graphs with data in storage.
It is usually used in mapping component for SPG builder.
"""
@abstractmethod
def link(self, source: SubGraph) -> List[SubGraph]:
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `link` method."
)
@abstractmethod
def merge(self, source: SubGraph, target: List[SubGraph]) -> List[SubGraph]:
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `merge` method."
)
def invoke(self, source: SubGraph) -> List[SubGraph]:
target = self.link(source)
return self.merge(source, target)
class LinkOpABC(BaseOp, ABC):
"""
Interface for recall nodes in storage by mapped properties.
It is usually used in mapping component for SPG builder.
"""
@abstractmethod
def invoke(self, source: Node, prop_value: str, target_type: str) -> List[Node]:
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `invoke` method."
)

View File

View File

@ -0,0 +1,46 @@
#
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import logging
from kag.common.base.prompt_op import PromptOp
logger = logging.getLogger(__name__)
class AnalyzeTablePrompt(PromptOp):
template_zh: str = """你是一个分析表格的专家, 从table中提取信息并分析最后返回表格有效信息"""
template_en: str = """You are an expert in knowledge graph extraction. Based on the schema defined by the constraint, extract all entities and their attributes from the input. Return NAN for attributes not explicitly mentioned in the input. Output the results in standard JSON format, as a list."""
def __init__(
self,
language: str = "zh",
):
super().__init__(
language=language,
)
def build_prompt(self, variables) -> str:
return json.dumps(
{
"instruction": self.template,
"table": variables.get("table",""),
},
ensure_ascii=False,
)
def parse_response(self, response: str, **kwargs):
return response

View File

View File

@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from string import Template
from typing import List, Optional
from kag.common.base.prompt_op import PromptOp
from knext.schema.client import SchemaClient
class OpenIENERPrompt(PromptOp):
template_en = """
{
"instruction": "You're a very effective entity extraction system. Please extract all the entities that are important for knowledge build and question, along with type, category and a brief description of the entity. The description of the entity is based on your OWN KNOWLEDGE AND UNDERSTANDING and does not need to be limited to the context. the entity's category belongs taxonomically to one of the items defined by schema, please also output the category. Note: Type refers to a specific, well-defined classification, such as Professor, Actor, while category is a broader group or class that may contain more than one type, such as Person, Works. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string.You can refer to the example for extraction.",
"schema": $schema,
"example": [
{
"input": "The Rezort\nThe Rezort is a 2015 British zombie horror film directed by Steve Barker and written by Paul Gerstenberger.\n It stars Dougray Scott, Jessica De Gouw and Martin McCann.\n After humanity wins a devastating war against zombies, the few remaining undead are kept on a secure island, where they are hunted for sport.\n When something goes wrong with the island's security, the guests must face the possibility of a new outbreak.",
"output": [
{
"entity": "The Rezort",
"type": "Movie",
"category": "Works",
"description": "A 2015 British zombie horror film directed by Steve Barker and written by Paul Gerstenberger."
},
{
"entity": "2015",
"type": "Year",
"category": "Date",
"description": "The year the movie 'The Rezort' was released."
},
{
"entity": "British",
"type": "Nationality",
"category": "GeographicLocation",
"description": "Great Britain, the island that includes England, Scotland, and Wales."
},
{
"entity": "Steve Barker",
"type": "Director",
"category": "Person",
"description": "Steve Barker is an English film director and screenwriter."
},
{
"entity": "Paul Gerstenberger",
"type": "Writer",
"category": "Person",
"description": "Paul is a writer and producer, known for The Rezort (2015), Primeval (2007) and House of Anubis (2011)."
},
{
"entity": "Dougray Scott",
"type": "Actor",
"category": "Person",
"description": "Stephen Dougray Scott (born 26 November 1965) is a Scottish actor."
},
{
"entity": "Jessica De Gouw",
"type": "Actor",
"category": "Person",
"description": "Jessica Elise De Gouw (born 15 February 1988) is an Australian actress. "
},
{
"entity": "Martin McCann",
"type": "Actor",
"category": "Person",
"description": "Martin McCann is an actor from Northern Ireland. In 2020, he was listed as number 48 on The Irish Times list of Ireland's greatest film actors"
}
]
}
],
"input": "$input"
}
"""
template_zh = """
{
"instruction": "你是命名实体识别的专家。请从输入中提取与模式定义匹配的实体。如果不存在该类型的实体请返回一个空列表。请以JSON字符串格式回应。你可以参照example进行抽取。",
"schema": $schema,
"example": [
{
"input": "《Rezort》\n《Rezort》是一部 2015 年英国僵尸恐怖片,由史蒂夫·巴克执导,保罗·格斯滕伯格编剧。\n 该片由道格瑞·斯科特、杰西卡·德·古维和马丁·麦凯恩主演。\n 在人类赢得与僵尸的毁灭性战争后,剩下的少数不死生物被关在一个安全的岛屿上,在那里他们被猎杀作为消遣。\n 当岛上的安全出现问题时,客人们必须面对新一轮疫情爆发的可能性。",
"output": [
{
"entity": "The Rezort",
"type": "Movie",
"category": "Works",
"description": "一部 2015 年英国僵尸恐怖片,由史蒂夫·巴克执导,保罗·格斯滕伯格编剧。"
},
{
"entity": "2015",
"type": "Year",
"category": "Date",
"description": "电影《The Rezort》上映的年份。"
},
{
"entity": "英国",
"type": "Nationality",
"category": "GeographicLocation",
"description": "大不列颠,包括英格兰、苏格兰和威尔士的岛屿。"
},
{
"entity": "史蒂夫·巴克",
"type": "Director",
"category": "Person",
"description": "史蒂夫·巴克 是一名英国电影导演和剧作家"
},
{
"entity": "保罗·格斯滕伯格",
"type": "Writer",
"category": "Person",
"description": "保罗·格斯滕伯格 (Paul Gerstenberger) 是一名作家和制片人因《The Rezort》2015 年、《Primeval》2007 年和《House of Anubis》2011 年)而闻名。"
},
{
"entity": "道格雷·斯科特",
"type": "Actor",
"category": "Person",
"description": "斯蒂芬·道格雷·斯科特 (Stephen Dougray Scott1965 年 11 月 26 日出生) 是一位苏格兰演员。"
},
{
"entity": "杰西卡·德·古维",
"type": "Actor",
"category": "Person",
"description": "杰西卡·伊莉斯·德·古维 (Jessica Elise De Gouw1988 年 2 月 15 日出生) 是一位澳大利亚女演员。"
},
{
"entity": "马丁·麦肯",
"type": "Actor",
"category": "Person",
"description": "马丁·麦肯是来自北爱尔兰的演员。2020 年,他在《爱尔兰时报》爱尔兰最伟大电影演员名单中排名第 48 位"
}
]
}
],
"input": "$input"
}
"""
def __init__(
self, language: Optional[str] = "en", **kwargs
):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
@property
def template_variables(self) -> List[str]:
return ["input"]
def parse_response(self, response: str, **kwargs):
rsp = response
if isinstance(rsp, str):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
if isinstance(rsp, dict) and "named_entities" in rsp:
entities = rsp["named_entities"]
else:
entities = rsp
return entities

View File

@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from typing import Optional, List
from kag.common.base.prompt_op import PromptOp
class OpenIEEntitystandardizationdPrompt(PromptOp):
template_en = """
{
"instruction": "The `input` field contains a user provided context. The `named_entities` field contains extracted named entities from the context, which may be unclear abbreviations, aliases, or slang. To eliminate ambiguity, please attempt to provide the official names of these entities based on the context and your own knowledge. Note that entities with the same meaning can only have ONE official name. Please respond in the format of a single JSONArray string without any explanation, as shown in the `output` field of the provided example.",
"example": {
"input": "American History\nWhen did the political party that favored harsh punishment of southern states after the Civil War, gain control of the House? Republicans regained control of the chamber they had lost in the 2006 midterm elections.",
"named_entities": [
{"entity": "American", "category": "GeographicLocation"},
{"entity": "political party", "category": "Organization"},
{"entity": "southern states", "category": "GeographicLocation"},
{"entity": "Civil War", "category": "Keyword"},
{"entity": "House", "category": "Organization"},
{"entity": "Republicans", "category": "Organization"},
{"entity": "chamber", "category": "Organization"},
{"entity": "2006 midterm elections", "category": "Date"}
],
"output": [
{
"entity": "American",
"category": "GeographicLocation",
"official_name": "United States of America"
},
{
"entity": "political party",
"category": "Organization",
"official_name": "Radical Republicans"
},
{
"entity": "southern states",
"category": "GeographicLocation",
"official_name": "Confederacy"
},
{
"entity": "Civil War",
"category": "Keyword",
"official_name": "American Civil War"
},
{
"entity": "House",
"category": "Organization",
"official_name": "United States House of Representatives"
},
{
"entity": "Republicans",
"category": "Organization",
"official_name": "Republican Party"
},
{
"entity": "chamber",
"category": "Organization",
"official_name": "United States House of Representatives"
},
{
"entity": "midterm elections",
"category": "Date",
"official_name": "United States midterm elections"
}
]
},
"input": "$input",
"named_entities": $named_entities
}
"""
template_zh = """
{
"instruction": "input字段包含用户提供的上下文。命名实体字段包含从上下文中提取的命名实体这些可能是含义不明的缩写、别名或俚语。为了消除歧义请尝试根据上下文和您自己的知识提供这些实体的官方名称。请注意具有相同含义的实体只能有一个官方名称。请按照提供的示例中的输出字段格式以单个JSONArray字符串形式回复无需任何解释。",
"example": {
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。\n3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。\n若体温降而复升或3天后仍不降者应考虑SP的肺外感染如腋胸、心包炎或关节炎等。治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"named_entities": [
{"entity": "烦躁不安", "category": "Symptom"},
{"entity": "语妄", "category": "Symptom"},
{"entity": "失眠", "category": "Symptom"},
{"entity": "镇静药", "category": "Medicine"},
{"entity": "肺外感染", "category": "Disease"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation"}
],
"output": [
{"entity": "烦躁不安", "category": "Symptom", "official_name": "焦虑不安"},
{"entity": "语妄", "category": "Symptom", "official_name": "谵妄"},
{"entity": "失眠", "category": "Symptom", "official_name": "失眠症"},
{"entity": "镇静药", "category": "Medicine", "official_name": "镇静剂"},
{"entity": "肺外感染", "category": "Disease", "official_name": "肺外感染"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment", "official_name": "胸腔引流管"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment", "official_name": "负压吸引装置"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation", "official_name": "闭式负压引流"}
]
},
"input": $input,
"named_entities": $named_entities,
}
"""
def __init__(self, language: Optional[str] = "en"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["input", "named_entities"]
def parse_response(self, response: str, **kwargs):
rsp = response
if isinstance(rsp, str):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
if isinstance(rsp, dict) and "named_entities" in rsp:
standardized_entity = rsp["named_entities"]
else:
standardized_entity = rsp
entities_with_offical_name = set()
merged = []
entities = kwargs.get("named_entities", [])
for entity in standardized_entity:
merged.append(entity)
entities_with_offical_name.add(entity["entity"])
# in case llm ignores some entities
for entity in entities:
if entity["entity"] not in entities_with_offical_name:
entity["official_name"] = entity["entity"]
merged.append(entity)
return merged

View File

@ -0,0 +1,210 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from typing import Optional, List
from kag.common.base.prompt_op import PromptOp
class OpenIETriplePrompt(PromptOp):
template_en = """
{
"instruction": "You are an expert specializing in carrying out open information extraction (OpenIE). Please extract any possible relations (including subject, predicate, object) from the given text, and list them following the json format {\"triples\": [[\"subject\", \"predicate\", \"object\"]]}\n. If there are none, do not list them.\n.\n\nPay attention to the following requirements:\n- Each triple should contain at least one, but preferably two, of the named entities in the entity_list.\n- Clearly resolve pronouns to their specific names to maintain clarity.",
"entity_list": $entity_list,
"input": "$input",
"example": {
"input": "The Rezort\nThe Rezort is a 2015 British zombie horror film directed by Steve Barker and written by Paul Gerstenberger.\n It stars Dougray Scott, Jessica De Gouw and Martin McCann.\n After humanity wins a devastating war against zombies, the few remaining undead are kept on a secure island, where they are hunted for sport.\n When something goes wrong with the island's security, the guests must face the possibility of a new outbreak.",
"entity_list": [
{
"entity": "The Rezort",
"category": "Works"
},
{
"entity": "2015",
"category": "Others"
},
{
"entity": "British",
"category": "GeographicLocation"
},
{
"entity": "Steve Barker",
"category": "Person"
},
{
"entity": "Paul Gerstenberger",
"category": "Person"
},
{
"entity": "Dougray Scott",
"category": "Person"
},
{
"entity": "Jessica De Gouw",
"category": "Person"
},
{
"entity": "Martin McCann",
"category": "Person"
},
{
"entity": "zombies",
"category": "Creature"
},
{
"entity": "zombie horror film",
"category": "Concept"
},
{
"entity": "humanity",
"category": "Concept"
},
{
"entity": "secure island",
"category": "GeographicLocation"
}
],
"output": [
[
"The Rezort",
"is",
"zombie horror film"
],
[
"The Rezort",
"publish at",
"2015"
],
[
"The Rezort",
"released",
"British"
],
[
"The Rezort",
"is directed by",
"Steve Barker"
],
[
"The Rezort",
"is written by",
"Paul Gerstenberger"
],
[
"The Rezort",
"stars",
"Dougray Scott"
],
[
"The Rezort",
"stars",
"Jessica De Gouw"
],
[
"The Rezort",
"stars",
"Martin McCann"
],
[
"humanity",
"wins",
"a devastating war against zombies"
],
[
"the few remaining undead",
"are kept on",
"a secure island"
],
[
"they",
"are hunted for",
"sport"
],
[
"something",
"goes wrong with",
"the island's security"
],
[
"the guests",
"must face",
"the possibility of a new outbreak"
]
]
}
}
"""
template_zh = """
{
"instruction": "您是一位专门从事开放信息提取OpenIE的专家。请从input字段的文本中提取任何可能的关系包括主语、谓语、宾语并按照JSON格式列出它们须遵循example字段的示例格式。请注意以下要求1. 每个三元组应至少包含entity_list实体列表中的一个但最好是两个命名实体。2. 明确地将代词解析为特定名称,以保持清晰度。",
"entity_list": $entity_list,
"input": "$input",
"example": {
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。\n3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。\n若体温降而复升或3天后仍不降者应考虑SP的肺外感染如腋胸、心包炎或关节炎等。治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"entity_list": [
{"entity": "烦躁不安", "category": "Symptom"},
{"entity": "语妄", "category": "Symptom"},
{"entity": "失眠", "category": "Symptom"},
{"entity": "镇静药", "category": "Medicine"},
{"entity": "肺外感染", "category": "Disease"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation"}
],
"output":[
["烦躁不安", "酌用", "镇静药"],
["语妄", "酌用", "镇静药"],
["失眠", "酌用", "镇静药"],
["镇静药", "禁用", "抑制呼吸的镇静药"],
["高热", "消退", "24小时内"],
["高热", "下降", "数日内"],
["体温", "降而复升或3天后仍不降", "肺外感染"],
["肺外感染", "考虑", "腋胸、心包炎或关节炎"],
["胸腔压力调节管", "", "吸引机负压吸引水瓶装置"],
["闭式负压吸引", "宜连续", "如经12小时后肺仍未复张"]
]
}
}
"""
def __init__(self, language: Optional[str] = "en"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["entity_list", "input"]
def parse_response(self, response: str, **kwargs):
rsp = response
if isinstance(rsp, str):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
if isinstance(rsp, dict) and "triples" in rsp:
triples = rsp["triples"]
else:
triples = rsp
standardized_triples = []
for triple in triples:
if isinstance(triple, list):
standardized_triples.append(triple)
elif isinstance(triple, dict):
s = triple.get("subject")
p = triple.get("predicate")
o = triple.get("object")
if s and p and o:
standardized_triples.append([s, p, o])
return standardized_triples

View File

View File

@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from string import Template
from typing import List, Optional
from kag.common.base.prompt_op import PromptOp
from knext.schema.client import SchemaClient
class OpenIENERPrompt(PromptOp):
template_zh = """
{
"instruction": "你是命名实体识别的专家。请从输入中提取与模式定义匹配的实体。如果不存在该类型的实体请返回一个空列表。请以JSON字符串格式回应。你可以参照example进行抽取。",
"schema": $schema,
"example": [
{
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。\n3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。\n若体温降而复升或3天后仍不降者应考虑SP的肺外感染。\n治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"output": [
{"entity": "烦躁不安", "category": "Symptom"},
{"entity": "语妄", "category": "Symptom"},
{"entity": "失眠", "category": "Symptom"},
{"entity": "镇静药", "category": "Medicine"},
{"entity": "肺外感染", "category": "Disease"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation"}
]
}
],
"input": "$input"
}
"""
template_en = template_zh
def __init__(
self, language: Optional[str] = "en", **kwargs
):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
@property
def template_variables(self) -> List[str]:
return ["input"]
def parse_response(self, response: str, **kwargs):
rsp = response
if isinstance(rsp, str):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
if isinstance(rsp, dict) and "named_entities" in rsp:
entities = rsp["named_entities"]
else:
entities = rsp
return entities

View File

@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from typing import Optional, List
from kag.common.base.prompt_op import PromptOp
class OpenIEEntitystandardizationdPrompt(PromptOp):
template_zh = """
{
"instruction": "input字段包含用户提供的上下文。命名实体字段包含从上下文中提取的命名实体这些可能是含义不明的缩写、别名或俚语。为了消除歧义请尝试根据上下文和您自己的知识提供这些实体的官方名称。请注意具有相同含义的实体只能有一个官方名称。请按照提供的示例中的输出字段格式以单个JSONArray字符串形式回复无需任何解释。",
"example": {
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。\n3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。\n若体温降而复升或3天后仍不降者应考虑SP的肺外感染如腋胸、心包炎或关节炎等。治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"named_entities": [
{"entity": "烦躁不安", "category": "Symptom"},
{"entity": "语妄", "category": "Symptom"},
{"entity": "失眠", "category": "Symptom"},
{"entity": "镇静药", "category": "Medicine"},
{"entity": "肺外感染", "category": "Disease"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation"}
],
"output": [
{"entity": "烦躁不安", "category": "Symptom", "official_name": "焦虑不安"},
{"entity": "语妄", "category": "Symptom", "official_name": "谵妄"},
{"entity": "失眠", "category": "Symptom", "official_name": "失眠症"},
{"entity": "镇静药", "category": "Medicine", "official_name": "镇静剂"},
{"entity": "肺外感染", "category": "Disease", "official_name": "肺外感染"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment", "official_name": "胸腔引流管"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment", "official_name": "负压吸引装置"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation", "official_name": "闭式负压引流"}
]
},
"input": $input,
"named_entities": $named_entities,
}
"""
template_en = template_zh
def __init__(self, language: Optional[str] = "en"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["input", "named_entities"]
def parse_response(self, response: str, **kwargs):
rsp = response
if isinstance(rsp, str):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
if isinstance(rsp, dict) and "named_entities" in rsp:
standardized_entity = rsp["named_entities"]
else:
standardized_entity = rsp
entities_with_offical_name = set()
merged = []
entities = kwargs.get("named_entities", [])
for entity in standardized_entity:
merged.append(entity)
entities_with_offical_name.add(entity["entity"])
# in case llm ignores some entities
for entity in entities:
if entity["entity"] not in entities_with_offical_name:
entity["official_name"] = entity["entity"]
merged.append(entity)
return merged

View File

@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from typing import Optional, List, Dict, Any
from kag.common.base.prompt_op import PromptOp
class OpenIETriplePrompt(PromptOp):
template_zh = """
{
"instruction": "您是一位专门从事开放信息提取OpenIE的专家。请从input字段的文本中提取任何可能的关系包括主语、谓语、宾语并按照JSON格式列出它们须遵循example字段的示例格式。请注意以下要求1. 每个三元组应至少包含entity_list实体列表中的一个但最好是两个命名实体。2. 明确地将代词解析为特定名称,以保持清晰度。",
"entity_list": $entity_list,
"input": "$input",
"example": {
"input": "烦躁不安、语妄、失眠酌用镇静药,禁用抑制呼吸的镇静药。\n3.并发症的处理经抗菌药物治疗后高热常在24小时内消退或数日内逐渐下降。\n若体温降而复升或3天后仍不降者应考虑SP的肺外感染如腋胸、心包炎或关节炎等。治疗接胸腔压力调节管吸引机负压吸引水瓶装置闭式负压吸引宜连续如经12小时后肺仍未复张应查找原因。",
"entity_list": [
{"entity": "烦躁不安", "category": "Symptom"},
{"entity": "语妄", "category": "Symptom"},
{"entity": "失眠", "category": "Symptom"},
{"entity": "镇静药", "category": "Medicine"},
{"entity": "肺外感染", "category": "Disease"},
{"entity": "胸腔压力调节管", "category": "MedicalEquipment"},
{"entity": "吸引机负压吸引水瓶装置", "category": "MedicalEquipment"},
{"entity": "闭式负压吸引", "category": "SurgicalOperation"}
],
"output":[
["烦躁不安", "酌用", "镇静药"],
["语妄", "酌用", "镇静药"],
["失眠", "酌用", "镇静药"],
["镇静药", "禁用", "抑制呼吸的镇静药"],
["高热", "消退", "24小时内"],
["高热", "下降", "数日内"],
["体温", "降而复升或3天后仍不降", "肺外感染"],
["肺外感染", "考虑", "腋胸、心包炎或关节炎"],
["胸腔压力调节管", "", "吸引机负压吸引水瓶装置"],
["闭式负压吸引", "宜连续", "如经12小时后肺仍未复张"]
]
}
}
"""
template_en = template_zh
def __init__(self, language: Optional[str] = "en"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["entity_list", "input"]
def parse_response(self, response: str, **kwargs):
rsp = response
if isinstance(rsp, str):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
if isinstance(rsp, dict) and "triples" in rsp:
triples = rsp["triples"]
else:
triples = rsp
standardized_triples = []
for triple in triples:
if isinstance(triple, list):
standardized_triples.append(triple)
elif isinstance(triple, dict):
s = triple.get("subject")
p = triple.get("predicate")
o = triple.get("object")
if s and p and o:
standardized_triples.append([s, p, o])
return standardized_triples

View File

@ -0,0 +1,518 @@
#
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import logging
import re
from abc import ABC
from typing import List, Dict, Any
from collections import defaultdict
from knext.schema.model.schema_helper import SPGTypeName
from kag.builder.model.spg_record import SPGRecord
from kag.builder.prompt.spg_prompt import SPGPrompt
import uuid
logger = logging.getLogger(__name__)
class OneKEPrompt(SPGPrompt, ABC):
template_zh: str = ""
template_en: str = ""
def __init__(self, **kwargs):
types_list = kwargs.get("types_list", [])
language = kwargs.get("language", "zh")
with_description = kwargs.get("with_description", False)
split_num = kwargs.get("split_num", 4)
super().__init__(types_list, **kwargs)
self.language = language
if language == "zh":
self.template = self.template_zh
else:
self.template = self.template_en
self.with_description = with_description
self.split_num = split_num
self._init_render_variables()
self._render()
self.params = kwargs
def build_prompt(self, variables: Dict[str, str]) -> List[str]:
instructions = []
for schema in self.schema_list:
instructions.append(
json.dumps(
{
"instruction": self.template,
"schema": schema,
"input": variables.get("input"),
},
ensure_ascii=False,
)
)
return instructions
def parse_response(self, response: str) -> List[SPGRecord]:
raise NotImplementedError
def _render(self):
raise NotImplementedError
def multischema_split_by_num(self, split_num, schemas: List[Any]):
negative_length = max(len(schemas) // split_num, 1) * split_num
total_schemas = []
for i in range(0, negative_length, split_num):
total_schemas.append(schemas[i : i + split_num])
remain_len = max(1, split_num // 2)
tmp_schemas = schemas[negative_length:]
if len(schemas) - negative_length >= remain_len and len(tmp_schemas) > 0:
total_schemas.append(tmp_schemas)
elif len(tmp_schemas) > 0:
total_schemas[-1].extend(tmp_schemas)
return total_schemas
class OneKE_NERPrompt(OneKEPrompt):
template_zh: str = (
"你是专门进行实体抽取的专家。请从input中抽取出符合schema定义的实体不存在的实体类型返回空列表。请按照JSON字符串的格式回答。"
)
template_en: str = "You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string."
def __init__(
self,
entity_types: List[SPGTypeName],
language: str = "zh",
with_description: bool = False,
split_num: int = 4,
**kwargs,
):
super().__init__(
types_list=entity_types,
language=language,
with_description=with_description,
split_num=split_num,
**kwargs,
)
def parse_response(self, response: str) -> List[SPGRecord]:
if isinstance(response, list) and len(response) > 0:
response = response[0]
try:
ent_obj = json.loads(response)
except json.decoder.JSONDecodeError:
logger.error("OneKE_NERPrompt response JSONDecodeError error.")
return []
if type(ent_obj) != dict:
logger.error("OneKE_NERPrompt response type error.")
return []
spg_records = []
for type_zh, values in ent_obj.items():
if type_zh not in self.spg_type_schema_info_zh:
logger.warning(f"Unrecognized entity_type: {type_zh}")
continue
type_en, _ = self.spg_type_schema_info_zh[type_zh]
for value in values:
spg_record = SPGRecord(type_en)
spg_record.upsert_properties({"id": value, "name": value})
spg_records.append(spg_record)
return spg_records
def _render(self):
entity_list = []
for spg_type in self.spg_types:
entity_list.append(spg_type.name_zh)
self.schema_list = self.multischema_split_by_num(self.split_num, entity_list)
class OneKE_SPOPrompt(OneKEPrompt):
template_zh: str = (
"你是专门进行SPO三元组抽取的专家。请从input中抽取出符合schema定义的spo关系三元组不存在的关系返回空列表。请按照JSON字符串的格式回答。"
)
template_en: str = "You are an expert in spo(subject, predicate, object) triples extraction. Please extract SPO relationship triples that match the schema definition from the input. Return an empty list for relationships that do not exist. Please respond in the format of a JSON string."
def __init__(
self,
spo_types: List[SPGTypeName],
language: str = "zh",
with_description: bool = False,
split_num: int = 4,
**kwargs,
):
super().__init__(
types_list=spo_types,
language=language,
with_description=with_description,
split_num=split_num,
**kwargs,
)
self.properties_mapper = {}
self.relations_mapper = {}
def parse_response(self, response: str) -> List[SPGRecord]:
if isinstance(response, list) and len(response) > 0:
response = response[0]
try:
re_obj = json.loads(response)
except json.decoder.JSONDecodeError:
logger.error("OneKE_REPrompt response JSONDecodeError error.")
return []
if type(re_obj) != dict:
logger.error("OneKE_REPrompt response type error.")
return []
relation_dcir = defaultdict(list)
for relation_zh, values in re_obj.items():
if relation_zh not in self.property_info_zh[relation_zh]:
logger.warning(f"Unrecognized relation: {relation_zh}")
continue
if values and isinstance(values, list):
for value in values:
if (
type(value) != dict
or "subject" not in value
or "object" not in value
):
logger.warning("OneKE_REPrompt response type error.")
continue
s_zh, o_zh = value.get("subject", ""), value.get("object", "")
relation_dcir[relation_zh].append((s_zh, o_zh))
spg_records = []
for relation_zh, sub_obj_list in relation_dcir.items():
sub_dict = defaultdict(list)
for s_zh, o_zh in sub_obj_list:
sub_dict[s_zh].append(o_zh)
for s_zh, o_list in sub_dict.items():
if s_zh in self.spg_type_schema_info_zh:
logger.warning(f"Unrecognized subject_type: {s_zh}")
continue
object_value = ",".join(o_list)
s_type_zh = self.properties_mapper.get(relation_zh, None)
if s_type_zh is not None:
s_type_en, _ = self.spg_type_schema_info_zh[s_type_zh]
relation_en, _ = self.property_info_zh[relation_zh]
spg_record = SPGRecord(s_type_en).upsert_properties(
{"id": s_zh, "name": s_zh}
)
spg_record.upsert_property(relation_en, object_value)
else:
s_type_zh, o_type_zh = self.relations_mapper.get(
relation_zh, [None, None]
)
if s_type_zh is None or o_type_zh is None:
logger.warning(f"Unrecognized relation: {relation_zh}")
continue
s_type_en, _ = self.spg_type_schema_info_zh[s_type_zh]
spg_record = SPGRecord(s_type_en).upsert_properties(
{"id": s_zh, "name": s_zh}
)
relation_en, _, object_type = self.relation_info_zh[s_type_zh][
relation_zh
]
spg_record.upsert_relation(relation_en, object_type, object_value)
spg_records.append(spg_record)
return spg_records
def _render(self):
spo_list = []
for spg_type in self.spg_types:
type_en, _ = self.spg_type_schema_info_zh[spg_type]
for v in spg_type.properties.values():
spo_list.append(
{
"subject_type": spg_type.name_zh,
"predicate": v.name_zh,
"object_type": "文本",
}
)
self.properties_mapper[v.name_zh] = spg_type
for v in spg_type.relations.values():
_, _, object_type = self.relation_info_en[type_en][v.name]
spo_list.append(
{
"subject_type": spg_type.name_zh,
"predicate": v.name_zh,
"object_type": object_type,
}
)
self.relations_mapper[v.name_zh] = [spg_type, object_type]
self.schema_list = self.multischema_split_by_num(self.split_num, spo_list)
class OneKE_REPrompt(OneKE_SPOPrompt):
template_zh: str = (
"你是专门进行关系抽取的专家。请从input中抽取出符合schema定义的关系三元组不存在的关系返回空列表。请按照JSON字符串的格式回答。"
)
template_en: str = "You are an expert in relationship extraction. Please extract relationship triples that match the schema definition from the input. Return an empty list for relationships that do not exist. Please respond in the format of a JSON string."
def __init__(
self,
relation_types: List[SPGTypeName],
language: str = "zh",
with_description: bool = False,
split_num: int = 4,
**kwargs,
):
super().__init__(
relation_types, language, with_description, split_num, **kwargs
)
def _render(self):
re_list = []
for spg_type in self.spg_types:
type_en, _ = self.spg_type_schema_info_zh[spg_type]
for v in spg_type.properties.values():
re_list.append(v.name_zh)
self.properties_mapper[v.name_zh] = spg_type
for v in spg_type.relations.values():
v_zh, _, object_type = self.relation_info_en[type_en][v.name]
re_list.append(v.name_zh)
self.relations_mapper[v.name_zh] = [spg_type, object_type]
self.schema_list = self.multischema_split_by_num(self.split_num, re_list)
class OneKE_KGPrompt(OneKEPrompt):
template_zh: str = "你是一个图谱实体知识结构化专家。根据输入实体类型(entity type)的schema描述从文本中抽取出相应的实体实例和其属性信息不存在的属性不输出, 属性存在多值就返回列表并输出为可解析的json格式。"
template_en: str = "You are an expert in structured knowledge systems for graph entities. Based on the schema description of the input entity type, you extract the corresponding entity instances and their attribute information from the text. Attributes that do not exist should not be output. If an attribute has multiple values, a list should be returned. The results should be output in a parsable JSON format."
def __init__(
self,
entity_types: List[SPGTypeName],
language: str = "zh",
with_description: bool = False,
split_num: int = 4,
**kwargs,
):
super().__init__(
types_list=entity_types,
language=language,
with_description=with_description,
split_num=split_num,
**kwargs,
)
def parse_response(self, response: str) -> List[SPGRecord]:
if isinstance(response, list) and len(response) > 0:
response = response[0]
try:
re_obj = json.loads(response)
except json.decoder.JSONDecodeError:
logger.error("OneKE_KGPrompt response JSONDecodeError error.")
return []
if type(re_obj) != dict:
logger.error("OneKE_KGPrompt response type error.")
return []
spg_records = []
for type_zh, type_value in re_obj.items():
if type_zh not in self.spg_type_schema_info_zh:
logger.warning(f"Unrecognized entity_type: {type_zh}")
continue
type_en, _ = self.spg_type_schema_info_zh[type_zh]
if type_value and isinstance(type_value, dict):
for name, attrs in type_value.items():
spg_record = SPGRecord(type_en).upsert_properties(
{"id": name, "name": name}
)
for attr_zh, attr_value in attrs.items():
if isinstance(attr_value, list):
attr_value = ",".join(attr_value)
if attr_zh in self.property_info_zh[type_zh]:
attr_en, _, object_type = self.property_info_zh[type_zh][
attr_zh
]
spg_record.upsert_property(attr_en, attr_value)
elif attr_zh in self.relation_info_zh[type_zh]:
attr_en, _, object_type = self.relation_info_zh[type_zh][
attr_zh
]
spg_record.upsert_relation(attr_en, object_type, attr_value)
else:
logger.warning(f"Unrecognized attribute: {attr_zh}")
continue
if object_type == "Integer":
matches = re.findall(r"\d+", attr_value)
if matches:
spg_record.upsert_property(attr_en, matches[0])
elif object_type == "Float":
matches = re.findall(r"\d+(?:\.\d+)?", attr_value)
if matches:
spg_record.upsert_property(attr_en, matches[0])
spg_records.append(spg_record)
return spg_records
def _render(self):
spo_list = []
for spg_type in self.spg_types:
if not self.with_description:
attributes = []
attributes.extend(
[
v.name_zh
for k, v in spg_type.properties.items()
if k not in self.ignored_properties
]
)
attributes.extend(
[
v.name_zh
for k, v in spg_type.relations.items()
if v.name_zh not in attributes
and k not in self.ignored_relations
]
)
else:
attributes = {}
attributes.update(
{
v.name_zh: v.desc or ""
for k, v in spg_type.properties.items()
if k not in self.ignored_properties
}
)
attributes.update(
{
v.name_zh: v.desc or ""
for k, v in spg_type.relations.items()
if v.name_zh not in attributes
and k not in self.ignored_relations
}
)
entity_type = spg_type.name_zh
spo_list.append({"entity_type": entity_type, "attributes": attributes})
self.schema_list = self.multischema_split_by_num(self.split_num, spo_list)
class OneKE_EEPrompt(OneKEPrompt):
template_zh: str = "你是专门进行事件提取的专家。请从input中抽取出符合schema定义的事件不存在的事件返回空列表不存在的论元返回NAN如果论元存在多值请返回列表。请按照JSON字符串的格式回答。"
template_en: str = "You are an expert in event extraction. Please extract events from the input that conform to the schema definition. Return an empty list for events that do not exist, and return NAN for arguments that do not exist. If an argument has multiple values, please return a list. Respond in the format of a JSON string."
def __init__(
self,
event_types: List[SPGTypeName],
language: str = "zh",
with_description: bool = False,
split_num: int = 4,
**kwargs,
):
super().__init__(
types_list=event_types,
language=language,
with_description=with_description,
split_num=split_num,
**kwargs,
)
def parse_response(self, response: str) -> List[SPGRecord]:
if isinstance(response, list) and len(response) > 0:
response = response[0]
try:
ee_obj = json.loads(response)
except json.decoder.JSONDecodeError:
logger.error("OneKE_EEPrompt response JSONDecodeError error.")
return []
if type(ee_obj) != dict:
logger.error("OneKE_EEPrompt response type error.")
return []
spg_records = []
for type_zh, type_values in ee_obj.items():
if type_zh not in self.spg_type_schema_info_zh:
logger.warning(f"Unrecognized event_type: {type_zh}")
continue
type_en, _ = self.spg_type_schema_info_zh[type_zh]
if type_values and isinstance(type_values, list):
for type_value in type_values:
uuid_4 = uuid.uuid4()
spg_record = (
SPGRecord(type_en)
.upsert_property("id", str(uuid_4))
.upsert_property("name", type_zh)
)
arguments = type_value.get("arguments")
if arguments and isinstance(arguments, dict):
for attr_zh, attr_value in arguments.items():
if isinstance(attr_value, list):
attr_value = ",".join(attr_value)
if attr_zh in self.property_info_zh[type_zh]:
attr_en, _, object_type = self.property_info_zh[
type_zh
][attr_zh]
spg_record.upsert_property(attr_en, attr_value)
elif attr_zh in self.relation_info_zh[type_zh]:
attr_en, _, object_type = self.relation_info_zh[
type_zh
][attr_zh]
spg_record.upsert_relation(
attr_en, object_type, attr_value
)
else:
logger.warning(f"Unrecognized attribute: {attr_zh}")
continue
if object_type == "Integer":
matches = re.findall(r"\d+", attr_value)
if matches:
spg_record.upsert_property(attr_en, matches[0])
elif object_type == "Float":
matches = re.findall(r"\d+(?:\.\d+)?", attr_value)
if matches:
spg_record.upsert_property(attr_en, matches[0])
spg_records.append(spg_record)
return spg_records
def _render(self):
event_list = []
for spg_type in self.spg_types:
if not self.with_description:
arguments = []
arguments.extend(
[
v.name_zh
for k, v in spg_type.properties.items()
if k not in self.ignored_properties
]
)
arguments.extend(
[
v.name_zh
for k, v in spg_type.relations.items()
if v.name_zh not in arguments
and k not in self.ignored_relations
]
)
else:
arguments = {}
arguments.update(
{
v.name_zh: v.desc or ""
for k, v in spg_type.properties.items()
if k not in self.ignored_properties
}
)
arguments.update(
{
v.name_zh: v.desc or ""
for k, v in spg_type.relations.items()
if v.name_zh not in arguments
and k not in self.ignored_relations
}
)
event_type = spg_type.name_zh
event_list.append(
{"event_type": event_type, "trigger": True, "arguments": arguments}
)
self.schema_list = self.multischema_split_by_num(self.split_num, event_list)

View File

@ -0,0 +1,162 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from typing import Optional, List
from kag.common.base.prompt_op import PromptOp
class OutlinePrompt(PromptOp):
template_zh = """
{
"instruction": "\n请理解input字段中的文本内容识别文本的结构和组成部分并帮我提取出以下内容的标题可能有多个标题分散在文本的各个地方仅返属于原文的回标题文本即可不要返回其他任何内容须按照python list的格式回答具体形式请遵从example字段中给出的若干例子。",
"input": "$input",
"example": [
{
"input": "第8条 原 则
1.各成员方在制订或修正其法律和规章时可采取必要措施以保护公众健康和营养并促进对其社会经济和技术发展至关重要部门的公众利益只要该措施符合本协议规定
2.可能需要采取与本协议的规定相一致的适当的措施以防止知识产权所有者滥用知识产权或藉以对贸易进行不合理限制或实行对国际间的技术转让产生不利影响的作法
第二部分 关于知识产权的效力范围及使用的标准
第1节 版权及相关权利
第9条 伯尔尼公约的关系",
"output": [
"第8条 原 则",
"第二部分 关于知识产权的效力、范围及使用的标准",
"第1节 版权及相关权利",
"第9条 与《伯尔尼公约》的关系"
],
},
{
"input": "第16条 授予权利
1.已注册商标所有者应拥有阻止所有未经其同意的第三方在贸易中使用与已注册商标相同或相似的商品或服务的其使用有可能招致混淆的相同或相似的标志在对相同商品或服务使用相同标志的情况下应推定存在混淆之可能上述权利不应妨碍任何现行的优先权也不应影响各成员方以使用为条件获得注册权的可能性
2.1967巴黎公约第6条副则经对细节作必要修改后应适用于服务在确定一个商标是否为知名商标时各成员方应考虑到有关部分的公众对该商标的了解包括由于该商标的推行而在有关成员方得到的了解
3.1967巴黎公约第6条副则经对细节作必要修改后应适用于与已注册商标的商品和服务不相似的商品或服务条件是该商标与该商品和服务有关的使用会表明该商品或服务与已注册商标所有者之间的联系而且已注册商标所有者的利益有可能为此种使用所破坏
第17条  \n ",
"output": [
"第16条 授予权利",
"第17条 例 外"
],
},
{
"input":"的做法。
4此类使用应是非独占性的
5此类使用应是不可转让的除非是同享有此类使用的那部分企业或信誉一道转让
6任何此类使用之授权均应主要是为授权此类使用的成员方国内市场供应之目的
7在被授权人的合法利益受到充分保护的条件下当导致此类使用授权的情况下不复存在和可能不再产生时有义务将其终止应有动机的请求主管当局应有权对上述情况的继续存在进行检查
8考虑到授权的经济价值应视具体情况向权利人支付充分的补偿金
9任何与此类使用之授权有关的决定其法律效力应接受该成员方境内更高当局的司法审查或其他独立审查
10任何与为此类使用而提供的补偿金有关的决定应接受成员方境内更高当局的司法审查或其他独立审查
",
"output": [],
},
]
}
"""
template_en = """
{
"instruction": "\nUnderstand the text content in the input field, identify the structure and components of the text, and help me extract the titles from the following content. There may be multiple titles scattered throughout the text. Only return the title texts that belong to the original text, and do not return any other content. The response must be in the format of a Python list, and the specific form should follow the examples given in the example field.",
"input": "$input",
"example": [
{
"input": "Article 8 Principles
1. In formulating or amending their laws and regulations, Members may take necessary measures to protect public health and nutrition, and to promote the public interest in sectors of vital importance to their socio-economic and technological development, provided that such measures are consistent with the provisions of this Agreement.
2. Appropriate measures may be needed to prevent the abuse of intellectual property rights by owners, or to prevent practices that restrict trade unjustifiably or adversely affect the international transfer of technology, in conformity with the provisions of this Agreement.
Part Two: Standards Concerning the Availability, Scope and Use of Intellectual Property Rights
Section 1 Copyright and Related Rights
Article 9 Relationship with the Berne Convention
",
"output": [
"Article 8 Principles",
"Part Two: Standards Concerning the Availability, Scope and Use of Intellectual Property Rights",
"Section 1 Copyright and Related Rights",
"Article 9 Relationship with the Berne Convention"
],
},
{
"input": "Article 16 Grant of Rights
1. Owners of registered trademarks shall have the right to prevent all third parties from using, without their consent, in the course of trade, any identical or similar signs for goods or services that are identical or similar to those for which the trademark is registered, where such use is likely to cause confusion. In the case of identical signs being used for identical goods or services, a likelihood of confusion shall be presumed. The above rights shall not prejudice any existing priority rights, nor shall they affect the possibility for Members to obtain registration rights conditional upon use.
2. The provisions of Article 6bis of the Paris Convention of 1967 shall apply to services with necessary modifications to the details. In determining whether a trademark is well-known, Members shall take into account the knowledge of the relevant public about that trademark, including the knowledge acquired in the relevant Member due to the promotion of that trademark.
3. The provisions of Article 6bis of the Paris Convention of 1967 shall apply to goods or services that are not similar to those for which the registered trademark is granted, provided that the use of the trademark in relation to those goods or services indicates a connection between the goods or services and the owner of the registered trademark, and that the interests of the owner of the registered trademark are likely to be harmed by such use.
Article 17 Exceptions
",
"output": [
"Article 16 Grant of Rights",
"Article 17 Exceptions"
],
},
{
"input": "by doing so.
(4) The use of this category should be non-exclusive.
(5) The use of this category should be non-transferable, unless it is transferred together with the part of the enterprise or reputation that enjoys the use of this category.
(6) Any authorization for such use should primarily be for the purpose of domestic market supply by the member party authorizing such use.
(7) There is an obligation to terminate it when the circumstances leading to the authorization for such use no longer exist and are unlikely to reoccur; upon motivated request, the competent authorities should have the right to examine the continued existence of the above circumstances.
(8) Adequate compensation should be paid to the right holder, taking into account the economic value of the authorization.
(9) Any decisions related to the authorization for such use should be subject to judicial review or other independent review by a higher authority within the territory of the member party.
(10) Any decisions related to the compensation provided for such use should be subject to judicial review or other independent review by a higher authority within the territory of the member party.",
"output": [],
},
]
}
"""
def __init__(self, language: Optional[str] = "zh"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["input"]
def parse_response(self, response: str, **kwargs):
if isinstance(response, str):
response = json.loads(response)
if isinstance(response, dict) and "output" in response:
response = response["output"]
outline = kwargs.get("outline", [])
for r in response:
outline.append(r)
return outline

View File

@ -0,0 +1,149 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from typing import Optional, List
from kag.common.base.prompt_op import PromptOp
class SemanticSegPrompt(PromptOp):
template_zh = """
{
"instruction": "\n请理解input字段中的文本内容识别文本的结构和组成部分并按照语义主题确定分割点将其切分成互不重叠的若干小节。如果文章有章节等可识别的结构信息请直接按照顶层结构进行切分。\n请按照schema定义的字段返回包含小节摘要和小节起始点。须按照JSON字符串的格式回答。具体形式请遵从example字段中给出的若干例子。",
"schema": {
"小节摘要": "该小节文本的简单概括",
"小节起始点": "该小节包含的原文的起点控制在20个字左右。该分割点将被用于分割原文因此必须可以在原文中找到"
},
"input": "$input",
"example": [
{
"input": "周杰伦Jay Chou1979年1月18日出生于台湾省新北市祖籍福建省永春县华语流行乐男歌手、音乐人、演员、导演、编剧毕业于淡江中学。\n2000年在杨峻荣的推荐下周杰伦开始演唱自己创作的歌曲。",
"output": [
{
"小节摘要": "个人简介",
"小节起始点": "周杰伦Jay Chou1979年1月18"
},
{
"小节摘要": "演艺经历",
"小节起始点": "\n2000年在杨峻荣的推荐下"
}
]
},
{
"input": "杭州市灵活就业人员缴存使用住房公积金管理办法(试行)\n为扩大住房公积金制度受益面,支持灵活就业人员解决住房问题,根据国务院《住房公积金管理条例》、《浙江省住房公积金条例》以及住房和城乡建设部、浙江省住房和城乡建设厅关于灵活就业人员参加住房公积金制度的有关规定和要求,结合杭州市实际,制订本办法。\n一、本办法适用于本市行政区域内灵活就业人员住房公积金的自愿缴存、使用和管理。\n二、本办法所称灵活就业人员是指在本市行政区域内年满16周岁且男性未满60周岁、女性未满55周岁具有完全民事行为能力以非全日制、个体经营、新就业形态等灵活方式就业的人员。\n三、灵活就业人员申请缴存住房公积金,应向杭州住房公积金管理中心(以下称公积金中心)申请办理缴存登记手续,设立个人账户。\n ",
"output": [
{
"小节摘要": "管理办法的制定背景和依据",
"小节起始点": "为扩大住房公积金制度受益面"
},
{
"小节摘要": "管理办法的适用范围",
"小节起始点": "一、本办法适用于本市行政区域内"
},
{
"小节摘要": "灵活就业人员的定义",
"小节起始点": "二、本办法所称灵活就业人员是指"
},
{
"小节摘要": "灵活就业人员缴存登记手续",
"小节起始点": "三、灵活就业人员申请缴存住房公积金",
}
]
}
]
}
"""
template_en = """
{
"instruction": "\nPlease understand the content of the text in the input field, recognize the structure and components of the text, and determine the segmentation points according to the semantic theme, dividing it into several non-overlapping sections. If the article has recognizable structural information such as chapters, please divide it according to the top-level structure.\nPlease return the results according to the schema definition, including summaries and starting points of the sections. The format must be a JSON string. Please follow the examples given in the example field.",
"schema": {
"Section Summary": "A brief summary of the section text",
"Section Starting Point": "The starting point of the section in the original text, limited to about 20 characters. This segmentation point will be used to split the original text, so it must be found in the original text!"
},
"input": "$input",
"example": [
{
"input": "Jay Chou (Jay Chou), born on January 18, 1979, in Xinbei City, Taiwan Province, originally from Yongchun County, Fujian Province, is a Mandopop male singer, musician, actor, director, screenwriter, and a graduate of Tamkang Senior High School.\nIn 2000, recommended by Yang Junrong, Jay Chou started singing his own compositions.",
"output": [
{
"Section Summary": "Personal Introduction",
"Section Starting Point": "Jay Chou (Jay Chou), born on January 18"
},
{
"Section Summary": "Career Start",
"Section Starting Point": "\nIn 2000, recommended by Yang Junrong"
}
]
},
{
"input": "Hangzhou Flexible Employment Personnel Housing Provident Fund Management Measures (Trial)\nTo expand the benefits of the housing provident fund system and support flexible employment personnel to solve housing problems, according to the State Council's 'Housing Provident Fund Management Regulations', 'Zhejiang Province Housing Provident Fund Regulations' and the relevant provisions and requirements of the Ministry of Housing and Urban-Rural Development and the Zhejiang Provincial Department of Housing and Urban-Rural Development on flexible employment personnel participating in the housing provident fund system, combined with the actual situation in Hangzhou, this method is formulated.\n1. This method applies to the voluntary deposit, use, and management of the housing provident fund for flexible employment personnel within the administrative region of this city.\n2. The flexible employment personnel referred to in this method are those who are within the administrative region of this city, aged 16 and above, and males under 60 and females under 55, with full civil capacity, and employed in a flexible manner such as part-time, self-employed, or in new forms of employment.\n3. Flexible employment personnel applying to deposit the housing provident fund should apply to the Hangzhou Housing Provident Fund Management Center (hereinafter referred to as the Provident Fund Center) for deposit registration procedures and set up personal accounts.",
"output": [
{
"Section Summary": "Background and Basis for Formulating the Management Measures",
"Section Starting Point": "To expand the benefits of the housing provident fund system"
},
{
"Section Summary": "Scope of Application of the Management Measures",
"Section Starting Point": "1. This method applies to the voluntary deposit"
},
{
"Section Summary": "Definition of Flexible Employment Personnel",
"Section Starting Point": "2. The flexible employment personnel referred to in this method"
},
{
"Section Summary": "Procedures for Flexible Employment Personnel to Register for Deposit",
"Section Starting Point": "3. Flexible employment personnel applying to deposit the housing provident fund"
}
]
}
]
}
"""
def __init__(self, language: Optional[str] = "zh"):
super().__init__(language)
@property
def template_variables(self) -> List[str]:
return ["input"]
def parse_response(self, response: str, **kwargs):
if isinstance(response, str):
response = json.loads(response)
if isinstance(response, dict) and "output" in response:
response = response["output"]
content = kwargs.get("input", "")
seg_info = []
for seg_point in response:
if not isinstance(seg_point, dict):
continue
start = seg_point.get(
"小节起始点" if self.language == "zh" else "Section Starting Point",
)
if not isinstance(start, str):
continue
start = start.strip()
# use first 10 charathers for split
loc = content.find(start)
if loc == -1:
print(f"incorrect seg: {seg_point}")
continue
abstract = seg_point.get(
"小节摘要" if self.language == "zh" else "Section Summary", None
)
seg_info.append((loc, abstract))
return seg_info

View File

@ -0,0 +1,205 @@
#
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import logging
from abc import ABC
from typing import List, Dict
from kag.common.base.prompt_op import PromptOp
from knext.schema.client import SchemaClient
from knext.schema.model.base import BaseSpgType, SpgTypeEnum
from knext.schema.model.schema_helper import SPGTypeName
from kag.builder.model.spg_record import SPGRecord
logger = logging.getLogger(__name__)
class SPGPrompt(PromptOp, ABC):
spg_types: Dict[str, BaseSpgType]
ignored_types: List[str] = ["Chunk"]
ignored_properties: List[str] = ["id", "name", "description", "stdId", "eventTime", "desc", "semanticType"]
ignored_relations: List[str] = ["isA"]
basic_types = {"Text": "文本", "Integer": "整型", "Float": "浮点型"}
def __init__(
self,
spg_type_names: List[SPGTypeName],
language: str = "zh",
**kwargs,
):
super().__init__(language=language, **kwargs)
self.all_schema_types = SchemaClient(project_id=self.project_id).load()
self.spg_type_names = spg_type_names
if not spg_type_names:
self.spg_types = self.all_schema_types
else:
self.spg_types = {k: v for k, v in self.all_schema_types.items() if k in spg_type_names}
self.schema_list = []
self._init_render_variables()
@property
def template_variables(self) -> List[str]:
return ["schema", "input"]
def _init_render_variables(self):
self.type_en_to_zh = {"Text": "文本", "Integer": "整型", "Float": "浮点型"}
self.type_zh_to_en = {
"文本": "Text",
"整型": "Integer",
"浮点型": "Float",
}
self.prop_en_to_zh = {}
self.prop_zh_to_en = {}
for type_name, spg_type in self.all_schema_types.items():
self.type_en_to_zh[type_name] = spg_type.name_zh
self.type_en_to_zh[spg_type.name_zh] = type_name
self.prop_zh_to_en[type_name] = {}
self.prop_en_to_zh[type_name] = {}
for _prop in spg_type.properties.values():
if _prop.name in self.ignored_properties:
continue
self.prop_en_to_zh[type_name][_prop.name] = _prop.name_zh
self.prop_zh_to_en[type_name][_prop.name_zh] = _prop.name
for _rel in spg_type.relations.values():
if _rel.is_dynamic:
continue
self.prop_en_to_zh[type_name][_rel.name] = _rel.name_zh
self.prop_zh_to_en[type_name][_rel.name_zh] = _rel.name
def _render(self):
raise NotImplementedError
class SPG_KGPrompt(SPGPrompt):
template_zh: str = """
{
"instruction": "你是一个图谱知识抽取的专家, 基于constraint 定义的schema从input 中抽取出所有的实体及其属性input中未明确提及的属性返回NAN以标准json 格式输出结果返回list",
"schema": $schema,
"example": [
{
"input": "甲状腺结节是指在甲状腺内的肿块可随吞咽动作随甲状腺而上下移动是临床常见的病症可由多种病因引起。临床上有多种甲状腺疾病如甲状腺退行性变、炎症、自身免疫以及新生物等都可以表现为结节。甲状腺结节可以单发也可以多发多发结节比单发结节的发病率高但单发结节甲状腺癌的发生率较高。患者通常可以选择在普外科甲状腺外科内分泌科头颈外科挂号就诊。有些患者可以触摸到自己颈部前方的结节。在大多情况下甲状腺结节没有任何症状甲状腺功能也是正常的。甲状腺结节进展为其它甲状腺疾病的概率只有1%。有些人会感觉到颈部疼痛、咽喉部异物感,或者存在压迫感。当甲状腺结节发生囊内自发性出血时,疼痛感会更加强烈。治疗方面,一般情况下可以用放射性碘治疗,复方碘口服液(Lugol液)等,或者服用抗甲状腺药物来抑制甲状腺激素的分泌。目前常用的抗甲状腺药物是硫脲类化合物,包括硫氧嘧啶类的丙基硫氧嘧啶(PTU)和甲基硫氧嘧啶(MTU)及咪唑类的甲硫咪唑和卡比马唑。",
"schema": {
"Disease": {
"properties": {
"complication": "并发症",
"commonSymptom": "常见症状",
"applicableMedicine": "适用药品",
"department": "就诊科室",
"diseaseSite": "发病部位",
}
},"Medicine": {
"properties": {
}
}
}
"output": [
{
"entity": "甲状腺结节",
"category":"Disease"
"properties": {
"complication": "甲状腺癌",
"commonSymptom": ["颈部疼痛", "咽喉部异物感", "压迫感"],
"applicableMedicine": ["复方碘口服液(Lugol液)", "丙基硫氧嘧啶(PTU)", "甲基硫氧嘧啶(MTU)", "甲硫咪唑", "卡比马唑"],
"department": ["普外科", "甲状腺外科", "内分泌科", "头颈外科"],
"diseaseSite": "甲状腺",
}
},{
"entity":"复方碘口服液(Lugol液)",
"category":"Medicine"
},{
"entity":"丙基硫氧嘧啶(PTU)",
"category":"Medicine"
},{
"entity":"甲基硫氧嘧啶(MTU)",
"category":"Medicine"
},{
"entity":"甲硫咪唑",
"category":"Medicine"
},{
"entity":"卡比马唑",
"category":"Medicine"
}
],
"input": "$input"
}
"""
template_en: str = template_zh
def __init__(
self,
spg_type_names: List[SPGTypeName],
language: str = "zh",
**kwargs
):
super().__init__(
spg_type_names=spg_type_names,
language=language,
**kwargs
)
self._render()
def build_prompt(self, variables: Dict[str, str]) -> str:
schema = {}
for tmpSchema in self.schema_list:
schema.update(tmpSchema)
return super().build_prompt({"schema": schema, "input": variables.get("input")})
def parse_response(self, response: str, **kwargs) -> List[SPGRecord]:
rsp = response
if isinstance(rsp, str):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
if isinstance(rsp, dict) and "named_entities" in rsp:
entities = rsp["named_entities"]
else:
entities = rsp
return entities
def _render(self):
spo_list = []
for type_name, spg_type in self.spg_types.items():
if spg_type.spg_type_enum not in [SpgTypeEnum.Entity, SpgTypeEnum.Concept, SpgTypeEnum.Event]:
continue
constraint = {}
properties = {}
properties.update(
{
v.name: (f"{v.name_zh}" if not v.desc else f"{v.name_zh}{v.desc}") if self.language == "zh" else (f"{v.name}" if not v.desc else f"{v.name}, {v.desc}")
for k, v in spg_type.properties.items()
if k not in self.ignored_properties
}
)
properties.update(
{
f"{v.name}#{v.object_type_name_en}": (
f"{v.name_zh},类型是{v.object_type_name_zh}"
if not v.desc
else f"{v.name_zh}{v.desc},类型是{v.object_type_name_zh}"
) if self.language == "zh" else (
f"{v.name}, the type is {v.object_type_name_en}"
if not v.desc
else f"{v.name}{v.desc}, the type is {v.object_type_name_en}"
)
for k, v in spg_type.relations.items()
if not v.is_dynamic and k not in self.ignored_relations
}
)
constraint.update({"properties": properties})
spo_list.append({type_name: constraint})
self.schema_list = spo_list

12
kag/common/__init__.py Normal file
View File

@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

1267
kag/common/arks_pb2.py Normal file

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,10 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,184 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import importlib
import inspect
import os
import sys
from abc import ABC
from string import Template
from typing import List
BUILDER_PROMPT_PATH = "kag.builder.prompt"
SOLVER_PROMPT_PATH = "kag.solver.prompt"
class PromptOp(ABC):
"""
Provides a template for generating and parsing prompts related to specific business scenes.
Subclasses must implement the template strings for specific languages (English or Chinese)
and override the `template_variables` and `parse_response` methods.
"""
"""English template string"""
template_en: str = ""
"""Chinese template string"""
template_zh: str = ""
def __init__(self, language: str, **kwargs):
"""
Initializes the PromptOp instance with the selected language.
Args:
language (str): The language for the prompt, should be either "en" or "zh".
Raises:
AssertionError: If the provided language is not supported.
"""
assert language in ["en", "zh"], f"language[{language}] is not supported."
self.template = self.template_en if language == "en" else self.template_zh
self.language = language
self.template_variables_value = {}
if "project_id" in kwargs:
self.project_id = kwargs["project_id"]
@property
def template_variables(self) -> List[str]:
"""
Gets the list of template variables.
Must be implemented by subclasses.
Returns:
- List[str]: A list of template variable names.
Raises:
- NotImplementedError: If the subclass does not implement this method.
"""
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `template_variables` method."
)
def process_template_string_to_avoid_dollar_problem(self, template_string):
new_template_str = template_string.replace('$', '$$')
for var in self.template_variables:
new_template_str = new_template_str.replace(f'$${var}', f'${var}')
return new_template_str
def build_prompt(self, variables) -> str:
"""
Build a prompt based on the template and provided variables.
This method replaces placeholders in the template with actual variable values.
If a variable is not provided, it defaults to an empty string.
Parameters:
- variables: A dictionary containing variable names and their corresponding values.
Returns:
- A string or list of strings, depending on the template content.
"""
self.template_variables_value = variables
template_string = self.process_template_string_to_avoid_dollar_problem(self.template)
template = Template(template_string)
return template.substitute(**variables)
def parse_response(self, response: str, **kwargs):
"""
Parses the response string.
Must be implemented by subclasses.
Parameters:
- response (str): The response string to be parsed.
Raises:
- NotImplementedError: If the subclass does not implement this method.
"""
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `parse_response` method."
)
@classmethod
def load(cls, biz_scene: str, type: str):
"""
Dynamically loads the corresponding PromptOp subclass object based on the business scene and type.
Parameters:
- biz_scene (str): The name of the business scene.
- type (str): The type of prompt.
Returns:
- subclass of PromptOp: The loaded PromptOp subclass object.
Raises:
- ImportError: If the specified module or class does not exist.
"""
dir_paths = [
os.path.join(os.getenv("KAG_PROJECT_ROOT_PATH", ""), "builder", "prompt"),
os.path.join(os.getenv("KAG_PROJECT_ROOT_PATH", ""), "solver", "prompt"),
]
module_paths = [
'.'.join([BUILDER_PROMPT_PATH, biz_scene, type]),
'.'.join([SOLVER_PROMPT_PATH, biz_scene, type]),
'.'.join([BUILDER_PROMPT_PATH, 'default', type]),
'.'.join([SOLVER_PROMPT_PATH, 'default', type]),
]
def find_class_from_dir(dir, type):
sys.path.append(dir)
for root, dirs, files in os.walk(dir):
for file in files:
if file.endswith(".py") and file.startswith(f"{type}."):
module_name = file[:-3]
try:
module = importlib.import_module(module_name)
except ImportError:
continue
cls_found = find_class_from_module(module)
if cls_found:
return cls_found
return None
def find_class_from_module(module):
classes = inspect.getmembers(module, inspect.isclass)
for class_name, class_obj in classes:
import kag
if issubclass(class_obj, kag.common.base.prompt_op.PromptOp) and inspect.getmodule(class_obj) == module:
return class_obj
return None
for dir_path in dir_paths:
try:
cls_found = find_class_from_dir(dir_path, type)
if cls_found:
return cls_found
except ImportError:
continue
for module_path in module_paths:
try:
module = importlib.import_module(module_path)
cls_found = find_class_from_module(module)
if cls_found:
return cls_found
except ModuleNotFoundError:
continue
raise ValueError(f'Not support prompt with biz_scene[{biz_scene}] and type[{type}]')

View File

@ -0,0 +1,10 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,112 @@
import re
import string
from collections import Counter
def normalize_answer(s):
"""
Normalizes the answer string.
This function standardizes the answer string through a series of steps including removing articles,
fixing whitespace, removing punctuation, and converting text to lowercase. This ensures consistency
and fairness when comparing answers.
Parameters:
s (str): The answer string to be standardized.
Returns:
str: The standardized answer string.
"""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return str(text).lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
"""
Calculates the F1 score between the predicted answer and the ground truth.
The F1 score is the harmonic mean of precision and recall, used to evaluate the model's performance in question answering tasks.
Parameters:
prediction (str): The predicted answer from the model.
ground_truth (str): The actual ground truth answer.
Returns:
tuple: A tuple containing the F1 score, precision, and recall.
"""
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)
ZERO_METRIC = (0, 0, 0)
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
# Calculate the number of matching words between the predicted and ground truth answers
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return ZERO_METRIC
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1, precision, recall
def exact_match_score(prediction, ground_truth):
"""
Calculates the exact match score between a predicted answer and the ground truth answer.
This function normalizes both the predicted answer and the ground truth answer before comparing them.
Normalization is performed to ensure that non-essential differences such as spaces and case are ignored.
Parameters:
prediction (str): The predicted answer string.
ground_truth (str): The ground truth answer string.
Returns:
int: 1 if the predicted answer exactly matches the ground truth answer, otherwise 0.
"""
return 1 if normalize_answer(prediction) == normalize_answer(ground_truth) else 0
def get_em_f1(prediction, gold):
"""
Calculates the Exact Match (EM) score and F1 score between the prediction and the gold standard.
This function evaluates the performance of a model in text similarity tasks by calculating the EM score and F1 score to measure the accuracy of the predictions.
Parameters:
prediction (str): The output predicted by the model.
gold (str): The gold standard output (i.e., the correct output).
Returns:
tuple: A tuple containing two floats, the EM score and the F1 score. The EM score represents the exact match accuracy, while the F1 score is a combination of precision and recall.
"""
em = exact_match_score(prediction, gold)
f1, precision, recall = f1_score(prediction, gold)
return float(em), f1

View File

@ -0,0 +1,65 @@
from typing import List
from .evaUtils import get_em_f1
class Evaluate():
"""
provide evaluation for benchmarks, such as emf1answer_similarity, answer_correctness
"""
def __init__(self, embedding_factory = "text-embedding-ada-002"):
self.embedding_factory = embedding_factory
def evaForSimilarity(self, predictionlist: List[str], goldlist: List[str]):
"""
evaluate the similarity between prediction and gold #TODO
"""
# data_samples = {
# 'question': [],
# 'answer': predictionlist,
# 'ground_truth': goldlist
# }
# dataset = Dataset.from_dict(data_samples)
# run_config = RunConfig(timeout=240, thread_timeout=240, max_workers=16)
# embeddings = embedding_factory(self.embedding_factory, run_config)
#
# score = evaluate(dataset, metrics=[answer_similarity], embeddings = embeddings, run_config=run_config)
# return np.average(score.to_pandas()[['answer_similarity']])
return 0.0
def getBenchMark(self, predictionlist: List[str], goldlist: List[str]):
"""
Calculates and returns evaluation metrics between predictions and ground truths.
This function evaluates the match between predictions and ground truths by calculating
the exact match (EM) and F1 score, as well as answer similarity.
Parameters:
predictionlist (List[str]): List of predicted values from the model.
goldlist (List[str]): List of actual ground truth values.
Returns:
dict: Dictionary containing EM, F1 score, and answer similarity.
"""
# Initialize total metrics
total_metrics = {'em': 0.0, 'f1': 0.0, 'answer_similarity': 0.0}
# Iterate over prediction and gold lists to calculate EM and F1 scores
for prediction, gold in zip(predictionlist, goldlist):
em, f1 = get_em_f1(prediction, gold) # Call external function to calculate EM and F1
total_metrics['em'] += em # Accumulate EM score
total_metrics['f1'] += f1 # Accumulate F1 score
# Calculate average EM and F1 scores
total_metrics['em'] /= len(predictionlist)
total_metrics['f1'] /= len(predictionlist)
# Call method to calculate answer similarity
total_metrics['answer_similarity'] = self.evaForSimilarity(predictionlist, goldlist)
# Return evaluation metrics dictionary
return total_metrics

View File

@ -0,0 +1,33 @@
[project]
with_server = True
host_addr = http://127.0.0.1:8887
[vectorizer]
vectorizer = knext.common.vectorizer.OpenAIVectorizer
model = bge-m3
api_key = EMPTY
base_url = http://127.0.0.1:11434/v1
vector_dimensions = 1024
[llm]
client_type = ollama
base_url = http://localhost:11434/api/generate
model = llama3.1
[indexer]
with_semantic = False
similarity_threshold = 0.8
[retriever]
with_semantic = False
pagerank_threshold = 0.9
match_threshold = 0.8
top_k = 10
[schedule]
interval_minutes = -1
[log]
level = INFO

117
kag/common/env.py Normal file
View File

@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
import os
import sys
from configparser import ConfigParser as CP
from pathlib import Path
from typing import Union, Optional
import kag.common as common
class ConfigParser(CP):
def __init__(self,defaults=None):
CP.__init__(self,defaults=defaults)
def optionxform(self, optionstr):
return optionstr
LOCAL_SCHEMA_URL = "http://localhost:8887"
DEFAULT_KAG_CONFIG_FILE_NAME = "default_config.cfg"
DEFAULT_KAG_CONFIG_PATH = os.path.join(common.__path__[0], DEFAULT_KAG_CONFIG_FILE_NAME)
KAG_CFG_PREFIX = "KAG"
def init_env():
"""Initialize environment to use command-line tool from inside a project
dir. This sets the Scrapy settings module and modifies the Python path to
be able to locate the project module.
"""
project_cfg, root_path = get_config()
init_kag_config(Path(root_path) / "kag_config.cfg")
def get_config():
"""
Get kag config file as a ConfigParser.
"""
local_cfg_path = _closest_cfg()
local_cfg = ConfigParser()
local_cfg.read(local_cfg_path)
projdir = ""
if local_cfg_path:
projdir = str(Path(local_cfg_path).parent)
if projdir not in sys.path:
sys.path.append(projdir)
return local_cfg, projdir
def _closest_cfg(
path: Union[str, os.PathLike] = ".",
prev_path: Optional[Union[str, os.PathLike]] = None,
) -> str:
"""
Return the path to the closest .kag.cfg file by traversing the current
directory and its parents
"""
if prev_path is not None and str(path) == str(prev_path):
return ""
path = Path(path).resolve()
cfg_file = path / "kag_config.cfg"
if cfg_file.exists():
return str(cfg_file)
return _closest_cfg(path.parent, path)
def get_cfg_files():
"""
Get global and local kag config files and paths.
"""
local_cfg_path = _closest_cfg()
local_cfg = ConfigParser()
local_cfg.read(local_cfg_path)
if local_cfg_path:
projdir = str(Path(local_cfg_path).parent)
if projdir not in sys.path:
sys.path.append(projdir)
return local_cfg, local_cfg_path
def init_kag_config(config_path: Union[str, Path] = None):
if not config_path or isinstance(config_path, Path) and not config_path.exists():
config_path = DEFAULT_KAG_CONFIG_PATH
kag_cfg = ConfigParser()
kag_cfg.read(config_path)
os.environ["KAG_PROJECT_ROOT_PATH"] = os.path.abspath(os.path.dirname(config_path))
for section in kag_cfg.sections():
sec_cfg = {}
for key, value in kag_cfg.items(section):
item_cfg_key = f"{KAG_CFG_PREFIX}_{section}_{key}".upper()
os.environ[item_cfg_key] = value
sec_cfg[key] = value
sec_cfg_key = f"{KAG_CFG_PREFIX}_{section}".upper()
os.environ[sec_cfg_key] = str(sec_cfg)
if section == "log":
for key, value in kag_cfg.items(section):
if key == "level":
logging.basicConfig(level=logging.getLevelName(value))
# neo4j log level set to be default error
logging.getLogger("neo4j.notifications").setLevel(logging.ERROR)
logging.getLogger("neo4j.io").setLevel(logging.INFO)
logging.getLogger("neo4j.pool").setLevel(logging.INFO)

View File

@ -0,0 +1,10 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,318 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from abc import ABC, abstractmethod
class GraphStore(ABC):
"""
Abstract base class for a graph store that defines standard interfaces for graph data operations.
This class specifies abstract methods to ensure subclasses implement specific graph operations such as node CRUD, relationship handling, and index management.
"""
@abstractmethod
def close(self):
"""
Close the graph store resources.
"""
pass
@abstractmethod
def initialize_schema(self, schema):
"""
Initialize the graph schema.
Parameters:
- schema: Definition of the graph schema.
"""
pass
@abstractmethod
def upsert_node(self, label, properties, id_key="id", extra_labels=("Entity",)):
"""
Insert or update a single node.
Parameters:
- label: Label of the node.
- properties: Properties of the node.
- id_key: Property key used as the unique identifier.
- extra_labels: Additional labels for the node.
"""
pass
@abstractmethod
def upsert_nodes(self, label, properties_list, id_key="id", extra_labels=("Entity",)):
"""
Insert or update multiple nodes.
Parameters:
- label: Label of the nodes.
- properties_list: List of properties for the nodes.
- id_key: Property key used as the unique identifier.
- extra_labels: Additional labels for the nodes.
"""
pass
@abstractmethod
def batch_preprocess_node_properties(self, node_batch, extra_labels=("Entity",)):
"""
Batch preprocess node properties.
Parameters:
- node_batch: A batch of nodes.
- extra_labels: Additional labels for the nodes.
"""
pass
@abstractmethod
def get_node(self, label, id_value, id_key="id"):
"""
Get a node by label and identifier.
Parameters:
- label: Label of the node.
- id_value: Unique identifier value of the node.
- id_key: Property key used as the unique identifier.
Returns:
- The matching node.
"""
pass
@abstractmethod
def delete_node(self, label, id_value, id_key="id"):
"""
Delete a specified node.
Parameters:
- label: Label of the node.
- id_value: Unique identifier value of the node.
- id_key: Property key used as the unique identifier.
"""
pass
@abstractmethod
def delete_nodes(self, label, id_values, id_key="id"):
"""
Delete multiple nodes.
Parameters:
- label: Label of the nodes.
- id_values: List of unique identifier values for the nodes.
- id_key: Property key used as the unique identifier.
"""
pass
@abstractmethod
def upsert_relationship(self, start_node_label, start_node_id_value,
end_node_label, end_node_id_value,
rel_type, properties, upsert_nodes=True,
start_node_id_key="id", end_node_id_key="id"):
"""
Insert or update a relationship.
Parameters:
- start_node_label: Label of the start node.
- start_node_id_value: Unique identifier value of the start node.
- end_node_label: Label of the end node.
- end_node_id_value: Unique identifier value of the end node.
- rel_type: Type of the relationship.
- properties: Properties of the relationship.
- upsert_nodes: Whether to insert or update nodes.
- start_node_id_key: Property key used as the unique identifier for the start node.
- end_node_id_key: Property key used as the unique identifier for the end node.
"""
pass
@abstractmethod
def upsert_relationships(self, start_node_label, end_node_label, rel_type,
relationships, upsert_nodes=True, start_node_id_key="id",
end_node_id_key="id"):
"""
Insert or update multiple relationships.
Parameters:
- start_node_label: Label of the start node.
- end_node_label: Label of the end node.
- rel_type: Type of the relationship.
- relationships: List of relationships.
- upsert_nodes: Whether to insert or update nodes.
- start_node_id_key: Property key used as the unique identifier for the start node.
- end_node_id_key: Property key used as the unique identifier for the end node.
"""
pass
@abstractmethod
def delete_relationship(self, start_node_label, start_node_id_value,
end_node_label, end_node_id_value,
rel_type, start_node_id_key="id", end_node_id_key="id"):
"""
Delete a specified relationship.
Parameters:
- start_node_label: Label of the start node.
- start_node_id_value: Unique identifier value of the start node.
- end_node_label: Label of the end node.
- end_node_id_value: Unique identifier value of the end node.
- rel_type: Type of the relationship.
- start_node_id_key: Property key used as the unique identifier for the start node.
- end_node_id_key: Property key used as the unique identifier for the end node.
"""
pass
@abstractmethod
def delete_relationships(self, start_node_label, start_node_id_values,
end_node_label, end_node_id_values, rel_type,
start_node_id_key="id", end_node_id_key="id"):
"""
Delete multiple relationships.
Parameters:
- start_node_label: Label of the start node.
- start_node_id_values: List of unique identifier values for the start nodes.
- end_node_label: Label of the end node.
- end_node_id_values: List of unique identifier values for the end nodes.
- rel_type: Type of the relationship.
- start_node_id_key: Property key used as the unique identifier for the start node.
- end_node_id_key: Property key used as the unique identifier for the end node.
"""
pass
@abstractmethod
def create_index(self, label, property_key, index_name=None):
"""
Create a node index.
Parameters:
- label: Label of the node.
- property_key: Property key used for indexing.
- index_name: Name of the index (optional).
"""
pass
@abstractmethod
def create_text_index(self, labels, property_keys, index_name=None):
"""
Create a text index.
Parameters:
- labels: List of node labels.
- property_keys: List of property keys used for indexing.
- index_name: Name of the index (optional).
"""
pass
@abstractmethod
def create_vector_index(self, label, property_key, index_name=None,
vector_dimensions=768, metric_type="cosine",
hnsw_m=None, hnsw_ef_construction=None):
"""
Create a vector index.
Parameters:
- label: Label of the node.
- property_key: Property key used for indexing.
- index_name: Name of the index (optional).
- vector_dimensions: Dimensionality of the vectors, default is 768.
- metric_type: Type of distance measure, default is "cosine".
- hnsw_m: m parameter of the HNSW algorithm, default to None (for m=16)
- hnsw_ef_construction: ef_construction parameter of the HNSW algorithm, default to None (for ef_construction=100)
"""
pass
@abstractmethod
def delete_index(self, index_name):
"""
Delete a specified index.
Parameters:
- index_name: Name of the index.
"""
pass
@abstractmethod
def text_search(self, query_string, label_constraints=None, topk=10, index_name=None):
"""
Perform a text search.
Parameters:
- query_string: Query string.
- label_constraints: Label constraints (optional).
- topk: Number of top results to return, default is 10.
- index_name: Name of the index (optional).
Returns:
- List of search results.
"""
pass
@abstractmethod
def vector_search(self, label, property_key, query_text_or_vector, topk=10, index_name=None, ef_search=None):
"""
Perform a vector search.
Parameters:
- label: Label of the node.
- property_key: Property key used for indexing.
- query_text_or_vector: Query text or vector.
- topk: Number of top results to return, default is 10.
- index_name: Name of the index (optional).
- ef_search: ef_search parameter of the HNSW algorithm, specify number of potential candicates
Returns:
- List of search results.
"""
pass
@abstractmethod
def execute_pagerank(self, iterations=20, damping_factor=0.85):
"""
Execute the PageRank algorithm.
Parameters:
- iterations: Number of iterations, default is 20.
- damping_factor: Damping factor, default is 0.85.
"""
pass
@abstractmethod
def get_pagerank_scores(self, start_nodes, target_type):
"""
Get PageRank scores.
Parameters:
- start_nodes: Start nodes.
- target_type: Target node type.
Returns:
- PageRank scores.
"""
pass
@abstractmethod
def run_script(self, script):
"""
Execute a script.
Parameters:
- script: Script to be executed.
"""
pass
@abstractmethod
def get_all_entity_labels(self):
"""
Get all entity labels.
Returns:
- List of entity labels.
"""
pass

View File

@ -0,0 +1,903 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
import os
import re
import threading
import time
from abc import ABCMeta
import schedule
from neo4j import GraphDatabase
from kag.common.graphstore.graph_store import GraphStore
from kag.common.semantic_infer import SemanticEnhance
from kag.common.utils import escape_single_quotes
from knext.schema.model.base import IndexTypeEnum
logger = logging.getLogger(__name__)
class SingletonMeta(ABCMeta):
"""
Thread-safe Singleton metaclass
"""
_instances = {}
_lock = threading.Lock()
def __call__(cls, *args, **kwargs):
uri = kwargs.get('uri')
user = kwargs.get('user')
password = kwargs.get('password')
database = kwargs.get('database', 'neo4j')
key = (cls, uri, user, password, database)
with cls._lock:
if key not in cls._instances:
cls._instances[key] = super().__call__(*args, **kwargs)
return cls._instances[key]
class Neo4jClient(GraphStore, metaclass=SingletonMeta):
def __init__(self, uri, user, password, database="neo4j", init_type="write", interval_minutes=10):
self._driver = GraphDatabase.driver(uri, auth=(user, password))
logger.info(f"init Neo4jClient uri: {uri} database: {database}")
self._database = database
self._lucene_special_chars = "\\+-!():^[]\"{}~*?|&/"
self._lucene_pattern = self._get_lucene_pattern()
self._simple_ident = "[A-Za-z_][A-Za-z0-9_]*"
self._simple_ident_pattern = re.compile(self._simple_ident)
self._vec_meta = dict()
self._vec_meta_ts = 0.0
self._vec_meta_timeout = 60.0
self._vectorizer = None
self._allGraph = "allGraph_0"
if init_type == "write":
self._labels = self._create_unique_constraint()
self._create_all_graph(self._allGraph)
self.schedule_constraint(interval_minutes)
# self.create_text_index(["Chunk"], ["content"])
self.refresh_vector_index_meta(force=True)
def close(self):
self._driver.close()
def schedule_constraint(self, interval_minutes):
def job():
try:
self._labels = self._create_unique_constraint()
self._update_pagerank_graph()
except Exception as e:
import traceback
logger.error(f"Error run scheduled job: {traceback.format_exc()}")
def run_scheduled_tasks():
while True:
schedule.run_pending()
time.sleep(1)
if interval_minutes > 0:
schedule.every(interval_minutes).minutes.do(job)
scheduler_thread = threading.Thread(target=run_scheduled_tasks, daemon=True)
scheduler_thread.start()
def get_all_entity_labels(self):
with self._driver.session(database=self._database) as session:
result = session.run("CALL db.labels()")
labels = [record[0] for record in result]
return labels
def run_script(self, script):
with self._driver.session(database=self._database) as session:
return list(session.run(script))
def _create_unique_constraint(self):
with self._driver.session(database=self._database) as session:
result = session.run("CALL db.labels()")
labels = [record[0] for record in result if record[0] != "Entity"]
for label in labels:
self._create_unique_index_constraint(self, label, session)
return labels
@staticmethod
def _create_unique_index_constraint(self, label, session):
constraint_name = f"uniqueness_{label}_id"
create_constraint_query = f"CREATE CONSTRAINT {self._escape_neo4j(constraint_name)} IF NOT EXISTS FOR (n:{self._escape_neo4j(label)}) REQUIRE n.id IS UNIQUE"
try:
result = session.run(create_constraint_query)
result.consume()
logger.debug(f"Unique constraint created for constraint_name: {constraint_name}")
except Exception as e:
logger.debug(f"warn creating constraint for {constraint_name}: {e}")
self._create_index_constraint(self, label, session)
@staticmethod
def _create_index_constraint(self, label, session):
index_name = f"index_{label}_id"
create_constraint_query = f"CREATE INDEX {self._escape_neo4j(index_name)} IF NOT EXISTS FOR (n:{self._escape_neo4j(label)}) ON (n.id)"
try:
result = session.run(create_constraint_query)
result.consume()
logger.debug(f"index constraint created for constraint_name: {index_name}")
except Exception as e:
logger.warn(f"warn creating index constraint for {index_name}: {e}")
def _update_pagerank_graph(self):
all_graph_0 = "allGraph_0"
all_graph_1 = "allGraph_1"
if self._allGraph == all_graph_0:
all_graph = all_graph_1
else:
all_graph = all_graph_0
logger.debug(f"update pagerank graph for {all_graph}")
self._create_all_graph(all_graph)
logger.debug(f"drop old pagerank graph for {self._allGraph}")
self._drop_all_graph(self._allGraph)
self._allGraph = all_graph
def create_pagerank_graph(self):
self._drop_all_graph(self._allGraph)
self._create_all_graph(self._allGraph)
def initialize_schema(self, schema_types):
for spg_type in schema_types:
label = spg_type
properties = schema_types[spg_type].properties
if properties:
for property_key in properties:
if property_key == "name":
self.create_vector_index(label, property_key)
index_type = properties[property_key].index_type
if index_type:
if index_type == IndexTypeEnum.Text:
pass
elif index_type == IndexTypeEnum.Vector:
self.create_vector_index(label, property_key)
elif index_type == IndexTypeEnum.TextAndVector:
self.create_vector_index(label, property_key)
else:
logger.info(f"Undefined IndexTypeEnum {index_type}")
labels, property_keys = self._collect_text_index_info(schema_types)
self.create_text_index(labels, property_keys)
self.create_vector_index("Entity", "name")
self.create_vector_index("Entity", "desc")
if bool(os.getenv("KAG_RETRIEVER_SEMANTIC_ENHANCE")):
self.create_vector_index(label=SemanticEnhance.concept_label, property_key="name")
self.refresh_vector_index_meta(force=True)
def _collect_text_index_info(self, schema_types):
labels = {}
property_keys = {}
for spg_type in schema_types:
label = spg_type
properties = schema_types[spg_type].properties
if properties:
label_property_keys = {}
for property_key in properties:
index_type = properties[property_key].index_type
if property_key == "name" or index_type and index_type in (IndexTypeEnum.Text, IndexTypeEnum.TextAndVector):
label_property_keys[property_key] = True
if label_property_keys:
labels[label] = True
property_keys.update(label_property_keys)
return tuple(labels.keys()), tuple(property_keys.keys())
def upsert_node(self, label, properties, id_key="id", extra_labels=("Entity",)):
self._preprocess_node_properties(label, properties, extra_labels)
with self._driver.session(database=self._database) as session:
if label not in self._labels:
self._create_unique_index_constraint(self, label, session)
try:
return session.execute_write(self._upsert_node, self, label, id_key, properties, extra_labels)
except Exception as e:
logger.error(f"upsert_node label:{label} properties:{properties} Exception: {e}")
return None
@staticmethod
def _upsert_node(tx, self, label, id_key, properties, extra_labels):
if not label:
logger.warning("label cannot be None or empty strings")
return None
query = (f"MERGE (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: $properties.{self._escape_neo4j(id_key)}}}) "
"SET n += $properties ")
if extra_labels:
query += f", n:{':'.join(self._escape_neo4j(extra_label) for extra_label in extra_labels)} "
query += "RETURN n"
result = tx.run(query, properties=properties)
return result.single()[0]
def upsert_nodes(self, label, properties_list, id_key="id", extra_labels=("Entity",)):
self._preprocess_node_properties_list(label, properties_list, extra_labels)
with self._driver.session(database=self._database) as session:
if label not in self._labels:
self._create_unique_index_constraint(self, label, session)
try:
return session.execute_write(self._upsert_nodes, self, label, properties_list, id_key, extra_labels)
except Exception as e:
logger.error(f"upsert_nodes label:{label} properties:{properties_list} Exception: {e}")
return None
@staticmethod
def _upsert_nodes(tx, self, label, properties_list, id_key, extra_labels):
if not label:
logger.warning("label cannot be None or empty strings")
return None
query = ("UNWIND $properties_list AS properties "
f"MERGE (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: properties.{self._escape_neo4j(id_key)}}}) "
"SET n += properties ")
if extra_labels:
query += f", n:{':'.join(self._escape_neo4j(extra_label) for extra_label in extra_labels)} "
query += "RETURN n"
result = tx.run(query, properties_list=properties_list)
return [record['n'] for record in result]
def _get_embedding_vector(self, properties, vector_field):
for property_key, property_value in properties.items():
field_name = self._create_vector_field_name(property_key)
if field_name != vector_field:
continue
if not property_value:
return None
if not isinstance(property_value, str):
message = f"property {property_key!r} must be string to generate embedding vector"
raise RuntimeError(message)
try:
vector = self.vectorizer.vectorize(property_value)
return vector
except Exception as e:
logger.info(f"An error occurred while vectorizing property {property_key!r}: {e}")
return None
return None
def _preprocess_node_properties(self, label, properties, extra_labels):
if self._vectorizer is None:
return
self.refresh_vector_index_meta()
vec_meta = self._vec_meta
labels = [label]
if extra_labels:
labels.extend(extra_labels)
for label in labels:
if label not in vec_meta:
continue
for vector_field in vec_meta[label]:
if vector_field in properties:
continue
embedding_vector = self._get_embedding_vector(properties, vector_field)
if embedding_vector is not None:
properties[vector_field] = embedding_vector
def _preprocess_node_properties_list(self, label, properties_list, extra_labels):
for properties in properties_list:
self._preprocess_node_properties(label, properties, extra_labels)
def batch_preprocess_node_properties(self, node_batch, extra_labels=("Entity",)):
if self._vectorizer is None:
return
class EmbeddingVectorPlaceholder(object):
def __init__(self, number, properties, vector_field, property_key, property_value):
self._number = number
self._properties = properties
self._vector_field = vector_field
self._property_key = property_key
self._property_value = property_value
self._embedding_vector = None
def replace(self):
if self._embedding_vector is not None:
self._properties[self._vector_field] = self._embedding_vector
def __repr__(self):
return repr(self._number)
class EmbeddingVectorManager(object):
def __init__(self):
self._placeholders = []
def get_placeholder(self, graph_store, properties, vector_field):
for property_key, property_value in properties.items():
field_name = graph_store._create_vector_field_name(property_key)
if field_name != vector_field:
continue
if not property_value:
return None
if not isinstance(property_value, str):
message = f"property {property_key!r} must be string to generate embedding vector"
raise RuntimeError(message)
num = len(self._placeholders)
placeholder = EmbeddingVectorPlaceholder(num, properties, vector_field, property_key, property_value)
self._placeholders.append(placeholder)
return placeholder
return None
def _get_text_batch(self):
text_batch = dict()
for placeholder in self._placeholders:
property_value = placeholder._property_value
if property_value not in text_batch:
text_batch[property_value] = list()
text_batch[property_value].append(placeholder)
return text_batch
def _generate_vectors(self, vectorizer, text_batch):
texts = list(text_batch)
vectors = vectorizer.vectorize(texts)
return vectors
def _fill_vectors(self, vectors, text_batch):
for vector, (_text, placeholders) in zip(vectors, text_batch.items()):
for placeholder in placeholders:
placeholder._embedding_vector = vector
def batch_vectorize(self, vectorizer):
text_batch = self._get_text_batch()
vectors = self._generate_vectors(vectorizer, text_batch)
self._fill_vectors(vectors, text_batch)
def patch(self):
for placeholder in self._placeholders:
placeholder.replace()
manager = EmbeddingVectorManager()
self.refresh_vector_index_meta()
vec_meta = self._vec_meta
for node_item in node_batch:
label, properties = node_item
labels = [label]
if extra_labels:
labels.extend(extra_labels)
for label in labels:
if label not in vec_meta:
continue
for vector_field in vec_meta[label]:
if vector_field in properties:
continue
placeholder = manager.get_placeholder(self, properties, vector_field)
if placeholder is not None:
properties[vector_field] = placeholder
manager.batch_vectorize(self._vectorizer)
manager.patch()
def get_node(self, label, id_value, id_key="id"):
with self._driver.session(database=self._database) as session:
return session.execute_read(self._get_node, self, label, id_key, id_value)
@staticmethod
def _get_node(tx, self, label, id_key, id_value):
query = f"MATCH (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: $id_value}}) RETURN n"
result = tx.run(query, id_value=id_value)
single_result = result.single()
# print(f"single_result: {single_result}")
if single_result is not None:
return single_result[0]
else:
return None
def delete_node(self, label, id_value, id_key="id"):
with self._driver.session(database=self._database) as session:
try:
session.execute_write(self._delete_node, self, label, id_key, id_value)
except Exception as e:
logger.error(f"delete_node label:{label} Exception: {e}")
@staticmethod
def _delete_node(tx, self, label, id_key, id_value):
query = f"MATCH (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: $id_value}}) DETACH DELETE n"
tx.run(query, id_value=id_value)
def delete_nodes(self, label, id_values, id_key="id"):
with self._driver.session(database=self._database) as session:
session.execute_write(self._delete_nodes, self, label, id_key, id_values)
@staticmethod
def _delete_nodes(tx, self, label, id_key, id_values):
query = f"UNWIND $id_values AS id_value MATCH (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: id_value}}) DETACH DELETE n"
tx.run(query, id_values=id_values)
def upsert_relationship(self, start_node_label, start_node_id_value,
end_node_label, end_node_id_value, rel_type,
properties, upsert_nodes=True, start_node_id_key="id", end_node_id_key="id"):
rel_type = self._escape_neo4j(rel_type)
with self._driver.session(database=self._database) as session:
try:
return session.execute_write(self._upsert_relationship, self, start_node_label, start_node_id_key,
start_node_id_value, end_node_label, end_node_id_key,
end_node_id_value, rel_type, properties, upsert_nodes)
except Exception as e:
logger.error(f"upsert_relationship rel_type:{rel_type} properties:{properties} Exception: {e}")
return None
@staticmethod
def _upsert_relationship(tx, self, start_node_label, start_node_id_key, start_node_id_value,
end_node_label, end_node_id_key, end_node_id_value,
rel_type, properties, upsert_nodes):
if not start_node_label or not end_node_label or not rel_type:
logger.warning("start_node_label, end_node_label, and rel_type cannot be None or empty strings")
return None
if upsert_nodes:
query = (
f"MERGE (a:{self._escape_neo4j(start_node_label)} {{{self._escape_neo4j(start_node_id_key)}: $start_node_id_value}}) "
f"MERGE (b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: $end_node_id_value}}) "
f"MERGE (a)-[r:{self._escape_neo4j(rel_type)}]->(b) SET r += $properties RETURN r"
)
else:
query = (
f"MATCH (a:{self._escape_neo4j(start_node_label)} {{{self._escape_neo4j(start_node_id_key)}: $start_node_id_value}}), "
f"(b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: $end_node_id_value}}) "
f"MERGE (a)-[r:{self._escape_neo4j(rel_type)}]->(b) SET r += $properties RETURN r"
)
result = tx.run(query, start_node_id_value=start_node_id_value,
end_node_id_value=end_node_id_value, properties=properties)
return result.single()
def upsert_relationships(self, start_node_label, end_node_label, rel_type, relations,
upsert_nodes=True, start_node_id_key="id", end_node_id_key="id"):
with self._driver.session(database=self._database) as session:
try:
return session.execute_write(self._upsert_relationships, self, relations, start_node_label,
start_node_id_key, end_node_label, end_node_id_key, rel_type, upsert_nodes)
except Exception as e:
logger.error(f"upsert_relationships rel_type:{rel_type} relations:{relations} Exception: {e}")
return None
@staticmethod
def _upsert_relationships(tx, self, relations, start_node_label, start_node_id_key,
end_node_label, end_node_id_key, rel_type, upsert_nodes):
if not start_node_label or not end_node_label or not rel_type:
logger.warning("start_node_label, end_node_label, and rel_type cannot be None or empty strings")
return None
if upsert_nodes:
query = (
"UNWIND $relations AS relationship "
f"MERGE (a:{self._escape_neo4j(start_node_label)} {{{self._escape_neo4j(start_node_id_key)}: relationship.start_node_id}}) "
f"MERGE (b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: relationship.end_node_id}}) "
f"MERGE (a)-[r:{self._escape_neo4j(rel_type)}]->(b) SET r += relationship.properties RETURN r"
)
else:
query = (
"UNWIND $relations AS relationship "
f"MATCH (a:{self._escape_neo4j(start_node_label)} {{{self._escape_neo4j(start_node_id_key)}: relationship.start_node_id}}) "
f"MATCH (b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: relationship.end_node_id}}) "
f"MERGE (a)-[r:{self._escape_neo4j(rel_type)}]->(b) SET r += relationship.properties RETURN r"
)
result = tx.run(query, relations=relations,
start_node_label=start_node_label, start_node_id_key=start_node_id_key,
end_node_label=end_node_label, end_node_id_key=end_node_id_key,
rel_type=rel_type)
return [record['r'] for record in result]
def delete_relationship(self, start_node_label, start_node_id_value,
end_node_label, end_node_id_value, rel_type,
start_node_id_key="id", end_node_id_key="id"):
with self._driver.session(database=self._database) as session:
try:
session.execute_write(self._delete_relationship, self, start_node_label, start_node_id_key,
start_node_id_value, end_node_label, end_node_id_key,
end_node_id_value, rel_type)
except Exception as e:
logger.error(f"delete_relationship rel_type:{rel_type} Exception: {e}")
@staticmethod
def _delete_relationship(tx, self, start_node_label, start_node_id_key, start_node_id_value,
end_node_label, end_node_id_key, end_node_id_value, rel_type):
query = (
f"MATCH (a:{self._escape_neo4j(start_node_label)} {{{self._escape_neo4j(start_node_id_key)}: $start_node_id_value}})-[r:{self._escape_neo4j(rel_type)}]->"
f"(b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: $end_node_id_value}}) DELETE r"
)
tx.run(query, start_node_id_value=start_node_id_value, end_node_id_value=end_node_id_value)
def delete_relationships(self, start_node_label, start_node_id_values,
end_node_label, end_node_id_values, rel_type,
start_node_id_key="id", end_node_id_key="id"):
with self._driver.session(database=self._database) as session:
session.execute_write(self._delete_relationships, self,
start_node_label, start_node_id_key, start_node_id_values,
end_node_label, end_node_id_key, end_node_id_values, rel_type)
@staticmethod
def _delete_relationships(tx, self, start_node_label, start_node_id_key, start_node_id_values,
end_node_label, end_node_id_key, end_node_id_values, rel_type):
query = (
"UNWIND $start_node_id_values AS start_node_id_value "
"UNWIND $end_node_id_values AS end_node_id_value "
f"MATCH (a:{self._escape_neo4j(start_node_label)} {{{self._escape_neo4j(start_node_id_key)}: start_node_id_value}})-[r:{self._escape_neo4j(rel_type)}]->"
f"(b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: end_node_id_value}}) DELETE r"
)
tx.run(query, start_node_id_values=start_node_id_values, end_node_id_values=end_node_id_values)
def _get_lucene_pattern(self):
string = re.escape(self._lucene_special_chars)
pattern = "([" + string + "])"
pattern = re.compile(pattern)
return pattern
def _escape_lucene(self, string):
result = self._lucene_pattern.sub(r"\\\1", string)
return result
def _make_lucene_query(self, string):
string = self._escape_lucene(string)
result = string.lower()
return result
def _get_utf16_codepoints(self, string):
result = []
for ch in string:
data = ch.encode("utf-16-le")
for i in range(0, len(data), 2):
value = int.from_bytes(data[i:i+2], "little")
result.append(value)
return tuple(result)
def _escape_neo4j(self, name):
match = self._simple_ident_pattern.fullmatch(name)
if match is not None:
return name
string = "`"
for ch in name:
if ch == "`":
string += "``"
elif ch.isascii() and ch.isprintable():
string += ch
else:
values = self._get_utf16_codepoints(ch)
for value in values:
string += "\\u%04X" % value
string += "`"
return string
def _to_snake_case(self, name):
import re
words = re.findall("[A-Za-z][a-z0-9]*", name)
result = "_".join(words).lower()
return result
def _create_vector_index_name(self, label, property_key):
name = f"{label}_{property_key}_vector_index"
name = self._to_snake_case(name)
return "_" + name
def _create_vector_field_name(self, property_key):
name = f"{property_key}_vector"
name = self._to_snake_case(name)
return "_" + name
def create_index(self, label, property_key, index_name=None):
with self._driver.session(database=self._database) as session:
session.execute_write(self._create_index, self, label, property_key, index_name)
@staticmethod
def _create_index(tx, self, label, property_key, index_name):
if not label or not property_key:
return
if index_name is None:
query = f"CREATE INDEX IF NOT EXISTS FOR (n:{self._escape_neo4j(label)}) ON (n.{self._escape_neo4j(property_key)})"
else:
query = f"CREATE INDEX {self._escape_neo4j(index_name)} IF NOT EXISTS FOR (n:{self._escape_neo4j(label)}) ON (n.{self._escape_neo4j(property_key)})"
tx.run(query)
def create_text_index(self, labels, property_keys, index_name=None):
if not labels or not property_keys:
return
if index_name is None:
index_name = "_default_text_index"
label_spec = "|".join(self._escape_neo4j(label) for label in labels)
property_spec = ", ".join(f"n.{self._escape_neo4j(key)}" for key in property_keys)
query = (
f"CREATE FULLTEXT INDEX {self._escape_neo4j(index_name)} IF NOT EXISTS "
f"FOR (n:{label_spec}) ON EACH [{property_spec}]"
)
def do_create_text_index(tx):
tx.run(query)
with self._driver.session(database=self._database) as session:
session.execute_write(do_create_text_index)
return index_name
def create_vector_index(self, label, property_key, index_name=None,
vector_dimensions=768, metric_type="cosine",
hnsw_m=None, hnsw_ef_construction=None):
if index_name is None:
index_name = self._create_vector_index_name(label, property_key)
if not property_key.lower().endswith("vector"):
property_key = self._create_vector_field_name(property_key)
with self._driver.session(database=self._database) as session:
session.execute_write(self._create_vector_index, self, label, property_key, index_name,
vector_dimensions, metric_type, hnsw_m, hnsw_ef_construction)
self.refresh_vector_index_meta(force=True)
return index_name
@staticmethod
def _create_vector_index(tx, self, label, property_key, index_name, vector_dimensions, metric_type, hnsw_m, hnsw_ef_construction):
query = (
f"CREATE VECTOR INDEX {self._escape_neo4j(index_name)} IF NOT EXISTS FOR (n:{self._escape_neo4j(label)}) ON (n.{self._escape_neo4j(property_key)}) "
"OPTIONS { indexConfig: {"
" `vector.dimensions`: $vector_dimensions,"
" `vector.similarity_function`: $metric_type"
)
if hnsw_m is not None:
query += ", `vector.hnsw.m`: $hnsw_m"
if hnsw_ef_construction is not None:
query += ", `vector.hnsw.ef_construction`: $hnsw_ef_construction"
query += "}}"
tx.run(query, vector_dimensions=vector_dimensions, metric_type=metric_type,
hnsw_m=hnsw_m, hnsw_ef_construction=hnsw_ef_construction)
def refresh_vector_index_meta(self, force=False):
import time
if not force and time.time() - self._vec_meta_ts < self._vec_meta_timeout:
return
def do_refresh_vector_index_meta(tx):
query = "SHOW VECTOR INDEX"
res = tx.run(query)
data = res.data()
meta = dict()
for record in data:
if record["entityType"] == "NODE":
label, = record["labelsOrTypes"]
vector_field, = record["properties"]
if vector_field.startswith("_") and vector_field.endswith("_vector"):
if label not in meta:
meta[label] = []
meta[label].append(vector_field)
self._vec_meta = meta
self._vec_meta_ts = time.time()
with self._driver.session(database=self._database) as session:
session.execute_read(do_refresh_vector_index_meta)
def delete_index(self, index_name):
with self._driver.session(database=self._database) as session:
session.execute_write(self._delete_index, self, index_name)
@staticmethod
def _delete_index(tx, self, index_name):
query = f"DROP INDEX {self._escape_neo4j(index_name)} IF EXISTS"
tx.run(query)
@property
def vectorizer(self):
if self._vectorizer is None:
message = "vectorizer is not initialized"
raise RuntimeError(message)
return self._vectorizer
@vectorizer.setter
def vectorizer(self, value):
self._vectorizer = value
def text_search(self, query_string, label_constraints=None, topk=10, index_name=None):
if index_name is None:
index_name = "_default_text_index"
if label_constraints is None:
pass
elif isinstance(label_constraints, str):
label_constraints = self._escape_neo4j(label_constraints)
elif isinstance(label_constraints, (list, tuple)):
label_constraints = "|".join(self._escape_neo4j(label_constraint) for label_constraint in label_constraints)
else:
message = f"invalid label_constraints: {label_constraints!r}"
raise RuntimeError(message)
if label_constraints is None:
query = ("CALL db.index.fulltext.queryNodes($index_name, $query_string) "
"YIELD node AS node, score "
"RETURN node, score")
else:
query = ("CALL db.index.fulltext.queryNodes($index_name, $query_string) "
"YIELD node AS node, score "
f"WHERE (node:{label_constraints}) "
"RETURN node, score")
query += " LIMIT $topk"
query_string = self._make_lucene_query(query_string)
def do_text_search(tx):
res = tx.run(query, query_string=query_string, topk=topk, index_name=index_name)
data = res.data()
return data
with self._driver.session(database=self._database) as session:
return session.execute_read(do_text_search)
def vector_search(self, label, property_key, query_text_or_vector, topk=10, index_name=None, ef_search=None):
if ef_search is not None:
if ef_search < topk:
message = f"ef_search must be greater than or equal to topk; {ef_search!r} is invalid"
raise ValueError(message)
self.refresh_vector_index_meta()
if index_name is None:
vec_meta = self._vec_meta
if label not in vec_meta:
logger.warning(f"vector index not defined for label, return empty. label: {label}, "
f"property_key: {property_key}, query_text_or_vector: {query_text_or_vector}.")
return []
vector_field = self._create_vector_field_name(property_key)
if vector_field not in vec_meta[label]:
logger.warning(f"vector index not defined for field, return empty. label: {label}, "
f"property_key: {property_key}, query_text_or_vector: {query_text_or_vector}.")
return []
if index_name is None:
index_name = self._create_vector_index_name(label, property_key)
if isinstance(query_text_or_vector, str):
query_vector = self.vectorizer.vectorize(query_text_or_vector)
else:
query_vector = query_text_or_vector
def do_vector_search(tx):
if ef_search is not None:
query = ("CALL db.index.vector.queryNodes($index_name, $ef_search, $query_vector) "
"YIELD node, score "
"RETURN node, score, labels(node) as __labels__"
f"LIMIT {topk}")
res = tx.run(query, query_vector=query_vector, ef_search=ef_search, index_name=index_name)
else:
query = ("CALL db.index.vector.queryNodes($index_name, $topk, $query_vector) "
"YIELD node, score "
"RETURN node, score, labels(node) as __labels__")
res = tx.run(query, query_vector=query_vector, topk=topk, index_name=index_name)
data = res.data()
for record in data:
record["node"]["__labels__"] = record["__labels__"]
del record["__labels__"]
return data
with self._driver.session(database=self._database) as session:
return session.execute_read(do_vector_search)
def _create_all_graph(self, graph_name):
with self._driver.session(database=self._database) as session:
logger.debug(f"create pagerank graph graph_name{graph_name} database{self._database}")
result = session.run(f"""
CALL gds.graph.exists('{graph_name}') YIELD exists
WHERE exists
CALL gds.graph.drop('{graph_name}') YIELD graphName
RETURN graphName
""")
summary = result.consume()
logger.debug(f"create pagerank graph exists graph_name{graph_name} database{self._database} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}")
result = session.run(f"""
CALL gds.graph.project('{graph_name}','*','*')
YIELD graphName, nodeCount AS nodes, relationshipCount AS rels
RETURN graphName, nodes, rels
""")
summary = result.consume()
logger.debug(f"create pagerank graph graph_name{graph_name} database{self._database} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}")
def _drop_all_graph(self, graph_name):
with self._driver.session(database=self._database) as session:
logger.debug(f"drop pagerank graph graph_name{graph_name} database{self._database}")
result = session.run(f"""
CALL gds.graph.exists('{graph_name}') YIELD exists
WHERE exists
CALL gds.graph.drop('{graph_name}') YIELD graphName
RETURN graphName
""")
result.consume()
logger.debug(f"drop pagerank graph graph_name{graph_name} database{self._database} succeed")
def execute_pagerank(self, iterations=20, damping_factor=0.85):
with self._driver.session(database=self._database) as session:
return session.execute_write(self._execute_pagerank, iterations, damping_factor)
@staticmethod
def _execute_pagerank(tx, iterations, damping_factor):
query = (
"CALL algo.pageRank.stream("
"{iterations: $iterations, dampingFactor: $damping_factor}) "
"YIELD nodeId, score "
"RETURN algo.getNodeById(nodeId) AS node, score "
"ORDER BY score DESC"
)
result = tx.run(query, iterations=iterations, damping_factor=damping_factor)
return [{"node": record["node"], "score": record["score"]} for record in result]
def get_pagerank_scores(self, start_nodes, target_type):
with self._driver.session(database=self._database) as session:
all_graph = self._allGraph
self._exists_all_graph(session, all_graph)
data = session.execute_write(self._get_pagerank_scores, self, all_graph, start_nodes, target_type)
return data
@staticmethod
def _get_pagerank_scores(tx, self, graph_name, start_nodes, return_type):
match_clauses = []
match_identify = []
for index, node in enumerate(start_nodes):
node_type, node_name = node['type'], node['name']
node_identify = f"node_{index}"
match_clauses.append(f"MATCH ({node_identify}:{self._escape_neo4j(node_type)} {{name: '{escape_single_quotes(node_name)}'}})")
match_identify.append(node_identify)
match_query = ' '.join(match_clauses)
match_identify_str = ', '.join(match_identify)
pagerank_query = f"""
{match_query}
CALL gds.pageRank.stream('{graph_name}',{{
maxIterations: 20,
dampingFactor: 0.85,
sourceNodes: [{match_identify_str}]
}})
YIELD nodeId, score
MATCH (m:{return_type}) WHERE id(m) = nodeId
RETURN id(m) AS g_id, gds.util.asNode(nodeId).id AS id, score
ORDER BY score DESC
"""
result = tx.run(pagerank_query)
return [{"id": record["id"], "score": record["score"]} for record in result]
@staticmethod
def _exists_all_graph(session, graph_name):
try:
logger.debug(f"exists pagerank graph graph_name{graph_name}")
result = session.run(f"""
CALL gds.graph.exists('{graph_name}') YIELD exists
WHERE NOT exists
CALL gds.graph.project('{graph_name}','*','*')
YIELD graphName, nodeCount AS nodes, relationshipCount AS rels
RETURN graphName, nodes, rels
""")
summary = result.consume()
logger.debug(f"exists pagerank graph graph_name{graph_name} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}")
except Exception as e:
logger.debug(f"Error exists pagerank graph {graph_name}: {e}")
def count(self, label):
with self._driver.session(database=self._database) as session:
return session.execute_read(self._count, self, label)
@staticmethod
def _count(tx, self, label):
query = f"MATCH (n:{self._escape_neo4j(label)}) RETURN count(n)"
result = tx.run(query)
single_result = result.single()
if single_result is not None:
return single_result[0]
def create_database(self, database):
with self._driver.session(database=self._database) as session:
database = database.lower()
result = session.run(f"CREATE DATABASE {self._escape_neo4j(database)} IF NOT EXISTS")
summary = result.consume()
logger.info(f"create_database {database} succeed "
f"executed{summary.result_available_after} consumed{summary.result_consumed_after}")
def delete_all_data(self, database):
if self._database != database:
raise ValueError(f"Error: Current database ({self._database}) is not the same as the target database ({database}).")
with self._driver.session(database=database) as session:
while True:
result = session.run("MATCH (n) WITH n LIMIT 100000 DETACH DELETE n RETURN count(*)")
count = result.single()[0]
logger.info(f"Deleted {count} nodes in this batch.")
if count == 0:
logger.info("All data has been deleted.")
break
def run_cypher_query(self, database, query, parameters=None):
if database and self._database != database:
raise ValueError(f"Current database ({self._database}) is not the same as the target database ({database}).")
with self._driver.session(database=database) as session:
result = session.run(query, parameters)
return [record for record in result]

View File

@ -0,0 +1,38 @@
# coding: utf-8
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
# flake8: noqa
"""
kag
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) # noqa: E501
The version of the OpenAPI document: 1.0.0
Generated by: https://openapi-generator.tech
"""
from __future__ import absolute_import
__version__ = "1"
# import apis into sdk package
from kag.common.graphstore.rest.graph_api import GraphApi
# import models into sdk package
from kag.common.graphstore.rest.models.delete_edge_request import DeleteEdgeRequest
from kag.common.graphstore.rest.models.delete_vertex_request import DeleteVertexRequest
from kag.common.graphstore.rest.models.edge_record_instance import EdgeRecordInstance
from kag.common.graphstore.rest.models.upsert_edge_request import UpsertEdgeRequest
from kag.common.graphstore.rest.models.upsert_vertex_request import UpsertVertexRequest
from kag.common.graphstore.rest.models.vertex_record_instance import VertexRecordInstance

View File

@ -0,0 +1,485 @@
# coding: utf-8
"""
kag
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) # noqa: E501
The version of the OpenAPI document: 1.0.0
Generated by: https://openapi-generator.tech
"""
from __future__ import absolute_import
import re # noqa: F401
# python 2 and python 3 compatibility library
import six
from kag.common.rest.api_client import ApiClient
from kag.common.rest.exceptions import ( # noqa: F401
ApiTypeError,
ApiValueError
)
class GraphApi(object):
"""NOTE: This class is auto generated by OpenAPI Generator
Ref: https://openapi-generator.tech
Do not edit the class manually.
"""
def __init__(self, api_client=None):
if api_client is None:
api_client = ApiClient()
self.api_client = api_client
def graph_delete_edge_post(self, **kwargs): # noqa: E501
"""delete_edge # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
>>> thread = api.graph_delete_edge_post(async_req=True)
>>> result = thread.get()
:param async_req bool: execute request asynchronously
:param DeleteEdgeRequest delete_edge_request:
:param _preload_content: if False, the urllib3.HTTPResponse object will
be returned without reading/decoding response
data. Default is True.
:param _request_timeout: timeout setting for this request. If one
number provided, it will be total request
timeout. It can also be a pair (tuple) of
(connection, read) timeouts.
:return: object
If the method is called asynchronously,
returns the request thread.
"""
kwargs['_return_http_data_only'] = True
return self.graph_delete_edge_post_with_http_info(**kwargs) # noqa: E501
def graph_delete_edge_post_with_http_info(self, **kwargs): # noqa: E501
"""delete_edge # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
>>> thread = api.graph_delete_edge_post_with_http_info(async_req=True)
>>> result = thread.get()
:param async_req bool: execute request asynchronously
:param DeleteEdgeRequest delete_edge_request:
:param _return_http_data_only: response data without head status code
and headers
:param _preload_content: if False, the urllib3.HTTPResponse object will
be returned without reading/decoding response
data. Default is True.
:param _request_timeout: timeout setting for this request. If one
number provided, it will be total request
timeout. It can also be a pair (tuple) of
(connection, read) timeouts.
:return: tuple(object, status_code(int), headers(HTTPHeaderDict))
If the method is called asynchronously,
returns the request thread.
"""
local_var_params = locals()
all_params = [
'delete_edge_request'
]
all_params.extend(
[
'async_req',
'_return_http_data_only',
'_preload_content',
'_request_timeout'
]
)
for key, val in six.iteritems(local_var_params['kwargs']):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_delete_edge_post" % key
)
local_var_params[key] = val
del local_var_params['kwargs']
collection_formats = {}
path_params = {}
query_params = []
header_params = {}
form_params = []
local_var_files = {}
body_params = None
if 'delete_edge_request' in local_var_params:
body_params = local_var_params['delete_edge_request']
# HTTP header `Accept`
header_params['Accept'] = self.api_client.select_header_accept(
['application/json']) # noqa: E501
# HTTP header `Content-Type`
header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
['application/json']) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
'/graph/deleteEdge', 'POST',
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
response_type='object', # noqa: E501
auth_settings=auth_settings,
async_req=local_var_params.get('async_req'),
_return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
_preload_content=local_var_params.get('_preload_content', True),
_request_timeout=local_var_params.get('_request_timeout'),
collection_formats=collection_formats)
def graph_delete_vertex_post(self, **kwargs): # noqa: E501
"""delete_vertex # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
>>> thread = api.graph_delete_vertex_post(async_req=True)
>>> result = thread.get()
:param async_req bool: execute request asynchronously
:param DeleteVertexRequest delete_vertex_request:
:param _preload_content: if False, the urllib3.HTTPResponse object will
be returned without reading/decoding response
data. Default is True.
:param _request_timeout: timeout setting for this request. If one
number provided, it will be total request
timeout. It can also be a pair (tuple) of
(connection, read) timeouts.
:return: object
If the method is called asynchronously,
returns the request thread.
"""
kwargs['_return_http_data_only'] = True
return self.graph_delete_vertex_post_with_http_info(**kwargs) # noqa: E501
def graph_delete_vertex_post_with_http_info(self, **kwargs): # noqa: E501
"""delete_vertex # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
>>> thread = api.graph_delete_vertex_post_with_http_info(async_req=True)
>>> result = thread.get()
:param async_req bool: execute request asynchronously
:param DeleteVertexRequest delete_vertex_request:
:param _return_http_data_only: response data without head status code
and headers
:param _preload_content: if False, the urllib3.HTTPResponse object will
be returned without reading/decoding response
data. Default is True.
:param _request_timeout: timeout setting for this request. If one
number provided, it will be total request
timeout. It can also be a pair (tuple) of
(connection, read) timeouts.
:return: tuple(object, status_code(int), headers(HTTPHeaderDict))
If the method is called asynchronously,
returns the request thread.
"""
local_var_params = locals()
all_params = [
'delete_vertex_request'
]
all_params.extend(
[
'async_req',
'_return_http_data_only',
'_preload_content',
'_request_timeout'
]
)
for key, val in six.iteritems(local_var_params['kwargs']):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_delete_vertex_post" % key
)
local_var_params[key] = val
del local_var_params['kwargs']
collection_formats = {}
path_params = {}
query_params = []
header_params = {}
form_params = []
local_var_files = {}
body_params = None
if 'delete_vertex_request' in local_var_params:
body_params = local_var_params['delete_vertex_request']
# HTTP header `Accept`
header_params['Accept'] = self.api_client.select_header_accept(
['application/json']) # noqa: E501
# HTTP header `Content-Type`
header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
['application/json']) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
'/graph/deleteVertex', 'POST',
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
response_type='object', # noqa: E501
auth_settings=auth_settings,
async_req=local_var_params.get('async_req'),
_return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
_preload_content=local_var_params.get('_preload_content', True),
_request_timeout=local_var_params.get('_request_timeout'),
collection_formats=collection_formats)
def graph_upsert_edge_post(self, **kwargs): # noqa: E501
"""upsert_edge # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
>>> thread = api.graph_upsert_edge_post(async_req=True)
>>> result = thread.get()
:param async_req bool: execute request asynchronously
:param UpsertEdgeRequest upsert_edge_request:
:param _preload_content: if False, the urllib3.HTTPResponse object will
be returned without reading/decoding response
data. Default is True.
:param _request_timeout: timeout setting for this request. If one
number provided, it will be total request
timeout. It can also be a pair (tuple) of
(connection, read) timeouts.
:return: object
If the method is called asynchronously,
returns the request thread.
"""
kwargs['_return_http_data_only'] = True
return self.graph_upsert_edge_post_with_http_info(**kwargs) # noqa: E501
def graph_upsert_edge_post_with_http_info(self, **kwargs): # noqa: E501
"""upsert_edge # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
>>> thread = api.graph_upsert_edge_post_with_http_info(async_req=True)
>>> result = thread.get()
:param async_req bool: execute request asynchronously
:param UpsertEdgeRequest upsert_edge_request:
:param _return_http_data_only: response data without head status code
and headers
:param _preload_content: if False, the urllib3.HTTPResponse object will
be returned without reading/decoding response
data. Default is True.
:param _request_timeout: timeout setting for this request. If one
number provided, it will be total request
timeout. It can also be a pair (tuple) of
(connection, read) timeouts.
:return: tuple(object, status_code(int), headers(HTTPHeaderDict))
If the method is called asynchronously,
returns the request thread.
"""
local_var_params = locals()
all_params = [
'upsert_edge_request'
]
all_params.extend(
[
'async_req',
'_return_http_data_only',
'_preload_content',
'_request_timeout'
]
)
for key, val in six.iteritems(local_var_params['kwargs']):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_upsert_edge_post" % key
)
local_var_params[key] = val
del local_var_params['kwargs']
collection_formats = {}
path_params = {}
query_params = []
header_params = {}
form_params = []
local_var_files = {}
body_params = None
if 'upsert_edge_request' in local_var_params:
body_params = local_var_params['upsert_edge_request']
# HTTP header `Accept`
header_params['Accept'] = self.api_client.select_header_accept(
['application/json']) # noqa: E501
# HTTP header `Content-Type`
header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
['application/json']) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
'/graph/upsertEdge', 'POST',
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
response_type='object', # noqa: E501
auth_settings=auth_settings,
async_req=local_var_params.get('async_req'),
_return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
_preload_content=local_var_params.get('_preload_content', True),
_request_timeout=local_var_params.get('_request_timeout'),
collection_formats=collection_formats)
def graph_upsert_vertex_post(self, **kwargs): # noqa: E501
"""upsert_vertex # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
>>> thread = api.graph_upsert_vertex_post(async_req=True)
>>> result = thread.get()
:param async_req bool: execute request asynchronously
:param UpsertVertexRequest upsert_vertex_request:
:param _preload_content: if False, the urllib3.HTTPResponse object will
be returned without reading/decoding response
data. Default is True.
:param _request_timeout: timeout setting for this request. If one
number provided, it will be total request
timeout. It can also be a pair (tuple) of
(connection, read) timeouts.
:return: object
If the method is called asynchronously,
returns the request thread.
"""
kwargs['_return_http_data_only'] = True
return self.graph_upsert_vertex_post_with_http_info(**kwargs) # noqa: E501
def graph_upsert_vertex_post_with_http_info(self, **kwargs): # noqa: E501
"""upsert_vertex # noqa: E501
This method makes a synchronous HTTP request by default. To make an
asynchronous HTTP request, please pass async_req=True
>>> thread = api.graph_upsert_vertex_post_with_http_info(async_req=True)
>>> result = thread.get()
:param async_req bool: execute request asynchronously
:param UpsertVertexRequest upsert_vertex_request:
:param _return_http_data_only: response data without head status code
and headers
:param _preload_content: if False, the urllib3.HTTPResponse object will
be returned without reading/decoding response
data. Default is True.
:param _request_timeout: timeout setting for this request. If one
number provided, it will be total request
timeout. It can also be a pair (tuple) of
(connection, read) timeouts.
:return: tuple(object, status_code(int), headers(HTTPHeaderDict))
If the method is called asynchronously,
returns the request thread.
"""
local_var_params = locals()
all_params = [
'upsert_vertex_request'
]
all_params.extend(
[
'async_req',
'_return_http_data_only',
'_preload_content',
'_request_timeout'
]
)
for key, val in six.iteritems(local_var_params['kwargs']):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_upsert_vertex_post" % key
)
local_var_params[key] = val
del local_var_params['kwargs']
collection_formats = {}
path_params = {}
query_params = []
header_params = {}
form_params = []
local_var_files = {}
body_params = None
if 'upsert_vertex_request' in local_var_params:
body_params = local_var_params['upsert_vertex_request']
# HTTP header `Accept`
header_params['Accept'] = self.api_client.select_header_accept(
['application/json']) # noqa: E501
# HTTP header `Content-Type`
header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
['application/json']) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
'/graph/upsertVertex', 'POST',
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
response_type='object', # noqa: E501
auth_settings=auth_settings,
async_req=local_var_params.get('async_req'),
_return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
_preload_content=local_var_params.get('_preload_content', True),
_request_timeout=local_var_params.get('_request_timeout'),
collection_formats=collection_formats)

View File

@ -0,0 +1,19 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from __future__ import absolute_import
from kag.common.graphstore.rest.models.delete_edge_request import DeleteEdgeRequest
from kag.common.graphstore.rest.models.delete_vertex_request import DeleteVertexRequest
from kag.common.graphstore.rest.models.edge_record_instance import EdgeRecordInstance
from kag.common.graphstore.rest.models.upsert_edge_request import UpsertEdgeRequest
from kag.common.graphstore.rest.models.upsert_vertex_request import UpsertVertexRequest
from kag.common.graphstore.rest.models.vertex_record_instance import VertexRecordInstance

View File

@ -0,0 +1,148 @@
# coding: utf-8
"""
kag
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) # noqa: E501
The version of the OpenAPI document: 1.0.0
Generated by: https://openapi-generator.tech
"""
import pprint
import re # noqa: F401
import six
from kag.common.rest.configuration import Configuration
class DeleteEdgeRequest(object):
"""NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
Do not edit the class manually.
"""
"""
Attributes:
openapi_types (dict): The key is attribute name
and the value is attribute type.
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
openapi_types = {
'project_id': 'int',
'edges': 'list[EdgeRecordInstance]'
}
attribute_map = {
'project_id': 'projectId',
'edges': 'edges'
}
def __init__(self, project_id=None, edges=None, local_vars_configuration=None): # noqa: E501
"""DeleteEdgeRequest - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
self.local_vars_configuration = local_vars_configuration
self._project_id = None
self._edges = None
self.discriminator = None
self.project_id = project_id
self.edges = edges
@property
def project_id(self):
"""Gets the project_id of this DeleteEdgeRequest. # noqa: E501
:return: The project_id of this DeleteEdgeRequest. # noqa: E501
:rtype: int
"""
return self._project_id
@project_id.setter
def project_id(self, project_id):
"""Sets the project_id of this DeleteEdgeRequest.
:param project_id: The project_id of this DeleteEdgeRequest. # noqa: E501
:type: int
"""
if self.local_vars_configuration.client_side_validation and project_id is None: # noqa: E501
raise ValueError("Invalid value for `project_id`, must not be `None`") # noqa: E501
self._project_id = project_id
@property
def edges(self):
"""Gets the edges of this DeleteEdgeRequest. # noqa: E501
:return: The edges of this DeleteEdgeRequest. # noqa: E501
:rtype: list[EdgeRecordInstance]
"""
return self._edges
@edges.setter
def edges(self, edges):
"""Sets the edges of this DeleteEdgeRequest.
:param edges: The edges of this DeleteEdgeRequest. # noqa: E501
:type: list[EdgeRecordInstance]
"""
if self.local_vars_configuration.client_side_validation and edges is None: # noqa: E501
raise ValueError("Invalid value for `edges`, must not be `None`") # noqa: E501
self._edges = edges
def to_dict(self):
"""Returns the model properties as a dict"""
result = {}
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
result[attr] = list(map(
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
value
))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
result[attr] = dict(map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict") else item,
value.items()
))
else:
result[attr] = value
return result
def to_str(self):
"""Returns the string representation of the model"""
return pprint.pformat(self.to_dict())
def __repr__(self):
"""For `print` and `pprint`"""
return self.to_str()
def __eq__(self, other):
"""Returns true if both objects are equal"""
if not isinstance(other, DeleteEdgeRequest):
return False
return self.to_dict() == other.to_dict()
def __ne__(self, other):
"""Returns true if both objects are not equal"""
if not isinstance(other, DeleteEdgeRequest):
return True
return self.to_dict() != other.to_dict()

View File

@ -0,0 +1,148 @@
# coding: utf-8
"""
kag
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) # noqa: E501
The version of the OpenAPI document: 1.0.0
Generated by: https://openapi-generator.tech
"""
import pprint
import re # noqa: F401
import six
from kag.common.rest.configuration import Configuration
class DeleteVertexRequest(object):
"""NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
Do not edit the class manually.
"""
"""
Attributes:
openapi_types (dict): The key is attribute name
and the value is attribute type.
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
openapi_types = {
'project_id': 'int',
'vertices': 'list[VertexRecordInstance]'
}
attribute_map = {
'project_id': 'projectId',
'vertices': 'vertices'
}
def __init__(self, project_id=None, vertices=None, local_vars_configuration=None): # noqa: E501
"""DeleteVertexRequest - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
self.local_vars_configuration = local_vars_configuration
self._project_id = None
self._vertices = None
self.discriminator = None
self.project_id = project_id
self.vertices = vertices
@property
def project_id(self):
"""Gets the project_id of this DeleteVertexRequest. # noqa: E501
:return: The project_id of this DeleteVertexRequest. # noqa: E501
:rtype: int
"""
return self._project_id
@project_id.setter
def project_id(self, project_id):
"""Sets the project_id of this DeleteVertexRequest.
:param project_id: The project_id of this DeleteVertexRequest. # noqa: E501
:type: int
"""
if self.local_vars_configuration.client_side_validation and project_id is None: # noqa: E501
raise ValueError("Invalid value for `project_id`, must not be `None`") # noqa: E501
self._project_id = project_id
@property
def vertices(self):
"""Gets the vertices of this DeleteVertexRequest. # noqa: E501
:return: The vertices of this DeleteVertexRequest. # noqa: E501
:rtype: list[VertexRecordInstance]
"""
return self._vertices
@vertices.setter
def vertices(self, vertices):
"""Sets the vertices of this DeleteVertexRequest.
:param vertices: The vertices of this DeleteVertexRequest. # noqa: E501
:type: list[VertexRecordInstance]
"""
if self.local_vars_configuration.client_side_validation and vertices is None: # noqa: E501
raise ValueError("Invalid value for `vertices`, must not be `None`") # noqa: E501
self._vertices = vertices
def to_dict(self):
"""Returns the model properties as a dict"""
result = {}
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
result[attr] = list(map(
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
value
))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
result[attr] = dict(map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict") else item,
value.items()
))
else:
result[attr] = value
return result
def to_str(self):
"""Returns the string representation of the model"""
return pprint.pformat(self.to_dict())
def __repr__(self):
"""For `print` and `pprint`"""
return self.to_str()
def __eq__(self, other):
"""Returns true if both objects are equal"""
if not isinstance(other, DeleteVertexRequest):
return False
return self.to_dict() == other.to_dict()
def __ne__(self, other):
"""Returns true if both objects are not equal"""
if not isinstance(other, DeleteVertexRequest):
return True
return self.to_dict() != other.to_dict()

View File

@ -0,0 +1,256 @@
# coding: utf-8
"""
kag
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) # noqa: E501
The version of the OpenAPI document: 1.0.0
Generated by: https://openapi-generator.tech
"""
import pprint
import re # noqa: F401
import six
from kag.common.rest.configuration import Configuration
class EdgeRecordInstance(object):
"""NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
Do not edit the class manually.
"""
"""
Attributes:
openapi_types (dict): The key is attribute name
and the value is attribute type.
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
openapi_types = {
'src_type': 'str',
'src_id': 'str',
'dst_type': 'str',
'dst_id': 'str',
'label': 'str',
'properties': 'object'
}
attribute_map = {
'src_type': 'srcType',
'src_id': 'srcId',
'dst_type': 'dstType',
'dst_id': 'dstId',
'label': 'label',
'properties': 'properties'
}
def __init__(self, src_type=None, src_id=None, dst_type=None, dst_id=None, label=None, properties=None, local_vars_configuration=None): # noqa: E501
"""EdgeRecordInstance - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
self.local_vars_configuration = local_vars_configuration
self._src_type = None
self._src_id = None
self._dst_type = None
self._dst_id = None
self._label = None
self._properties = None
self.discriminator = None
self.src_type = src_type
self.src_id = src_id
self.dst_type = dst_type
self.dst_id = dst_id
self.label = label
self.properties = properties
@property
def src_type(self):
"""Gets the src_type of this EdgeRecordInstance. # noqa: E501
:return: The src_type of this EdgeRecordInstance. # noqa: E501
:rtype: str
"""
return self._src_type
@src_type.setter
def src_type(self, src_type):
"""Sets the src_type of this EdgeRecordInstance.
:param src_type: The src_type of this EdgeRecordInstance. # noqa: E501
:type: str
"""
if self.local_vars_configuration.client_side_validation and src_type is None: # noqa: E501
raise ValueError("Invalid value for `src_type`, must not be `None`") # noqa: E501
self._src_type = src_type
@property
def src_id(self):
"""Gets the src_id of this EdgeRecordInstance. # noqa: E501
:return: The src_id of this EdgeRecordInstance. # noqa: E501
:rtype: str
"""
return self._src_id
@src_id.setter
def src_id(self, src_id):
"""Sets the src_id of this EdgeRecordInstance.
:param src_id: The src_id of this EdgeRecordInstance. # noqa: E501
:type: str
"""
if self.local_vars_configuration.client_side_validation and src_id is None: # noqa: E501
raise ValueError("Invalid value for `src_id`, must not be `None`") # noqa: E501
self._src_id = src_id
@property
def dst_type(self):
"""Gets the dst_type of this EdgeRecordInstance. # noqa: E501
:return: The dst_type of this EdgeRecordInstance. # noqa: E501
:rtype: str
"""
return self._dst_type
@dst_type.setter
def dst_type(self, dst_type):
"""Sets the dst_type of this EdgeRecordInstance.
:param dst_type: The dst_type of this EdgeRecordInstance. # noqa: E501
:type: str
"""
if self.local_vars_configuration.client_side_validation and dst_type is None: # noqa: E501
raise ValueError("Invalid value for `dst_type`, must not be `None`") # noqa: E501
self._dst_type = dst_type
@property
def dst_id(self):
"""Gets the dst_id of this EdgeRecordInstance. # noqa: E501
:return: The dst_id of this EdgeRecordInstance. # noqa: E501
:rtype: str
"""
return self._dst_id
@dst_id.setter
def dst_id(self, dst_id):
"""Sets the dst_id of this EdgeRecordInstance.
:param dst_id: The dst_id of this EdgeRecordInstance. # noqa: E501
:type: str
"""
if self.local_vars_configuration.client_side_validation and dst_id is None: # noqa: E501
raise ValueError("Invalid value for `dst_id`, must not be `None`") # noqa: E501
self._dst_id = dst_id
@property
def label(self):
"""Gets the label of this EdgeRecordInstance. # noqa: E501
:return: The label of this EdgeRecordInstance. # noqa: E501
:rtype: str
"""
return self._label
@label.setter
def label(self, label):
"""Sets the label of this EdgeRecordInstance.
:param label: The label of this EdgeRecordInstance. # noqa: E501
:type: str
"""
if self.local_vars_configuration.client_side_validation and label is None: # noqa: E501
raise ValueError("Invalid value for `label`, must not be `None`") # noqa: E501
self._label = label
@property
def properties(self):
"""Gets the properties of this EdgeRecordInstance. # noqa: E501
:return: The properties of this EdgeRecordInstance. # noqa: E501
:rtype: object
"""
return self._properties
@properties.setter
def properties(self, properties):
"""Sets the properties of this EdgeRecordInstance.
:param properties: The properties of this EdgeRecordInstance. # noqa: E501
:type: object
"""
if self.local_vars_configuration.client_side_validation and properties is None: # noqa: E501
raise ValueError("Invalid value for `properties`, must not be `None`") # noqa: E501
self._properties = properties
def to_dict(self):
"""Returns the model properties as a dict"""
result = {}
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
result[attr] = list(map(
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
value
))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
result[attr] = dict(map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict") else item,
value.items()
))
else:
result[attr] = value
return result
def to_str(self):
"""Returns the string representation of the model"""
return pprint.pformat(self.to_dict())
def __repr__(self):
"""For `print` and `pprint`"""
return self.to_str()
def __eq__(self, other):
"""Returns true if both objects are equal"""
if not isinstance(other, EdgeRecordInstance):
return False
return self.to_dict() == other.to_dict()
def __ne__(self, other):
"""Returns true if both objects are not equal"""
if not isinstance(other, EdgeRecordInstance):
return True
return self.to_dict() != other.to_dict()

View File

@ -0,0 +1,175 @@
# coding: utf-8
"""
kag
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) # noqa: E501
The version of the OpenAPI document: 1.0.0
Generated by: https://openapi-generator.tech
"""
import pprint
import re # noqa: F401
import six
from kag.common.rest.configuration import Configuration
class UpsertEdgeRequest(object):
"""NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
Do not edit the class manually.
"""
"""
Attributes:
openapi_types (dict): The key is attribute name
and the value is attribute type.
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
openapi_types = {
'project_id': 'int',
'upsert_adjacent_vertices': 'bool',
'edges': 'list[EdgeRecordInstance]'
}
attribute_map = {
'project_id': 'projectId',
'upsert_adjacent_vertices': 'upsertAdjacentVertices',
'edges': 'edges'
}
def __init__(self, project_id=None, upsert_adjacent_vertices=None, edges=None, local_vars_configuration=None): # noqa: E501
"""UpsertEdgeRequest - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
self.local_vars_configuration = local_vars_configuration
self._project_id = None
self._upsert_adjacent_vertices = None
self._edges = None
self.discriminator = None
self.project_id = project_id
self.upsert_adjacent_vertices = upsert_adjacent_vertices
self.edges = edges
@property
def project_id(self):
"""Gets the project_id of this UpsertEdgeRequest. # noqa: E501
:return: The project_id of this UpsertEdgeRequest. # noqa: E501
:rtype: int
"""
return self._project_id
@project_id.setter
def project_id(self, project_id):
"""Sets the project_id of this UpsertEdgeRequest.
:param project_id: The project_id of this UpsertEdgeRequest. # noqa: E501
:type: int
"""
if self.local_vars_configuration.client_side_validation and project_id is None: # noqa: E501
raise ValueError("Invalid value for `project_id`, must not be `None`") # noqa: E501
self._project_id = project_id
@property
def upsert_adjacent_vertices(self):
"""Gets the upsert_adjacent_vertices of this UpsertEdgeRequest. # noqa: E501
:return: The upsert_adjacent_vertices of this UpsertEdgeRequest. # noqa: E501
:rtype: bool
"""
return self._upsert_adjacent_vertices
@upsert_adjacent_vertices.setter
def upsert_adjacent_vertices(self, upsert_adjacent_vertices):
"""Sets the upsert_adjacent_vertices of this UpsertEdgeRequest.
:param upsert_adjacent_vertices: The upsert_adjacent_vertices of this UpsertEdgeRequest. # noqa: E501
:type: bool
"""
if self.local_vars_configuration.client_side_validation and upsert_adjacent_vertices is None: # noqa: E501
raise ValueError("Invalid value for `upsert_adjacent_vertices`, must not be `None`") # noqa: E501
self._upsert_adjacent_vertices = upsert_adjacent_vertices
@property
def edges(self):
"""Gets the edges of this UpsertEdgeRequest. # noqa: E501
:return: The edges of this UpsertEdgeRequest. # noqa: E501
:rtype: list[EdgeRecordInstance]
"""
return self._edges
@edges.setter
def edges(self, edges):
"""Sets the edges of this UpsertEdgeRequest.
:param edges: The edges of this UpsertEdgeRequest. # noqa: E501
:type: list[EdgeRecordInstance]
"""
if self.local_vars_configuration.client_side_validation and edges is None: # noqa: E501
raise ValueError("Invalid value for `edges`, must not be `None`") # noqa: E501
self._edges = edges
def to_dict(self):
"""Returns the model properties as a dict"""
result = {}
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
result[attr] = list(map(
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
value
))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
result[attr] = dict(map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict") else item,
value.items()
))
else:
result[attr] = value
return result
def to_str(self):
"""Returns the string representation of the model"""
return pprint.pformat(self.to_dict())
def __repr__(self):
"""For `print` and `pprint`"""
return self.to_str()
def __eq__(self, other):
"""Returns true if both objects are equal"""
if not isinstance(other, UpsertEdgeRequest):
return False
return self.to_dict() == other.to_dict()
def __ne__(self, other):
"""Returns true if both objects are not equal"""
if not isinstance(other, UpsertEdgeRequest):
return True
return self.to_dict() != other.to_dict()

View File

@ -0,0 +1,148 @@
# coding: utf-8
"""
kag
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) # noqa: E501
The version of the OpenAPI document: 1.0.0
Generated by: https://openapi-generator.tech
"""
import pprint
import re # noqa: F401
import six
from kag.common.rest.configuration import Configuration
class UpsertVertexRequest(object):
"""NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
Do not edit the class manually.
"""
"""
Attributes:
openapi_types (dict): The key is attribute name
and the value is attribute type.
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
openapi_types = {
'project_id': 'int',
'vertices': 'list[VertexRecordInstance]'
}
attribute_map = {
'project_id': 'projectId',
'vertices': 'vertices'
}
def __init__(self, project_id=None, vertices=None, local_vars_configuration=None): # noqa: E501
"""UpsertVertexRequest - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
self.local_vars_configuration = local_vars_configuration
self._project_id = None
self._vertices = None
self.discriminator = None
self.project_id = project_id
self.vertices = vertices
@property
def project_id(self):
"""Gets the project_id of this UpsertVertexRequest. # noqa: E501
:return: The project_id of this UpsertVertexRequest. # noqa: E501
:rtype: int
"""
return self._project_id
@project_id.setter
def project_id(self, project_id):
"""Sets the project_id of this UpsertVertexRequest.
:param project_id: The project_id of this UpsertVertexRequest. # noqa: E501
:type: int
"""
if self.local_vars_configuration.client_side_validation and project_id is None: # noqa: E501
raise ValueError("Invalid value for `project_id`, must not be `None`") # noqa: E501
self._project_id = project_id
@property
def vertices(self):
"""Gets the vertices of this UpsertVertexRequest. # noqa: E501
:return: The vertices of this UpsertVertexRequest. # noqa: E501
:rtype: list[VertexRecordInstance]
"""
return self._vertices
@vertices.setter
def vertices(self, vertices):
"""Sets the vertices of this UpsertVertexRequest.
:param vertices: The vertices of this UpsertVertexRequest. # noqa: E501
:type: list[VertexRecordInstance]
"""
if self.local_vars_configuration.client_side_validation and vertices is None: # noqa: E501
raise ValueError("Invalid value for `vertices`, must not be `None`") # noqa: E501
self._vertices = vertices
def to_dict(self):
"""Returns the model properties as a dict"""
result = {}
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
result[attr] = list(map(
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
value
))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
result[attr] = dict(map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict") else item,
value.items()
))
else:
result[attr] = value
return result
def to_str(self):
"""Returns the string representation of the model"""
return pprint.pformat(self.to_dict())
def __repr__(self):
"""For `print` and `pprint`"""
return self.to_str()
def __eq__(self, other):
"""Returns true if both objects are equal"""
if not isinstance(other, UpsertVertexRequest):
return False
return self.to_dict() == other.to_dict()
def __ne__(self, other):
"""Returns true if both objects are not equal"""
if not isinstance(other, UpsertVertexRequest):
return True
return self.to_dict() != other.to_dict()

View File

@ -0,0 +1,202 @@
# coding: utf-8
"""
kag
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) # noqa: E501
The version of the OpenAPI document: 1.0.0
Generated by: https://openapi-generator.tech
"""
import pprint
import re # noqa: F401
import six
from kag.common.rest.configuration import Configuration
class VertexRecordInstance(object):
"""NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
Do not edit the class manually.
"""
"""
Attributes:
openapi_types (dict): The key is attribute name
and the value is attribute type.
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
openapi_types = {
'type': 'str',
'id': 'str',
'properties': 'object',
'vectors': 'object'
}
attribute_map = {
'type': 'type',
'id': 'id',
'properties': 'properties',
'vectors': 'vectors'
}
def __init__(self, type=None, id=None, properties=None, vectors=None, local_vars_configuration=None): # noqa: E501
"""VertexRecordInstance - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
self.local_vars_configuration = local_vars_configuration
self._type = None
self._id = None
self._properties = None
self._vectors = None
self.discriminator = None
self.type = type
self.id = id
self.properties = properties
self.vectors = vectors
@property
def type(self):
"""Gets the type of this VertexRecordInstance. # noqa: E501
:return: The type of this VertexRecordInstance. # noqa: E501
:rtype: str
"""
return self._type
@type.setter
def type(self, type):
"""Sets the type of this VertexRecordInstance.
:param type: The type of this VertexRecordInstance. # noqa: E501
:type: str
"""
if self.local_vars_configuration.client_side_validation and type is None: # noqa: E501
raise ValueError("Invalid value for `type`, must not be `None`") # noqa: E501
self._type = type
@property
def id(self):
"""Gets the id of this VertexRecordInstance. # noqa: E501
:return: The id of this VertexRecordInstance. # noqa: E501
:rtype: str
"""
return self._id
@id.setter
def id(self, id):
"""Sets the id of this VertexRecordInstance.
:param id: The id of this VertexRecordInstance. # noqa: E501
:type: str
"""
if self.local_vars_configuration.client_side_validation and id is None: # noqa: E501
raise ValueError("Invalid value for `id`, must not be `None`") # noqa: E501
self._id = id
@property
def properties(self):
"""Gets the properties of this VertexRecordInstance. # noqa: E501
:return: The properties of this VertexRecordInstance. # noqa: E501
:rtype: object
"""
return self._properties
@properties.setter
def properties(self, properties):
"""Sets the properties of this VertexRecordInstance.
:param properties: The properties of this VertexRecordInstance. # noqa: E501
:type: object
"""
if self.local_vars_configuration.client_side_validation and properties is None: # noqa: E501
raise ValueError("Invalid value for `properties`, must not be `None`") # noqa: E501
self._properties = properties
@property
def vectors(self):
"""Gets the vectors of this VertexRecordInstance. # noqa: E501
:return: The vectors of this VertexRecordInstance. # noqa: E501
:rtype: object
"""
return self._vectors
@vectors.setter
def vectors(self, vectors):
"""Sets the vectors of this VertexRecordInstance.
:param vectors: The vectors of this VertexRecordInstance. # noqa: E501
:type: object
"""
if self.local_vars_configuration.client_side_validation and vectors is None: # noqa: E501
raise ValueError("Invalid value for `vectors`, must not be `None`") # noqa: E501
self._vectors = vectors
def to_dict(self):
"""Returns the model properties as a dict"""
result = {}
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
result[attr] = list(map(
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
value
))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
result[attr] = dict(map(
lambda item: (item[0], item[1].to_dict())
if hasattr(item[1], "to_dict") else item,
value.items()
))
else:
result[attr] = value
return result
def to_str(self):
"""Returns the string representation of the model"""
return pprint.pformat(self.to_dict())
def __repr__(self):
"""For `print` and `pprint`"""
return self.to_str()
def __eq__(self, other):
"""Returns true if both objects are equal"""
if not isinstance(other, VertexRecordInstance):
return False
return self.to_dict() == other.to_dict()
def __ne__(self, other):
"""Returns true if both objects are not equal"""
if not isinstance(other, VertexRecordInstance):
return True
return self.to_dict() != other.to_dict()

View File

@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.common.llm.client.openai_client import OpenAIClient
from kag.common.llm.client.vllm_client import VLLMClient
from kag.common.llm.client.llm_client import LLMClient
from kag.common.llm.client.ollama_client import OllamaClient
__all__ = [
"OpenAIClient",
"LLMClient",
"VLLMClient",
"OllamaClient"
]

View File

@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import json
from pathlib import Path
from typing import Union, Dict, List, Any
import logging
import traceback
import yaml
from kag.common.base.prompt_op import PromptOp
from kag.common.llm.config import *
logger = logging.getLogger(__name__)
config_cls_map = {
"maas": OpenAIConfig,
"vllm": VLLMConfig,
"ollama": OllamaConfig,
}
def get_config_cls(config:dict):
client_type = config.get("client_type", None)
return config_cls_map.get(client_type, None)
def get_llm_cls(config: LLMConfig):
from kag.common.llm.client import VLLMClient,OpenAIClient,OllamaClient
return {
VLLMConfig: VLLMClient,
OpenAIConfig: OpenAIClient,
OllamaConfig: OllamaClient,
}[config.__class__]
class LLMClient:
# Define the model type
model: str
def __init__(self, **kwargs):
self.model = kwargs.get("model", None)
@classmethod
def from_config(cls, config: Union[str, dict]):
"""
Initialize an LLMClient instance from a configuration file or dictionary.
:param config: Path to a configuration file or a configuration dictionary
:return: Initialized LLMClient instance
:raises FileNotFoundError: If the configuration file is not found
:raises ValueError: If the model type is unsupported
"""
if isinstance(config, str):
config_path = Path(config)
if config_path.is_file():
try:
with open(config_path, "r") as f:
nn_config = yaml.safe_load(f)
except:
logger.error(f"Failed to parse config file")
raise
else:
logger.error(f"Config file not found: {config}")
raise FileNotFoundError(f"Config file not found: {config}")
else:
# If config is already a dictionary, use it directly
nn_config = config
config_cls = get_config_cls(nn_config)
if config_cls is None:
logger.error(f"Unsupported model type: {nn_config.get('client_type', None)}")
raise ValueError(f"Unsupported model type")
llm_config = config_cls(**nn_config)
llm_cls = get_llm_cls(llm_config)
return llm_cls(llm_config)
def __call__(self, prompt: Union[str, dict, list]) -> str:
"""
Perform inference on the given prompt and return the result.
:param prompt: Input prompt for inference
:return: Inference result
:raises NotImplementedError: If the subclass has not implemented this method
"""
raise NotImplementedError
def call_with_json_parse(self, prompt: Union[str, dict, list]):
"""
Perform inference on the given prompt and attempt to parse the result as JSON.
:param prompt: Input prompt for inference
:return: Parsed result
:raises NotImplementedError: If the subclass has not implemented this method
"""
res = self(prompt)
_end = res.rfind("```")
_start = res.find("```json")
if _end != -1 and _start != -1:
json_str = res[_start + len("```json"): _end].strip()
else:
json_str = res
try:
json_result = json.loads(json_str)
except:
return res
return json_result
def invoke(self, variables: Dict[str, Any], prompt_op: PromptOp, with_json_parse: bool = True):
"""
Call the model and process the result.
:param variables: Variables used to build the prompt
:param prompt_op: Prompt operation object for building and parsing prompts
:param with_json_parse: Whether to attempt parsing the response as JSON
:return: Processed result list
"""
result = []
prompt = prompt_op.build_prompt(variables)
logger.debug(f"Prompt: {prompt}")
if not prompt:
return result
response = ""
try:
response = self.call_with_json_parse(prompt=prompt) if with_json_parse else self(prompt)
logger.debug(f"Response: {response}")
result = prompt_op.parse_response(response, model=self.model, **variables)
logger.debug(f"Result: {result}")
except Exception as e:
import traceback
logger.debug(f"Error {e} during invocation: {traceback.format_exc()}")
return result
def batch(self, variables: Dict[str, Any], prompt_op: PromptOp, with_json_parse: bool = True) -> List:
"""
Batch process prompts.
:param variables: Variables used to build the prompts
:param prompt_op: Prompt operation object for building and parsing prompts
:param with_json_parse: Whether to attempt parsing the response as JSON
:return: List of all processed results
"""
results = []
prompts = prompt_op.build_prompt(variables)
# If there is only one prompt, call the `invoke` method directly
if isinstance(prompts, str):
return self.invoke(variables, prompt_op, with_json_parse=with_json_parse)
for idx, prompt in enumerate(prompts, start=0):
logger.debug(f"Prompt_{idx}: {prompt}")
try:
response = self.call_with_json_parse(prompt=prompt) if with_json_parse else self(prompt)
logger.debug(f"Response_{idx}: {response}")
result = prompt_op.parse_response(response, idx=idx, model=self.model, **variables)
logger.debug(f"Result_{idx}: {result}")
results.extend(result)
except Exception as e:
logger.error(f"Error processing prompt {idx}: {e}")
logger.debug(traceback.format_exc())
continue
return results
if __name__ == "__main__":
from kag.common.env import init_kag_config
configFilePath = "/ossfs/workspace/workspace/openspgapp/openspg/python/kag/kag/common/default_config.cfg"
init_kag_config(configFilePath)
model = eval(os.getenv("KAG_LLM"))
print(model)
llm = LLMClient.from_config(model)
res = llm("who are you?")
print(res)

View File

@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import ast
import re
import json
import time
import uuid
import html
from binascii import b2a_hex
from datetime import datetime
from pathlib import Path
from typing import Union, Dict, List, Any
from urllib import request
from collections import defaultdict
from openai import OpenAI
import logging
import requests
import traceback
from Crypto.Cipher import AES
from requests import RequestException
from kag.common import arks_pb2
from kag.common.base.prompt_op import PromptOp
from kag.common.llm.config import OllamaConfig
from kag.common.llm.client.llm_client import LLMClient
# logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class OllamaClient(LLMClient):
def __init__(self, llm_config: OllamaConfig):
self.model = llm_config.model
self.base_url = llm_config.base_url
self.param = {}
def sync_request(self, prompt,image):
# import pdb; pdb.set_trace()
self.param["prompt"] = prompt
self.param["model"] = self.model
self.param["stream"] = False
if image:
self.param["images"] = [image]
response = requests.post(
self.base_url,
data=json.dumps(self.param),
headers={"Content-Type": "application/json"},
)
data = response.json()
content = data["response"]
content = content.replace("&rdquo;", "").replace("&ldquo;", "")
content = content.replace("&middot;", "")
return content
def __call__(self, prompt,image=None):
return self.sync_request(prompt,image)
def call_with_json_parse(self, prompt):
content = [{"role": "user", "content": prompt}]
rsp = self.sync_request(content)
_end = rsp.rfind("```")
_start = rsp.find("```json")
if _end != -1 and _start != -1:
json_str = rsp[_start + len("```json"): _end].strip()
else:
json_str = rsp
try:
json_result = json.loads(json_str)
except:
return rsp
return json_result

View File

@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from typing import Union
from openai import OpenAI
import logging
from kag.common.llm.client.llm_client import LLMClient
from kag.common.llm.config import OpenAIConfig
# logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class OpenAIClient(LLMClient):
"""
A client class for interacting with the OpenAI API.
Initializes the client with an API key, base URL, streaming option, temperature parameter, and default model.
Parameters:
api_key (str): The OpenAI API key.
base_url (str): The base URL of the API.
stream (bool, optional): Whether to process responses in a streaming manner. Default is False.
temperature (int, optional): Sampling temperature to control the randomness of the model's output. Default is 0.7.
model (str, optional): The default model to use.
Attributes:
api_key (str): The OpenAI API key.
base_url (str): The base URL of the API.
model (str): The default model to use.
stream (bool): Whether to process responses in a streaming manner.
temperature (float): Sampling temperature.
client (OpenAI): An instance of the OpenAI API client.
"""
def __init__(
self,
llm_config:OpenAIConfig
):
# Initialize the OpenAIClient object
self.api_key = llm_config.api_key
self.base_url = llm_config.base_url
self.model = llm_config.model
self.stream = llm_config.stream
self.temperature = llm_config.temperature
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
def __call__(self, prompt:str, image_url:str=None):
"""
Executes a model request when the object is called and returns the result.
Parameters:
prompt (str): The prompt provided to the model.
Returns:
str: The response content generated by the model.
"""
# Call the model with the given prompt and return the response
if image_url:
message = [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
},
]
response = self.client.chat.completions.create(
model=self.model,
messages=message,
stream=self.stream,
temperature=self.temperature,
)
rsp = response.choices[0].message.content
return rsp
else:
message = [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
model=self.model,
messages=message,
stream=self.stream,
temperature=self.temperature,
)
rsp = response.choices[0].message.content
return rsp
def call_with_json_parse(self, prompt):
"""
Calls the model and attempts to parse the response into JSON format.
Parameters:
prompt (str): The prompt provided to the model.
Returns:
Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response.
"""
# Call the model and attempt to parse the response into JSON format
rsp = self(prompt)
_end = rsp.rfind("```")
_start = rsp.find("```json")
if _end != -1 and _start != -1:
json_str = rsp[_start + len("```json"): _end].strip()
else:
json_str = rsp
try:
json_result = json.loads(json_str)
except:
return rsp
return json_result

View File

@ -0,0 +1,87 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import ast
import re
import json
import time
import uuid
import html
from binascii import b2a_hex
from datetime import datetime
from pathlib import Path
from typing import Union, Dict, List, Any
from urllib import request
from collections import defaultdict
from openai import OpenAI
import logging
import requests
import traceback
from Crypto.Cipher import AES
from requests import RequestException
from kag.common import arks_pb2
from kag.common.base.prompt_op import PromptOp
from kag.common.llm.config import VLLMConfig
from kag.common.llm.client.llm_client import LLMClient
# logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class VLLMClient(LLMClient):
def __init__(self, llm_config: VLLMConfig):
self.model = llm_config.model
self.base_url = llm_config.base_url
self.param = {}
def sync_request(self, prompt):
# import pdb; pdb.set_trace()
self.param["messages"] = prompt
self.param["model"] = self.model
response = requests.post(
self.base_url,
data=json.dumps(self.param),
headers={"Content-Type": "application/json"},
)
data = response.json()
content = data["choices"][0]["message"]["content"]
content = content.replace("&rdquo;", "").replace("&ldquo;", "")
content = content.replace("&middot;", "")
return content
def __call__(self, prompt):
content = [
{"role": "user", "content": prompt}
]
return self.sync_request(content)
def call_with_json_parse(self, prompt):
content = [{"role": "user", "content": prompt}]
rsp = self.sync_request(content)
_end = rsp.rfind("```")
_start = rsp.find("```json")
if _end != -1 and _start != -1:
json_str = rsp[_start + len("```json"): _end].strip()
else:
json_str = rsp
try:
json_result = json.loads(json_str)
except:
return rsp
return json_result

View File

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.common.llm.config.openai import OpenAIConfig
from kag.common.llm.config.base import LLMConfig
from kag.common.llm.config.vllm import VLLMConfig
from kag.common.llm.config.ollama import OllamaConfig
__all__ = [
"OpenAIConfig",
"LLMConfig",
"VLLMConfig",
"OllamaConfig"
]

View File

@ -0,0 +1,10 @@
"""LLM Parameters model."""
from pydantic import BaseModel, Field
import kag.common.llm.config.defaults as defs
class LLMConfig(BaseModel):
"""LLM Config model."""

View File

@ -0,0 +1,115 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig' and 'PipelineMemoryCacheConfig' models."""
from __future__ import annotations
from enum import Enum
class CacheType(str, Enum):
"""The cache configuration type for the pipeline."""
file = "file"
"""The file cache configuration type."""
memory = "memory"
"""The memory cache configuration type."""
none = "none"
"""The none cache configuration type."""
blob = "blob"
"""The blob cache configuration type."""
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
class InputFileType(str, Enum):
"""The input file type for the pipeline."""
csv = "csv"
"""The CSV input type."""
text = "text"
"""The text input type."""
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
class InputType(str, Enum):
"""The input type for the pipeline."""
file = "file"
"""The file storage type."""
blob = "blob"
"""The blob storage type."""
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
class StorageType(str, Enum):
"""The storage type for the pipeline."""
file = "file"
"""The file storage type."""
memory = "memory"
"""The memory storage type."""
blob = "blob"
"""The blob storage type."""
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
class ReportingType(str, Enum):
"""The reporting configuration type for the pipeline."""
file = "file"
"""The file reporting configuration type."""
console = "console"
"""The console reporting configuration type."""
blob = "blob"
"""The blob reporting configuration type."""
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
class TextEmbeddingTarget(str, Enum):
"""The target to use for text embeddings."""
all = "all"
required = "required"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
class LLMType(str, Enum):
"""LLMType enum class definition."""
# Embeddings
OpenAIEmbedding = "openai_embedding"
AzureOpenAIEmbedding = "azure_openai_embedding"
# Raw Completion
OpenAI = "openai"
AzureOpenAI = "azure_openai"
# Chat Completion
OpenAIChat = "openai_chat"
AzureOpenAIChat = "azure_openai_chat"
# Debug
StaticResponse = "static_response"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'

View File

@ -0,0 +1,67 @@
#-----------------------------------------------------------------------------------#
# openai SDK maas. client_type = maas #
#
# TongYi #
[llm] #
client_type = maas #
base_url = https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions #
api_key = "put your tongyi api key here" #
model = qwen-turbo # #
#
# Deepseek #
[llm] #
client_type = maas #
base_url = https://api.deepseek.com/beta #
api_key = "put your deepseek api key here" #
model = deepseek-chat #
#
# OpenAI #
[llm] #
client_type = maas #
base_url = https://api.openai.com/v1/chat/completions #
api_key = "put your openai api key here" #
model = gpt-3.5-turbo #
#
#-----------------------------------------------------------------------------------#
#-----------------------------------------------------------------------------------#
# local llm service. client_type = vllm #
#
# vllm #
[llm] #
client_type = vllm #
base_url = http://localhost:8000/v1/chat/completions #
model = qwen-7b-chat #
#
#-----------------------------------------------------------------------------------#
#-----------------------------------------------------------------------------------#
# maya llm service. client_type = maya #
#
[llm] #
client_type = maya #
scene_name = Qwen2_7B_Instruct_Knowledge #
chain_name = v1 #
lora_name = humming-v25 #
#
#-----------------------------------------------------------------------------------#
#-----------------------------------------------------------------------------------#
#
# ollama #
[llm]
client_type = ollama
base_url = http://localhost:11434/api/generate
model = llama3.1 #
#
#-----------------------------------------------------------------------------------#

View File

@ -0,0 +1,11 @@
from pydantic import Field
from kag.common.llm.config.base import LLMConfig
class OllamaConfig(LLMConfig):
model: str = Field(
description="model name."
)
base_url: str = Field(
description="post url."
)

View File

@ -0,0 +1,20 @@
from pydantic import Field
from kag.common.llm.config.base import LLMConfig
class OpenAIConfig(LLMConfig):
api_key: str = Field(
description="api key."
)
stream: bool = Field(
description="if use stream mode",default=False
)
model: str = Field(
description="model name."
)
temperature: float = Field(
description="temperature.",default=0.7
)
base_url: str = Field(
description="post url."
)

View File

@ -0,0 +1,9 @@
from kag.common.llm.config.base import ProxyLLMConfig
class GPTProxyLLMConfig(ProxyLLMConfig):
pass
class DeepSeekProxyLLMConfig(ProxyLLMConfig):
pass

View File

@ -0,0 +1,11 @@
from pydantic import Field
from kag.common.llm.config.base import LLMConfig
class VLLMConfig(LLMConfig):
model: str = Field(
description="model name."
)
base_url: str = Field(
description="post url."
)

View File

@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.common.reranker.bge_reranker import BGEReranker
from kag.common.reranker.reranker import Reranker
__all__ = [
"BGEReranker",
"Reranker"
]

View File

@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import numpy as np
from typing import List
from .reranker import Reranker
def rrf_score(length, r: int = 1):
"""
Calculates the RRF (Recursive Robust Function) scores.
This function generates a score sequence of the given length, where each score is calculated based on the index according to the formula 1/(r+i).
RRF is a method used in information retrieval and data analysis, and this function provides a way to generate weights based on document indices.
Parameters:
length: int, the length of the score sequence, i.e., the number of scores to generate.
r: int, optional, default is 1. Controls the starting index of the scores. Increasing the value of r shifts the emphasis towards later scores.
Returns:
numpy.ndarray, an array containing the scores calculated according to the given formula.
"""
return np.array([1 / (r + i) for i in range(length)])
class BGEReranker(Reranker):
"""
BGEReranker class is a subclass of Reranker that reranks given queries and passages.
This class uses the FlagReranker model from FlagEmbedding to score and reorder passages.
Args:
model_path (str): Path to the FlagReranker model.
use_fp16 (bool): Whether to use half-precision floating-point numbers for computation. Default is True.
"""
def __init__(self, model_path: str, use_fp16: bool = True):
from FlagEmbedding import FlagReranker
self.model_path = model_path
self.model = FlagReranker(self.model_path, use_fp16=use_fp16)
def rerank(self, queries: List[str], passages: List[str]):
"""
Reranks given queries and passages.
Args:
queries (List[str]): List of queries.
passages (List[str]): List of passages, where each passage is a string.
Returns:
new_passages (List[str]): List of passages after reranking.
"""
# Calculate initial ranking scores for passages
rank_scores = rrf_score(len(passages))
passage_scores = np.zeros(len(passages)) + rank_scores
# For each query, compute passage scores using the model and accumulate them
for query in queries:
scores = self.model.compute_score([[query, x] for x in passages])
sorted_idx = np.argsort(-np.array(scores))
for rank, passage_id in enumerate(sorted_idx):
passage_scores[passage_id] += rank_scores[rank]
# Perform final sorting of passages based on accumulated scores
merged_sorted_idx = np.argsort(-passage_scores)
new_passages = [passages[x] for x in merged_sorted_idx]
return new_passages

View File

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import numpy as np
from typing import List
def rrf_score(length, r: int = 1):
return np.array([1 / (r + i) for i in range(length)])
class Reranker:
"""
This class provides a framework for a reranker,
which is intended to re-rank the matches between queries and document passages.
"""
def __init__(self):
"""
Constructor for initializing the reranker class.
Currently, there are no specific initialization parameters or operations.
"""
pass
def rerank(self, queries: List[str], passages: List[str]):
"""
Function to re-rank queries and document passages,
aiming to reorder the input query and passage pairs according to a certain strategy.
Parameters:
queries (List[str]): A list of strings containing queries that need to be re-ranked.
passages (List[str]): A list of strings containing document passages that need to be re-ranked.
The function is currently not implemented and raises an exception to indicate this.
"""
raise NotImplementedError("rerank not implemented yet.")

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from kag.common.retriever.kag_retriever import DefaultRetriever
from kag.common.retriever.semantic_retriever import SemanticRetriever
from kag.common.retriever.retriever import Retriever
__all__ = [
"DefaultRetriever",
"SemanticRetriever",
"Retriever"
]

View File

@ -0,0 +1,422 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import json
from tenacity import retry, stop_after_attempt
from kag.common.base.prompt_op import PromptOp
from kag.common.vectorizer import Vectorizer
from knext.graph_algo.client import GraphAlgoClient
from kag.interface.retriever.chunk_retriever_abc import ChunkRetrieverABC
from typing import List, Dict
import numpy as np
import logging
from knext.reasoner.client import ReasonerClient
from knext.schema.client import CHUNK_TYPE, OTHER_TYPE
from knext.project.client import ProjectClient
from kag.common.utils import processing_phrases
from kag.common.llm.client.llm_client import LLMClient
from knext.search.client import SearchClient
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.config import LogicFormConfiguration
logger = logging.getLogger(__name__)
class DefaultRetriever(ChunkRetrieverABC):
"""
KAGRetriever class for retrieving and processing knowledge graph data from a graph database.
this retriever references the implementation of Hippoag for the combination of dpr & ppr, developer can define your Retriever
Parameters:
- project_id (str, optional): Project ID to load specific project configurations.
"""
def __init__(self, project_id: str = None):
self.project_id = project_id or os.environ.get("KAG_PROJECT_ID")
self._init_llm()
self._init_search(self.project_id)
biz_scene = os.getenv('KAG_PROMPT_BIZ_SCENE', 'default')
language = os.getenv('KAG_PROMPT_LANGUAGE', 'en')
self.ner_prompt = PromptOp.load(biz_scene, "question_ner")(language=language, project_id=self.project_id)
self.std_prompt = PromptOp.load(biz_scene, "std")(language=language)
self.pagerank_threshold = 0.9
self.match_threshold = 0.8
self.pagerank_weight = 0.5
self.reranker_model_path = os.getenv("KAG_RETRIEVER_RERANKER_MODEL_PATH")
if self.reranker_model_path:
from kag.common.reranker.reranker import BGEReranker
self.reranker = BGEReranker(self.reranker_model_path, use_fp16=True)
else:
self.reranker = None
self.with_semantic = True
def _init_llm(self):
llm_config = eval(os.getenv("KAG_LLM", "{}"))
project_id = int(self.project_id)
config = ProjectClient().get_config(project_id)
llm_config.update(config.get("llm", {}))
self.llm = LLMClient.from_config(llm_config)
def _init_search(self, project_id):
host_addr = os.getenv("KAG_PROJECT_HOST_ADDR")
self.schema_util = SchemaUtils(LogicFormConfiguration({
"project_id": project_id,
"host_addr": host_addr,
}))
self.sc: SearchClient = SearchClient(host_addr, int(project_id))
vectorizer_config = eval(os.getenv("KAG_VECTORIZER", "{}"))
self.vectorizer = Vectorizer.from_config(
vectorizer_config
)
self.reason: ReasonerClient = ReasonerClient(host_addr, int(project_id))
self.graph_algo = GraphAlgoClient(host_addr, int(project_id))
@retry(stop=stop_after_attempt(3))
def named_entity_recognition(self, query: str):
"""
Perform named entity recognition.
This method invokes the pre-configured service client (self.llm) to process the input query,
using the named entity recognition (NER) prompt (self.ner_prompt).
Parameters:
query (str): The text input provided by the user or system for named entity recognition.
Returns:
The result returned by the service client, with the type and format depending on the used service.
"""
return self.llm.invoke({"input": query}, self.ner_prompt)
@retry(stop=stop_after_attempt(3))
def named_entity_standardization(self, query: str, entities: List[Dict]):
"""
Entity standardization function.
This function calls a remote service to process the input query and named entities,
standardizing the entities. This is useful for unifying different representations of the same entity in text,
improving the performance of natural language processing tasks.
Parameters:
- query: A string containing the query with named entities.
- entities: A list of dictionaries, each containing information about named entities.
Returns:
- The result of the remote service call, typically standardized named entity information.
"""
return self.llm.invoke(
{"input": query, "named_entities": entities}, self.std_prompt
)
@staticmethod
def append_official_name(source_entities: List[Dict], entities_with_official_name: List[Dict]):
"""
Appends official names to entities.
Parameters:
source_entities (List[Dict]): A list of source entities.
entities_with_official_name (List[Dict]): A list of entities with official names.
"""
tmp_dict = {}
for tmp_entity in entities_with_official_name:
name = tmp_entity["entity"]
category = tmp_entity["category"]
official_name = tmp_entity["official_name"]
key = f"{category}{name}"
tmp_dict[key] = official_name
for tmp_entity in source_entities:
name = tmp_entity["entity"]
category = tmp_entity["category"]
key = f"{category}{name}"
if key in tmp_dict:
official_name = tmp_dict[key]
tmp_entity["official_name"] = official_name
def calculate_sim_scores(self, query: str, doc_nums: int):
"""
Calculate the vector similarity scores between a query and document chunks.
Parameters:
query (str): The user's query text.
doc_nums (int): The number of document chunks to return.
Returns:
dict: A dictionary with keys as document chunk IDs and values as the vector similarity scores.
"""
scores = dict()
try:
query_vector = self.vectorizer.vectorize(query)
top_k = self.sc.search_vector(
label=self.schema_util.get_label_within_prefix(CHUNK_TYPE),
property_key="content",
query_vector=query_vector,
topk=doc_nums
)
scores = {item["node"]["id"]: item["score"] for item in top_k}
except Exception as e:
logger.error(
f"run calculate_sim_scores failed, info: {e}", exc_info=True
)
return scores
def calculate_pagerank_scores(self, start_nodes: List[Dict]):
"""
Calculate and retrieve PageRank scores for the given starting nodes.
Parameters:
start_nodes (list): A list containing document fragment IDs to be used as starting nodes for the PageRank algorithm.
Returns:
ppr_doc_scores (dict): A dictionary containing each document fragment ID and its corresponding PageRank score.
This method uses the PageRank algorithm in the graph store to compute scores for document fragments. If `start_nodes` is empty,
it returns an empty dictionary. Otherwise, it attempts to retrieve PageRank scores from the graph store and converts the result
into a dictionary format where keys are document fragment IDs and values are their respective PageRank scores. Any exceptions,
such as failures in running `run_pagerank_igraph_chunk`, are logged.
"""
scores = dict()
if len(start_nodes) != 0:
try:
scores = self.graph_algo.calculate_pagerank_scores(
self.schema_util.get_label_within_prefix(CHUNK_TYPE),
start_nodes
)
except Exception as e:
logger.error(
f"run calculate_pagerank_scores failed, info: {e}, start_nodes: {start_nodes}", exc_info=True
)
return scores
def match_entities(self, queries: Dict[str, str], top_k: int = 1):
"""
Match entities based on the provided queries.
:param queries: A dictionary containing keywords and their labels.
:param top_k: The number of top results to return. Default is 1.
:return: A tuple containing a list of matched entities and their scores.
"""
matched_entities = []
matched_entities_scores = []
for query, query_type in queries.items():
query = processing_phrases(query)
if query_type not in self.schema_util.node_en_zh.keys():
query_type = self.schema_util.get_label_within_prefix(OTHER_TYPE)
typed_nodes = self.sc.search_vector(
label=query_type,
property_key="name",
query_vector=self.vectorizer.vectorize(query),
topk=top_k,
)
if query_type != self.schema_util.get_label_within_prefix(OTHER_TYPE):
nontyped_nodes = self.sc.search_vector(
label=self.schema_util.get_label_within_prefix(OTHER_TYPE),
property_key="name",
query_vector=self.vectorizer.vectorize(query),
topk=top_k,
)
else:
nontyped_nodes = typed_nodes
if len(typed_nodes) == 0 and len(nontyped_nodes) != 0:
matched_entities.append(
{"name": nontyped_nodes[0]["node"]["name"], "type": OTHER_TYPE}
)
matched_entities_scores.append(nontyped_nodes[0]["score"])
elif len(typed_nodes) != 0 and len(nontyped_nodes) != 0:
if typed_nodes[0]["score"] > 0.8:
matched_entities.append(
{"name": typed_nodes[0]["node"]["name"], "type": query_type}
)
matched_entities_scores.append(typed_nodes[0]["score"])
else:
matched_entities.append(
{"name": nontyped_nodes[0]["node"]["name"], "type": OTHER_TYPE}
)
matched_entities_scores.append(nontyped_nodes[0]["score"])
matched_entities.append(
{"name": typed_nodes[0]["node"]["name"], "type": query_type}
)
matched_entities_scores.append(typed_nodes[0]["score"])
elif len(typed_nodes) != 0 and len(nontyped_nodes) == 0:
if typed_nodes[0]["score"] > 0.8:
matched_entities.append(
{"name": typed_nodes[0]["node"]["name"], "type": query_type}
)
matched_entities_scores.append(typed_nodes[0]["score"])
if not matched_entities:
logger.info(f"No entities matched for {queries}")
return matched_entities, matched_entities_scores
def calculate_combined_scores(self, sim_scores: Dict[str, float], pagerank_scores: Dict[str, float]):
"""
Calculate and return the combined scores that integrate both similarity scores and PageRank scores.
Parameters:
sim_scores (Dict[str, float]): A dictionary containing similarity scores, where keys are identifiers and values are scores.
pagerank_scores (Dict[str, float]): A dictionary containing PageRank scores, where keys are identifiers and values are scores.
Returns:
Dict[str, float]: A dictionary containing the combined scores, where keys are identifiers and values are the combined scores.
"""
def min_max_normalize(x):
if len(x) == 0:
return []
if np.max(x) - np.min(x) > 0:
return (x - np.min(x)) / (np.max(x) - np.min(x))
else:
return x - np.min(x)
all_keys = set(pagerank_scores.keys()).union(set(sim_scores.keys()))
for key in all_keys:
sim_scores.setdefault(key, 0.0)
pagerank_scores.setdefault(key, 0.0)
sim_scores = dict(zip(sim_scores.keys(), min_max_normalize(
np.array(list(sim_scores.values()))
)))
pagerank_scores = dict(zip(pagerank_scores.keys(), min_max_normalize(
np.array(list(pagerank_scores.values()))
)))
combined_scores = dict()
for key in pagerank_scores.keys():
combined_scores[key] = (sim_scores[key] * (1 - self.pagerank_weight) +
pagerank_scores[key] * self.pagerank_weight
)
return combined_scores
def recall_docs(self, query: str, top_k: int = 5, **kwargs):
"""
Recall relevant documents based on the query string.
Parameters:
- query (str): The user's query string.
- top_k (int, optional): The number of documents to return, default is 5.
Keyword Arguments:
- kwargs: Additional keyword arguments.
Returns:
- list: A list containing the top_k most relevant documents.
"""
assert isinstance(query, str), "Query must be a string"
chunk_nums = top_k * 20
if chunk_nums == 0:
return []
ner_list = self.named_entity_recognition(query)
print(ner_list)
if self.with_semantic:
std_ner_list = self.named_entity_standardization(query, ner_list)
self.append_official_name(ner_list, std_ner_list)
entities = {}
for item in ner_list:
entity = item.get("entity", "")
category = item.get("category", "")
official_name = item.get("official_name", "")
if not entity or not (category or official_name):
continue
if category.lower() in ["works", "person", "other"]:
entities[entity] = category
else:
entities[entity] = official_name or category
sim_scores = self.calculate_sim_scores(query, chunk_nums)
matched_entities, matched_scores = self.match_entities(entities)
pagerank_scores = self.calculate_pagerank_scores(matched_entities)
if not matched_entities:
combined_scores = sim_scores
elif matched_entities and np.min(matched_scores) > self.pagerank_threshold:
combined_scores = pagerank_scores
else:
combined_scores = self.calculate_combined_scores(sim_scores, pagerank_scores)
sorted_scores = sorted(
combined_scores.items(), key=lambda item: item[1], reverse=True
)
logger.debug(f"sorted_scores: {sorted_scores}")
return self.get_all_docs_by_id(query, sorted_scores, top_k)
def get_all_docs_by_id(self, query: str, doc_ids: list, top_k: int):
"""
Retrieve a list of documents based on their IDs.
Parameters:
- query (str): The query string for text matching.
- doc_ids (list): A list of document IDs to retrieve documents.
- top_k (int): The maximum number of documents to return.
Returns:
- list: A list of matched documents.
"""
matched_docs = []
hits_docs = set()
counter = 0
for doc_id in doc_ids:
if counter == top_k:
break
if isinstance(doc_id, tuple):
doc_score = doc_id[1]
doc_id = doc_id[0]
else:
doc_score = doc_ids[doc_id]
counter += 1
node = self.reason.query_node(label=self.schema_util.get_label_within_prefix(CHUNK_TYPE), id_value=doc_id)
node_dict = dict(node.items())
matched_docs.append(f"#{node_dict['name']}#{node_dict['content']}#{doc_score}")
hits_docs.add(node_dict['name'])
try:
text_matched = self.sc.search_text(query, [self.schema_util.get_label_within_prefix(CHUNK_TYPE)], topk=1)
if text_matched:
for item in text_matched:
title = item["node"]["name"]
if title not in hits_docs:
if len(matched_docs) > 0:
matched_docs.pop()
else:
logger.warning(f"{query} matched docs is empty")
matched_docs.append(f'#{item["node"]["name"]}#{item["node"]["content"]}#{item["score"]}')
break
except Exception as e:
logger.warning(f"{query} query chunk failed: {e}", exc_info=True)
logger.debug(f"matched_docs: {matched_docs}")
return matched_docs
def rerank_docs(self, queries: List[str], passages: List[str]):
"""
Re-ranks the given passages based on the provided queries.
Parameters:
- queries (List[str]): A list of queries.
- passages (List[str]): A list of passages.
Returns:
- List[str]: A re-ranked list of passages.
"""
if self.reranker is None:
return passages
return self.reranker.rerank(queries, passages)

View File

@ -0,0 +1,134 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import io
import json
from pathlib import Path
from abc import ABC, abstractmethod
from typing import Any, Union, Iterable, Tuple
from typing import Dict
import logging
logger = logging.getLogger(__name__)
Item = Dict[str, Any]
RetrievalResult = Iterable[Tuple[Item, float]]
class Retriever(ABC):
"""
Retriever indexing a collection of items and supports fast retrieving of the
desired items given a query.
"""
@classmethod
def from_config(cls, config: Union[str, Path, Dict[str, Any]]) -> "Retriever":
"""
Create retriever from `config`.
If `config` is a string or path, it will be loaded as a dictionary depending
on its file extension. Currently, the following formats are supported:
* .json: JSON
* .json5: JSON with comments support
* .yaml: YAML
:param config: retriever config
:type config: str, Path or Dict[str, Any]
:return: retriever instance
:rtype: Retriever
"""
from kag.common.utils import dynamic_import_class
if isinstance(config, (str, Path)):
config_path = config
if not isinstance(config_path, Path):
config_path = Path(config_path)
if config_path.name.endswith(".yaml"):
import yaml
with io.open(config_path, "r", encoding="utf-8") as fin:
config = yaml.safe_load(fin)
elif config_path.name.endswith(".json5"):
import json5
with io.open(config_path, "r", encoding="utf-8") as fin:
config = json5.load(fin)
elif config_path.name.endswith(".json"):
with io.open(config_path, "r", encoding="utf-8") as fin:
config = json.load(fin)
else:
message = "only .json, .json5 and .yaml are supported currently; "
message += "can not load retriever config from %r" % str(config_path)
raise RuntimeError(message)
elif isinstance(config, dict):
pass
else:
message = "only str, Path and dict are supported; "
message += "invalid retriever config: %r" % (config,)
raise RuntimeError(message)
class_name = config.get("retriever")
if class_name is None:
message = "retriever class name is not specified"
raise RuntimeError(message)
retriever_class = dynamic_import_class(class_name, "retriever")
if not issubclass(retriever_class, Retriever):
message = "class %r is not a retriever class" % (class_name,)
raise RuntimeError(message)
retriever = retriever_class._from_config(config)
return retriever
@classmethod
@abstractmethod
def _from_config(cls, config: Dict[str, Any]) -> "Retriever":
"""
Create retriever from `config`. This method is supposed to be implemented
by derived classes.
:param config: retriever config
:type config: Dict[str, Any]
:return: retriever instance
:rtype: Retriever
"""
message = "abstract method _from_config is not implemented"
raise NotImplementedError(message)
def index(self, items: Union[Item, Iterable[Item]]) -> None:
"""
Add one or more items to the index of the retriever.
NOTE: This method may not be supported by the retriever.
:param items: items to index
:type items: Item or Iterable[Item]
"""
message = "method index is not supported by the retriever"
raise RuntimeError(message)
@abstractmethod
def retrieve(
self, queries: Union[str, Iterable[str]], top_k: int = 10
) -> Union[RetrievalResult, Iterable[RetrievalResult]]:
"""
Retrieve items for the given query or queries.
:param queries: queries to retrieve
:type queries: str or Iterable[str]
:param int top_k: how many most related items to return for each query, default to 10
:return: retrieval results of the queries
:rtype: RetrievalResult or Iterable[RetrievalResult]
"""
message = "abstract method retrieve is not implemented"
raise NotImplementedError(message)

View File

@ -0,0 +1,115 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from typing import Any, Iterable, Tuple
from typing import Dict
import logging
from kag.common.utils import processing_phrases
from kag.common.semantic_infer import SemanticEnhance
from kag.common.retriever.kag_retriever import DefaultRetriever
from kag.common.retriever.kag_retriever import DefaultRetriever
logger = logging.getLogger(__name__)
Item = Dict[str, Any]
RetrievalResult = Iterable[Tuple[Item, float]]
class SemanticRetriever(DefaultRetriever, SemanticEnhance):
def __init__(self, project_id: str = None, **kwargs):
DefaultRetriever.__init__(self, project_id)
SemanticEnhance.__init__(self, **kwargs)
self.general_label = "Entity"
self.max_expand = 2
self.concept_sim_t = 0.9
def get_top_phrases(self, query_ner_list, query_ner_type_list, context=None):
"""
语义增强改造: entity -[sim]-> node ==> entity -[semantic]-> node
"""
phrase_ids = []
query_phrases = []
max_scores = []
for query, query_type in zip(query_ner_list, query_ner_type_list):
query = processing_phrases(query)
query_phrases.append(query)
query_node = self.graph_store.get_node(
label=self.general_label, id_value=query
)
if query_node is not None:
n_type = [i for i in query_node.labels if i != self.general_label][0]
phrase_ids.append(
{
"name": query_node["name"],
"type": n_type,
"_source": "exact_match",
}
)
max_scores.append(self.pagerank_threshold)
query_concepts = [
n
for n in self.expand_semantic_concept(query, context=context)
if processing_phrases(n["name"]) not in [processing_phrases(query)]
]
for ix, concept in enumerate(query_concepts):
if ix >= self.max_expand:
continue
concept["name"] = processing_phrases(concept["name"])
concept_node = self.graph_store.get_node(
label=self.concept_label, id_value=concept["name"]
)
if concept_node is not None:
phrase_ids.append(
{
"name": concept_node["name"],
"type": self.concept_label,
"_source": "expand_concept",
}
)
max_scores.append(self.pagerank_threshold)
else:
# pass
recall_concepts = self.graph_store.vector_search(
label=self.concept_label,
property_key="name",
query_text_or_vector=concept["name"],
topk=5,
)
all_nodes = [
n["node"]["name"]
for n in recall_concepts
if n["score"] >= self.concept_sim_t
]
for name in all_nodes:
if name in {n["name"] for n in phrase_ids}:
continue
semantic_node = {
"name": name,
"type": self.concept_label,
"_source": "sim_concept",
}
phrase_ids.append(semantic_node)
max_scores.append(self.pagerank_threshold)
break
# top_phrase_vec = np.zeros(self.num_vertices)
if len(phrase_ids) == 0:
logger.error(
f"ERROR, no phrases found for {query_ner_list}, {query_ner_type_list}"
)
return phrase_ids, max_scores

Some files were not shown because too many files have changed in this diff Show More