autogen/python/packages/autogen-studio/autogenstudio/database/dbmanager.py

492 lines
20 KiB
Python

import threading
from datetime import datetime
from typing import Optional
from loguru import logger
from sqlalchemy import exc
from sqlmodel import Session, SQLModel, and_, create_engine, select
from ..datamodel import (
Agent,
AgentLink,
AgentModelLink,
AgentSkillLink,
Model,
Response,
Skill,
Workflow,
WorkflowAgentLink,
WorkflowAgentType,
)
from .utils import init_db_samples
valid_link_types = ["agent_model", "agent_skill", "agent_agent", "workflow_agent"]
class WorkflowAgentMap(SQLModel):
agent: Agent
link: WorkflowAgentLink
class DBManager:
"""A class to manage database operations"""
_init_lock = threading.Lock() # Class-level lock
def __init__(self, engine_uri: str):
connection_args = {"check_same_thread": True} if "sqlite" in engine_uri else {}
self.engine = create_engine(engine_uri, connect_args=connection_args)
# run_migration(engine_uri=engine_uri)
def create_db_and_tables(self):
"""Create a new database and tables"""
with self._init_lock: # Use the lock
try:
SQLModel.metadata.create_all(self.engine)
try:
init_db_samples(self)
except Exception as e:
logger.info("Error while initializing database samples: " + str(e))
except Exception as e:
logger.info("Error while creating database tables:" + str(e))
def upsert(self, model: SQLModel):
"""Create a new entity"""
# check if the model exists, update else add
status = True
model_class = type(model)
existing_model = None
with Session(self.engine) as session:
try:
existing_model = session.exec(select(model_class).where(model_class.id == model.id)).first()
if existing_model:
model.updated_at = datetime.now()
for key, value in model.model_dump().items():
setattr(existing_model, key, value)
model = existing_model
session.add(model)
else:
session.add(model)
session.commit()
session.refresh(model)
except Exception as e:
session.rollback()
logger.error("Error while updating " + str(model_class.__name__) + ": " + str(e))
status = False
response = Response(
message=(
f"{model_class.__name__} Updated Successfully "
if existing_model
else f"{model_class.__name__} Created Successfully"
),
status=status,
data=model.model_dump(),
)
return response
def _model_to_dict(self, model_obj):
return {col.name: getattr(model_obj, col.name) for col in model_obj.__table__.columns}
def get_items(
self,
model_class: SQLModel,
session: Session,
filters: dict = None,
return_json: bool = False,
order: str = "desc",
):
"""List all entities"""
result = []
status = True
status_message = ""
try:
if filters:
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
statement = select(model_class).where(and_(*conditions))
if hasattr(model_class, "created_at") and order:
if order == "desc":
statement = statement.order_by(model_class.created_at.desc())
else:
statement = statement.order_by(model_class.created_at.asc())
else:
statement = select(model_class)
if return_json:
result = [self._model_to_dict(row) for row in session.exec(statement).all()]
else:
result = session.exec(statement).all()
status_message = f"{model_class.__name__} Retrieved Successfully"
except Exception as e:
session.rollback()
status = False
status_message = f"Error while fetching {model_class.__name__}"
logger.error("Error while getting items: " + str(model_class.__name__) + " " + str(e))
response: Response = Response(
message=status_message,
status=status,
data=result,
)
return response
def get(
self,
model_class: SQLModel,
filters: dict = None,
return_json: bool = False,
order: str = "desc",
):
"""List all entities"""
with Session(self.engine) as session:
response = self.get_items(model_class, session, filters, return_json, order)
return response
def delete(self, model_class: SQLModel, filters: dict = None):
"""Delete an entity"""
row = None
status_message = ""
status = True
with Session(self.engine) as session:
try:
if filters:
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
row = session.exec(select(model_class).where(and_(*conditions))).all()
else:
row = session.exec(select(model_class)).all()
if row:
for row in row:
session.delete(row)
session.commit()
status_message = f"{model_class.__name__} Deleted Successfully"
else:
print(f"Row with filters {filters} not found")
logger.info("Row with filters + filters + not found")
status_message = "Row not found"
except exc.IntegrityError as e:
session.rollback()
logger.error("Integrity ... Error while deleting: " + str(e))
status_message = f"The {model_class.__name__} is linked to another entity and cannot be deleted."
status = False
except Exception as e:
session.rollback()
logger.error("Error while deleting: " + str(e))
status_message = f"Error while deleting: {e}"
status = False
response = Response(
message=status_message,
status=status,
data=None,
)
return response
def get_linked_entities(
self,
link_type: str,
primary_id: int,
return_json: bool = False,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = None,
):
"""
Get all entities linked to the primary entity.
Args:
link_type (str): The type of link to retrieve, e.g., "agent_model".
primary_id (int): The identifier for the primary model.
return_json (bool): Whether to return the result as a JSON object.
Returns:
List[SQLModel]: A list of linked entities.
"""
linked_entities = []
if link_type not in valid_link_types:
return []
status = True
status_message = ""
with Session(self.engine) as session:
try:
if link_type == "agent_model":
# get the agent
agent = self.get_items(Agent, filters={"id": primary_id}, session=session).data[0]
linked_entities = agent.models
elif link_type == "agent_skill":
agent = self.get_items(Agent, filters={"id": primary_id}, session=session).data[0]
linked_entities = agent.skills
elif link_type == "agent_agent":
agent = self.get_items(Agent, filters={"id": primary_id}, session=session).data[0]
linked_entities = agent.agents
elif link_type == "workflow_agent":
linked_entities = session.exec(
select(WorkflowAgentLink, Agent)
.join(Agent, WorkflowAgentLink.agent_id == Agent.id)
.where(
WorkflowAgentLink.workflow_id == primary_id,
)
).all()
linked_entities = [WorkflowAgentMap(agent=agent, link=link) for link, agent in linked_entities]
linked_entities = sorted(linked_entities, key=lambda x: x.link.sequence_id) # type: ignore
except Exception as e:
logger.error("Error while getting linked entities: " + str(e))
status_message = f"Error while getting linked entities: {e}"
status = False
if return_json:
linked_entities = [row.model_dump() for row in linked_entities]
response = Response(
message=status_message,
status=status,
data=linked_entities,
)
return response
def link(
self,
link_type: str,
primary_id: int,
secondary_id: int,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = None,
) -> Response:
"""
Link two entities together.
Args:
link_type (str): The type of link to create, e.g., "agent_model".
primary_id (int): The identifier for the primary model.
secondary_id (int): The identifier for the secondary model.
agent_type (Optional[str]): The type of agent, e.g., "sender" or receiver.
Returns:
Response: The response of the linking operation, including success status and message.
"""
# TBD verify that is creator of the primary entity being linked
status = True
status_message = ""
primary_model = None
secondary_model = None
if link_type not in valid_link_types:
status = False
status_message = f"Invalid link type: {link_type}. Valid link types are: {valid_link_types}"
else:
with Session(self.engine) as session:
try:
if link_type == "agent_model":
primary_model = session.exec(select(Agent).where(Agent.id == primary_id)).first()
secondary_model = session.exec(select(Model).where(Model.id == secondary_id)).first()
if primary_model is None or secondary_model is None:
status = False
status_message = "One or both entity records do not exist."
else:
# check if the link already exists
existing_link = session.exec(
select(AgentModelLink).where(
AgentModelLink.agent_id == primary_id,
AgentModelLink.model_id == secondary_id,
)
).first()
if existing_link: # link already exists
return Response(
message=(
f"{secondary_model.__class__.__name__} already linked "
f"to {primary_model.__class__.__name__}"
),
status=False,
)
else:
primary_model.models.append(secondary_model)
elif link_type == "agent_agent":
primary_model = session.exec(select(Agent).where(Agent.id == primary_id)).first()
secondary_model = session.exec(select(Agent).where(Agent.id == secondary_id)).first()
if primary_model is None or secondary_model is None:
status = False
status_message = "One or both entity records do not exist."
else:
# check if the link already exists
existing_link = session.exec(
select(AgentLink).where(
AgentLink.parent_id == primary_id,
AgentLink.agent_id == secondary_id,
)
).first()
if existing_link:
return Response(
message=(
f"{secondary_model.__class__.__name__} already linked "
f"to {primary_model.__class__.__name__}"
),
status=False,
)
else:
primary_model.agents.append(secondary_model)
elif link_type == "agent_skill":
primary_model = session.exec(select(Agent).where(Agent.id == primary_id)).first()
secondary_model = session.exec(select(Skill).where(Skill.id == secondary_id)).first()
if primary_model is None or secondary_model is None:
status = False
status_message = "One or both entity records do not exist."
else:
# check if the link already exists
existing_link = session.exec(
select(AgentSkillLink).where(
AgentSkillLink.agent_id == primary_id,
AgentSkillLink.skill_id == secondary_id,
)
).first()
if existing_link:
return Response(
message=(
f"{secondary_model.__class__.__name__} already linked "
f"to {primary_model.__class__.__name__}"
),
status=False,
)
else:
primary_model.skills.append(secondary_model)
elif link_type == "workflow_agent":
primary_model = session.exec(select(Workflow).where(Workflow.id == primary_id)).first()
secondary_model = session.exec(select(Agent).where(Agent.id == secondary_id)).first()
if primary_model is None or secondary_model is None:
status = False
status_message = "One or both entity records do not exist."
else:
# check if the link already exists
existing_link = session.exec(
select(WorkflowAgentLink).where(
WorkflowAgentLink.workflow_id == primary_id,
WorkflowAgentLink.agent_id == secondary_id,
WorkflowAgentLink.agent_type == agent_type,
WorkflowAgentLink.sequence_id == sequence_id,
)
).first()
if existing_link:
return Response(
message=(
f"{secondary_model.__class__.__name__} already linked "
f"to {primary_model.__class__.__name__}"
),
status=False,
)
else:
# primary_model.agents.append(secondary_model)
workflow_agent_link = WorkflowAgentLink(
workflow_id=primary_id,
agent_id=secondary_id,
agent_type=agent_type,
sequence_id=sequence_id,
)
session.add(workflow_agent_link)
# add and commit the link
session.add(primary_model)
session.commit()
status_message = (
f"{secondary_model.__class__.__name__} successfully linked "
f"to {primary_model.__class__.__name__}"
)
except Exception as e:
session.rollback()
logger.error("Error while linking: " + str(e))
status = False
status_message = f"Error while linking due to an exception: {e}"
response = Response(
message=status_message,
status=status,
)
return response
def unlink(
self,
link_type: str,
primary_id: int,
secondary_id: int,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = 0,
) -> Response:
"""
Unlink two entities.
Args:
link_type (str): The type of link to remove, e.g., "agent_model".
primary_id (int): The identifier for the primary model.
secondary_id (int): The identifier for the secondary model.
agent_type (Optional[str]): The type of agent, e.g., "sender" or receiver.
Returns:
Response: The response of the unlinking operation, including success status and message.
"""
status = True
status_message = ""
print("primary", primary_id, "secondary", secondary_id, "sequence", sequence_id, "agent_type", agent_type)
if link_type not in valid_link_types:
status = False
status_message = f"Invalid link type: {link_type}. Valid link types are: {valid_link_types}"
return Response(message=status_message, status=status)
with Session(self.engine) as session:
try:
if link_type == "agent_model":
existing_link = session.exec(
select(AgentModelLink).where(
AgentModelLink.agent_id == primary_id,
AgentModelLink.model_id == secondary_id,
)
).first()
elif link_type == "agent_skill":
existing_link = session.exec(
select(AgentSkillLink).where(
AgentSkillLink.agent_id == primary_id,
AgentSkillLink.skill_id == secondary_id,
)
).first()
elif link_type == "agent_agent":
existing_link = session.exec(
select(AgentLink).where(
AgentLink.parent_id == primary_id,
AgentLink.agent_id == secondary_id,
)
).first()
elif link_type == "workflow_agent":
existing_link = session.exec(
select(WorkflowAgentLink).where(
WorkflowAgentLink.workflow_id == primary_id,
WorkflowAgentLink.agent_id == secondary_id,
WorkflowAgentLink.agent_type == agent_type,
WorkflowAgentLink.sequence_id == sequence_id,
)
).first()
if existing_link:
session.delete(existing_link)
session.commit()
status_message = "Link removed successfully."
else:
status = False
status_message = "Link does not exist."
except Exception as e:
session.rollback()
logger.error("Error while unlinking: " + str(e))
status = False
status_message = f"Error while unlinking due to an exception: {e}"
return Response(message=status_message, status=status)