autogen/python/packages/autogen-studio/autogenstudio/datamodel.py

297 lines
10 KiB
Python

from datetime import datetime
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from sqlalchemy import ForeignKey, Integer, orm
from sqlmodel import (
JSON,
Column,
DateTime,
Field,
Relationship,
SQLModel,
func,
)
from sqlmodel import (
Enum as SqlEnum,
)
# added for python3.11 and sqlmodel 0.0.22 incompatibility
if hasattr(SQLModel, "model_config"):
SQLModel.model_config["protected_namespaces"] = ()
elif hasattr(SQLModel, "Config"):
class CustomSQLModel(SQLModel):
class Config:
protected_namespaces = ()
SQLModel = CustomSQLModel
else:
print("Warning: Unable to set protected_namespaces.")
# pylint: disable=protected-access
class MessageMeta(SQLModel, table=False):
task: Optional[str] = None
messages: Optional[List[Dict[str, Any]]] = None
summary_method: Optional[str] = "last"
files: Optional[List[dict]] = None
time: Optional[datetime] = None
log: Optional[List[dict]] = None
usage: Optional[List[dict]] = None
class Message(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
role: str
content: str
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"))
)
connection_id: Optional[str] = None
meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))
class Session(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
workflow_id: Optional[int] = Field(default=None, foreign_key="workflow.id")
name: Optional[str] = None
description: Optional[str] = None
class AgentSkillLink(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
skill_id: int = Field(default=None, primary_key=True, foreign_key="skill.id")
class AgentModelLink(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
model_id: int = Field(default=None, primary_key=True, foreign_key="model.id")
class Skill(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
name: str
content: str
description: Optional[str] = None
secrets: Optional[List[dict]] = Field(default_factory=list, sa_column=Column(JSON))
libraries: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(back_populates="skills", link_model=AgentSkillLink)
class LLMConfig(SQLModel, table=False):
"""Data model for LLM Config for AutoGen"""
config_list: List[Any] = Field(default_factory=list)
temperature: float = 0
cache_seed: Optional[Union[int, None]] = None
timeout: Optional[int] = None
max_tokens: Optional[int] = 2048
extra_body: Optional[dict] = None
class ModelTypes(str, Enum):
openai = "open_ai"
google = "google"
azure = "azure"
anthropic = "anthropic"
mistral = "mistral"
together = "together"
groq = "groq"
class Model(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
model: str
api_key: Optional[str] = None
base_url: Optional[str] = None
api_type: ModelTypes = Field(default=ModelTypes.openai, sa_column=Column(SqlEnum(ModelTypes)))
api_version: Optional[str] = None
description: Optional[str] = None
agents: List["Agent"] = Relationship(back_populates="models", link_model=AgentModelLink)
class CodeExecutionConfigTypes(str, Enum):
local = "local"
docker = "docker"
none = "none"
class AgentConfig(SQLModel, table=False):
name: Optional[str] = None
human_input_mode: str = "NEVER"
max_consecutive_auto_reply: int = 10
system_message: Optional[str] = None
is_termination_msg: Optional[Union[bool, str, Callable]] = None
code_execution_config: CodeExecutionConfigTypes = Field(
default=CodeExecutionConfigTypes.local, sa_column=Column(SqlEnum(CodeExecutionConfigTypes))
)
default_auto_reply: Optional[str] = ""
description: Optional[str] = None
llm_config: Optional[Union[LLMConfig, bool]] = Field(default=False, sa_column=Column(JSON))
admin_name: Optional[str] = "Admin"
messages: Optional[List[Dict]] = Field(default_factory=list)
max_round: Optional[int] = 100
speaker_selection_method: Optional[str] = "auto"
allow_repeat_speaker: Optional[Union[bool, List["AgentConfig"]]] = True
class AgentType(str, Enum):
assistant = "assistant"
userproxy = "userproxy"
groupchat = "groupchat"
class WorkflowAgentType(str, Enum):
sender = "sender"
receiver = "receiver"
planner = "planner"
sequential = "sequential"
class WorkflowAgentLink(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
workflow_id: int = Field(default=None, primary_key=True, foreign_key="workflow.id")
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
agent_type: WorkflowAgentType = Field(
default=WorkflowAgentType.sender,
sa_column=Column(SqlEnum(WorkflowAgentType), primary_key=True),
)
sequence_id: Optional[int] = Field(default=0, primary_key=True)
class AgentLink(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
parent_id: Optional[int] = Field(default=None, foreign_key="agent.id", primary_key=True)
agent_id: Optional[int] = Field(default=None, foreign_key="agent.id", primary_key=True)
class Agent(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
type: AgentType = Field(default=AgentType.assistant, sa_column=Column(SqlEnum(AgentType)))
config: Union[AgentConfig, dict] = Field(default_factory=AgentConfig, sa_column=Column(JSON))
skills: List[Skill] = Relationship(back_populates="agents", link_model=AgentSkillLink)
models: List[Model] = Relationship(back_populates="agents", link_model=AgentModelLink)
workflows: List["Workflow"] = Relationship(link_model=WorkflowAgentLink, back_populates="agents")
parents: List["Agent"] = Relationship(
back_populates="agents",
link_model=AgentLink,
sa_relationship_kwargs=dict(
primaryjoin="Agent.id==AgentLink.agent_id",
secondaryjoin="Agent.id==AgentLink.parent_id",
),
)
agents: List["Agent"] = Relationship(
back_populates="parents",
link_model=AgentLink,
sa_relationship_kwargs=dict(
primaryjoin="Agent.id==AgentLink.parent_id",
secondaryjoin="Agent.id==AgentLink.agent_id",
),
)
task_instruction: Optional[str] = None
class WorkFlowType(str, Enum):
autonomous = "autonomous"
sequential = "sequential"
class WorkFlowSummaryMethod(str, Enum):
last = "last"
none = "none"
llm = "llm"
class Workflow(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
name: str
description: str
agents: List[Agent] = Relationship(back_populates="workflows", link_model=WorkflowAgentLink)
type: WorkFlowType = Field(default=WorkFlowType.autonomous, sa_column=Column(SqlEnum(WorkFlowType)))
summary_method: Optional[WorkFlowSummaryMethod] = Field(
default=WorkFlowSummaryMethod.last,
sa_column=Column(SqlEnum(WorkFlowSummaryMethod)),
)
sample_tasks: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON))
class Response(SQLModel):
message: str
status: bool
data: Optional[Any] = None
class SocketMessage(SQLModel, table=False):
connection_id: str
data: Dict[str, Any]
type: str