mirror of https://github.com/OpenSPG/KAG
191 lines
6.4 KiB
Python
191 lines
6.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2023 OpenSPG Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
|
# in compliance with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
# or implied.
|
|
import pprint
|
|
from typing import Dict, List, Any
|
|
|
|
from knext.schema.client import BASIC_TYPES
|
|
from kag.builder.model.spg_record import SPGRecord
|
|
from knext.schema.model.base import BaseSpgType
|
|
|
|
|
|
class Node(object):
|
|
id: str
|
|
name: str
|
|
label: str
|
|
properties: Dict[str, str]
|
|
hash_map: Dict[int, str] = dict()
|
|
|
|
def __init__(self, _id: str, name: str, label: str, properties: Dict[str, str]):
|
|
self.name = name
|
|
self.label = label
|
|
self.properties = properties
|
|
self.id = _id
|
|
|
|
@classmethod
|
|
def from_spg_record(cls, idx, spg_record: SPGRecord):
|
|
return cls(
|
|
_id=idx,
|
|
name=spg_record.get_property("name"),
|
|
label=spg_record.spg_type_name,
|
|
properties=spg_record.properties,
|
|
)
|
|
|
|
@staticmethod
|
|
def unique_key(spg_record):
|
|
return spg_record.spg_type_name + '_' + spg_record.get_property("name", "")
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"id": self.id,
|
|
"name": self.name,
|
|
"label": self.label,
|
|
"properties": self.properties,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, input: Dict):
|
|
return cls(
|
|
_id=input["id"],
|
|
name=input["name"],
|
|
label=input["label"],
|
|
properties=input["properties"],
|
|
)
|
|
|
|
def __eq__(self, other):
|
|
return self.name == other.name and self.label == other.label and self.properties == other.properties
|
|
|
|
|
|
class Edge(object):
|
|
id: str
|
|
from_id: str
|
|
from_type: str
|
|
to_id: str
|
|
to_type: str
|
|
label: str
|
|
properties: Dict[str, str]
|
|
|
|
def __init__(
|
|
self, _id: str, from_node: Node, to_node: Node, label: str, properties: Dict[str, str]
|
|
):
|
|
self.from_id = from_node.id
|
|
self.from_type = from_node.label
|
|
self.to_id = to_node.id
|
|
self.to_type = to_node.label
|
|
self.label = label
|
|
self.properties = properties
|
|
if not _id:
|
|
_id = id(self)
|
|
self.id = _id
|
|
|
|
@classmethod
|
|
def from_spg_record(
|
|
cls, s_idx, subject_record: SPGRecord, o_idx, object_record: SPGRecord, label: str
|
|
):
|
|
from_node = Node.from_spg_record(s_idx, subject_record)
|
|
to_node = Node.from_spg_record(o_idx, object_record)
|
|
|
|
return cls(_id="", from_node=from_node, to_node=to_node, label=label, properties={})
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"id": self.id,
|
|
"from": self.from_id,
|
|
"to": self.to_id,
|
|
"fromType": self.from_type,
|
|
"toType": self.to_type,
|
|
"label": self.label,
|
|
"properties": self.properties,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, input: Dict):
|
|
return cls(
|
|
_id=input["id"],
|
|
from_node=Node(_id=input["from"], name=input["from"],label=input["fromType"], properties={}),
|
|
to_node=Node(_id=input["to"], name=input["to"], label=input["toType"], properties={}),
|
|
label=input["label"],
|
|
properties=input["properties"],
|
|
)
|
|
|
|
def __eq__(self, other):
|
|
return self.from_id == other.from_id and self.to_id == other.to_id and self.label == other.label and self.properties == other.properties and self.from_type == other.from_type and self.to_type == other.to_type
|
|
|
|
|
|
class SubGraph(object):
|
|
id: str
|
|
nodes: List[Node] = list()
|
|
edges: List[Edge] = list()
|
|
|
|
def __init__(self, nodes: List[Node], edges: List[Edge]):
|
|
self.nodes = nodes
|
|
self.edges = edges
|
|
|
|
def add_node(self, id: str, name: str, label: str, properties=None):
|
|
if not properties:
|
|
properties = dict()
|
|
self.nodes.append(Node(_id=id, name=name, label=label, properties=properties))
|
|
return self
|
|
|
|
def add_edge(self, s_id: str, s_label: str, p: str, o_id: str, o_label: str, properties=None):
|
|
if not properties:
|
|
properties = dict()
|
|
s_node = Node(_id=s_id, name=s_id, label=s_label, properties={})
|
|
o_node = Node(_id=o_id, name=o_id, label=o_label, properties={})
|
|
self.edges.append(Edge(_id="", from_node=s_node, to_node=o_node, label=p, properties=properties))
|
|
return self
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"resultNodes": [n.to_dict() for n in self.nodes],
|
|
"resultEdges": [e.to_dict() for e in self.edges],
|
|
}
|
|
|
|
def __repr__(self):
|
|
return pprint.pformat(self.to_dict())
|
|
|
|
def merge(self, sub_graph: 'SubGraph'):
|
|
self.nodes.extend(sub_graph.nodes)
|
|
self.edges.extend(sub_graph.edges)
|
|
|
|
@classmethod
|
|
def from_spg_record(
|
|
cls, spg_types: Dict[str, BaseSpgType], spg_records: List[SPGRecord]
|
|
):
|
|
sub_graph = cls([], [])
|
|
for record in spg_records:
|
|
s_id = record.id
|
|
s_name = record.name
|
|
s_label = record.spg_type_name.split('.')[-1]
|
|
properties = record.properties
|
|
|
|
spg_type = spg_types.get(record.spg_type_name)
|
|
for prop_name, prop_value in record.properties.items():
|
|
if prop_name in spg_type.properties:
|
|
from knext.schema.model.property import Property
|
|
prop: Property = spg_type.properties.get(prop_name)
|
|
o_label = prop.object_type_name.split('.')[-1]
|
|
if o_label not in BASIC_TYPES:
|
|
prop_value_list = prop_value.split(',')
|
|
for o_id in prop_value_list:
|
|
sub_graph.add_edge(s_id=s_id, s_label=s_label, p=prop_name, o_id=o_id, o_label=o_label)
|
|
properties.pop(prop_name)
|
|
sub_graph.add_node(id=s_id, name=s_name, label=s_label, properties=properties)
|
|
|
|
return sub_graph
|
|
|
|
@classmethod
|
|
def from_dict(cls, input: Dict[str, Any]):
|
|
return cls(
|
|
nodes=[Node.from_dict(node) for node in input["resultNodes"]],
|
|
edges=[Edge.from_dict(edge) for edge in input["resultEdges"]],
|
|
)
|