autogen/python/packages/autogen-studio/autogenstudio/datamodel/db.py

282 lines
11 KiB
Python

# defines how core data types in autogenstudio are serialized and stored in the database
from datetime import datetime
from enum import Enum
from typing import List, Optional, Tuple, Type, Union
from uuid import UUID, uuid4
from loguru import logger
from pydantic import BaseModel
from sqlalchemy import ForeignKey, Integer, UniqueConstraint
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel, func
from .types import AgentConfig, MessageConfig, MessageMeta, ModelConfig, TeamConfig, TeamResult, ToolConfig
# added for python3.11 and sqlmodel 0.0.22 incompatibility
if hasattr(SQLModel, "model_config"):
SQLModel.model_config["protected_namespaces"] = ()
elif hasattr(SQLModel, "Config"):
class CustomSQLModel(SQLModel):
class Config:
protected_namespaces = ()
SQLModel = CustomSQLModel
else:
logger.warning("Unable to set protected_namespaces.")
# pylint: disable=protected-access
class ComponentTypes(Enum):
TEAM = "team"
AGENT = "agent"
MODEL = "model"
TOOL = "tool"
@property
def model_class(self) -> Type[SQLModel]:
return {
ComponentTypes.TEAM: Team,
ComponentTypes.AGENT: Agent,
ComponentTypes.MODEL: Model,
ComponentTypes.TOOL: Tool,
}[self]
class LinkTypes(Enum):
AGENT_MODEL = "agent_model"
AGENT_TOOL = "agent_tool"
TEAM_AGENT = "team_agent"
@property
# type: ignore
def link_config(self) -> Tuple[Type[SQLModel], Type[SQLModel], Type[SQLModel]]:
return {
LinkTypes.AGENT_MODEL: (Agent, Model, AgentModelLink),
LinkTypes.AGENT_TOOL: (Agent, Tool, AgentToolLink),
LinkTypes.TEAM_AGENT: (Team, Agent, TeamAgentLink),
}[self]
@property
def primary_class(self) -> Type[SQLModel]: # type: ignore
return self.link_config[0]
@property
def secondary_class(self) -> Type[SQLModel]: # type: ignore
return self.link_config[1]
@property
def link_table(self) -> Type[SQLModel]: # type: ignore
return self.link_config[2]
# link models
class AgentToolLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
{"sqlite_autoincrement": True},
)
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
tool_id: int = Field(default=None, primary_key=True, foreign_key="tool.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
class AgentModelLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
{"sqlite_autoincrement": True},
)
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
model_id: int = Field(default=None, primary_key=True, foreign_key="model.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
class TeamAgentLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
{"sqlite_autoincrement": True},
)
team_id: int = Field(default=None, primary_key=True, foreign_key="team.id")
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
# database models
class Tool(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[ToolConfig, dict] = Field(sa_column=Column(JSON))
agents: List["Agent"] = Relationship(back_populates="tools", link_model=AgentToolLink)
class Model(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[ModelConfig, dict] = Field(sa_column=Column(JSON))
agents: List["Agent"] = Relationship(back_populates="models", link_model=AgentModelLink)
class Team(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[TeamConfig, dict] = Field(sa_column=Column(JSON))
agents: List["Agent"] = Relationship(back_populates="teams", link_model=TeamAgentLink)
class Agent(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[AgentConfig, dict] = Field(sa_column=Column(JSON))
tools: List[Tool] = Relationship(back_populates="agents", link_model=AgentToolLink)
models: List[Model] = Relationship(back_populates="agents", link_model=AgentModelLink)
teams: List[Team] = Relationship(back_populates="agents", link_model=TeamAgentLink)
class Message(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"))
)
run_id: Optional[UUID] = Field(default=None, foreign_key="run.id")
message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))
class Session(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
team_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE")))
name: Optional[str] = None
class RunStatus(str, Enum):
CREATED = "created"
ACTIVE = "active"
COMPLETE = "complete"
ERROR = "error"
STOPPED = "stopped"
class Run(SQLModel, table=True):
"""Represents a single execution run within a session"""
__table_args__ = {"sqlite_autoincrement": True}
id: UUID = Field(default_factory=uuid4, primary_key=True, index=True)
created_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), server_default=func.now())
)
updated_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), onupdate=func.now())
)
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"), nullable=False)
)
status: RunStatus = Field(default=RunStatus.CREATED)
# Store the original user task
task: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
# Store TeamResult which contains TaskResult
team_result: Union[TeamResult, dict] = Field(default=None, sa_column=Column(JSON))
error_message: Optional[str] = None
version: Optional[str] = "0.0.1"
messages: Union[List[Message], List[dict]] = Field(default_factory=list, sa_column=Column(JSON))
class Config:
json_encoders = {UUID: str, datetime: lambda v: v.isoformat()}
class GalleryConfig(SQLModel, table=False):
id: UUID = Field(default_factory=uuid4, primary_key=True, index=True)
title: Optional[str] = None
description: Optional[str] = None
run: Run
team: TeamConfig = None
tags: Optional[List[str]] = None
visibility: str = "public" # public, private, shared
class Config:
json_encoders = {UUID: str, datetime: lambda v: v.isoformat()}
class Gallery(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
)
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
)
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[GalleryConfig, dict] = Field(default_factory=GalleryConfig, sa_column=Column(JSON))