Add Session Saving to AGS (#4369)

* fix import issue related to agentchat update #4245

* update uv lock file

* fix db auto_upgrade logic issue.

* im prove msg rendering issue

* Support termination condition combination. Closes #4325

* fix db instantiation bug

* update yarn.lock, closes #4260 #4262

* remove deps for now with vulnerabilities found by dependabot #4262

* update db tests

* add ability to load sessions from db ..

* format updates, add format checks to ags

* format check fixes

* linting and ruff check fixes

* make tests for ags non-parrallel to avoid db race conditions.

* format updates

* fix concurrency issue

* minor ui tweaks, move run start to websocket

* lint fixes

* update uv.lock

* Update python/packages/autogen-studio/autogenstudio/datamodel/types.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* Update python/packages/autogen-studio/autogenstudio/teammanager.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* reuse user proxy from agentchat

* ui tweaks

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Hussein Mozannar <hmozannar@microsoft.com>
This commit is contained in:
Victor Dibia 2024-11-26 15:39:36 -08:00 committed by GitHub
parent df183be35a
commit fe96f7de24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 2741 additions and 4315 deletions

View File

@ -1,3 +1,18 @@
from .datamodel import *
from .database.db_manager import DatabaseManager
from .datamodel import Agent, AgentConfig, Model, ModelConfig, Team, TeamConfig, Tool, ToolConfig
from .teammanager import TeamManager
from .version import __version__
from .teammanager import *
__all__ = [
"Tool",
"Model",
"DatabaseManager",
"Team",
"Agent",
"ToolConfig",
"ModelConfig",
"TeamConfig",
"AgentConfig",
"TeamManager",
"__version__",
]

View File

@ -15,7 +15,7 @@ def ui(
host: str = "127.0.0.1",
port: int = 8081,
workers: int = 1,
reload: Annotated[bool, typer.Option("--reload")] = True,
reload: Annotated[bool, typer.Option("--reload")] = False,
docs: bool = True,
appdir: str = None,
database_uri: Optional[str] = None,
@ -48,11 +48,7 @@ def ui(
port=port,
workers=workers,
reload=reload,
reload_excludes=[
"**/alembic/*",
"**/alembic.ini",
"**/versions/*"
] if reload else None
reload_excludes=["**/alembic/*", "**/alembic.ini", "**/versions/*"] if reload else None,
)

View File

@ -1 +0,0 @@
from .agents.userproxy import UserProxyAgent

View File

@ -1,47 +0,0 @@
from typing import Callable, List, Optional, Sequence, Union, Awaitable
from inspect import iscoroutinefunction
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import ChatMessage, TextMessage
from autogen_core.base import CancellationToken
import asyncio
class UserProxyAgent(BaseChatAgent):
"""An agent that can represent a human user in a chat."""
def __init__(
self,
name: str,
description: Optional[str] = "a",
input_func: Optional[Union[Callable[..., str],
Callable[..., Awaitable[str]]]] = None
) -> None:
super().__init__(name, description=description)
self.input_func = input_func or input
self._is_async = iscoroutinefunction(
input_func) if input_func else False
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage]
async def _get_input(self, prompt: str) -> str:
"""Handle both sync and async input functions"""
if self._is_async:
return await self.input_func(prompt)
else:
return await asyncio.get_event_loop().run_in_executor(None, self.input_func, prompt)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
try:
user_input = await self._get_input("Enter your response: ")
return Response(chat_message=TextMessage(content=user_input, source=self.name))
except Exception as e:
# Consider logging the error here
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass

View File

@ -1,3 +1,3 @@
from .db_manager import DatabaseManager
from .component_factory import ComponentFactory, Component
from .component_factory import Component, ComponentFactory
from .config_manager import ConfigurationManager
from .db_manager import DatabaseManager

View File

@ -1,43 +1,45 @@
import os
from pathlib import Path
from typing import Callable, List, Literal, Union, Optional, Dict, Any, Type
from datetime import datetime
import json
from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination, StopMessageTermination
import yaml
import logging
from packaging import version
from datetime import datetime
from pathlib import Path
from typing import Callable, Dict, List, Literal, Optional, Union
from ..datamodel import (
TeamConfig, AgentConfig, ModelConfig, ToolConfig,
TeamTypes, AgentTypes, ModelTypes, ToolTypes,
ComponentType, ComponentConfig, ComponentConfigInput, TerminationConfig, TerminationTypes, Response
)
from ..components import UserProxyAgent
from autogen_agentchat.agents import AssistantAgent
import aiofiles
import yaml
from autogen_agentchat.agents import AssistantAgent, UserProxyAgent
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient
from ..datamodel.types import (
AgentConfig,
AgentTypes,
ComponentConfig,
ComponentConfigInput,
ComponentTypes,
ModelConfig,
ModelTypes,
TeamConfig,
TeamTypes,
TerminationConfig,
TerminationTypes,
ToolConfig,
ToolTypes,
)
from ..utils.utils import Version
logger = logging.getLogger(__name__)
# Type definitions for supported components
TeamComponent = Union[RoundRobinGroupChat, SelectorGroupChat]
AgentComponent = Union[AssistantAgent] # Will grow with more agent types
# Will grow with more model types
AgentComponent = Union[AssistantAgent]
ModelComponent = Union[OpenAIChatCompletionClient]
ToolComponent = Union[FunctionTool] # Will grow with more tool types
TerminationComponent = Union[MaxMessageTermination,
StopMessageTermination, TextMentionTermination]
TerminationComponent = Union[MaxMessageTermination, StopMessageTermination, TextMentionTermination]
# Config type definitions
Component = Union[TeamComponent, AgentComponent, ModelComponent, ToolComponent, TerminationComponent]
Component = Union[TeamComponent, AgentComponent, ModelComponent, ToolComponent]
ReturnType = Literal['object', 'dict', 'config']
Component = Union[RoundRobinGroupChat, SelectorGroupChat,
AssistantAgent, OpenAIChatCompletionClient, FunctionTool]
ReturnType = Literal["object", "dict", "config"]
DEFAULT_SELECTOR_PROMPT = """You are in a role play game. The following roles are available:
{roles}.
@ -48,18 +50,18 @@ Read the following conversation. Then select the next role from {participants} t
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
"""
CONFIG_RETURN_TYPES = Literal['object', 'dict', 'config']
CONFIG_RETURN_TYPES = Literal["object", "dict", "config"]
class ComponentFactory:
"""Creates and manages agent components with versioned configuration loading"""
SUPPORTED_VERSIONS = {
ComponentType.TEAM: ["1.0.0"],
ComponentType.AGENT: ["1.0.0"],
ComponentType.MODEL: ["1.0.0"],
ComponentType.TOOL: ["1.0.0"],
ComponentType.TERMINATION: ["1.0.0"]
ComponentTypes.TEAM: ["1.0.0"],
ComponentTypes.AGENT: ["1.0.0"],
ComponentTypes.MODEL: ["1.0.0"],
ComponentTypes.TOOL: ["1.0.0"],
ComponentTypes.TERMINATION: ["1.0.0"],
}
def __init__(self):
@ -68,10 +70,7 @@ class ComponentFactory:
self._last_cache_clear = datetime.now()
async def load(
self,
component: ComponentConfigInput,
input_func: Optional[Callable] = None,
return_type: ReturnType = 'object'
self, component: ComponentConfigInput, input_func: Optional[Callable] = None, return_type: ReturnType = "object"
) -> Union[Component, dict, ComponentConfig]:
"""
Universal loader for any component type
@ -103,24 +102,23 @@ class ComponentFactory:
)
# Return early if dict or config requested
if return_type == 'dict':
if return_type == "dict":
return config.model_dump()
elif return_type == 'config':
elif return_type == "config":
return config
# Otherwise create and return component instance
handlers = {
ComponentType.TEAM: lambda c: self.load_team(c, input_func),
ComponentType.AGENT: lambda c: self.load_agent(c, input_func),
ComponentType.MODEL: self.load_model,
ComponentType.TOOL: self.load_tool,
ComponentType.TERMINATION: self.load_termination
ComponentTypes.TEAM: lambda c: self.load_team(c, input_func),
ComponentTypes.AGENT: lambda c: self.load_agent(c, input_func),
ComponentTypes.MODEL: self.load_model,
ComponentTypes.TOOL: self.load_tool,
ComponentTypes.TERMINATION: self.load_termination,
}
handler = handlers.get(config.component_type)
if not handler:
raise ValueError(
f"Unknown component type: {config.component_type}")
raise ValueError(f"Unknown component type: {config.component_type}")
return await handler(config)
@ -128,7 +126,9 @@ class ComponentFactory:
logger.error(f"Failed to load component: {str(e)}")
raise
async def load_directory(self, directory: Union[str, Path], return_type: ReturnType = 'object') -> List[Union[Component, dict, ComponentConfig]]:
async def load_directory(
self, directory: Union[str, Path], return_type: ReturnType = "object"
) -> List[Union[Component, dict, ComponentConfig]]:
"""
Import all component configurations from a directory.
"""
@ -137,13 +137,12 @@ class ComponentFactory:
directory = Path(directory)
# Using Path.iterdir() instead of os.listdir
for path in list(directory.glob("*")):
if path.suffix.lower().endswith(('.json', '.yaml', '.yml')):
if path.suffix.lower().endswith((".json", ".yaml", ".yml")):
try:
component = await self.load(path, return_type=return_type)
components.append(component)
except Exception as e:
logger.info(
f"Failed to load component: {str(e)}, {path}")
logger.info(f"Failed to load component: {str(e)}, {path}")
return components
except Exception as e:
@ -156,14 +155,14 @@ class ComponentFactory:
raise ValueError("component_type is required in configuration")
config_types = {
ComponentType.TEAM: TeamConfig,
ComponentType.AGENT: AgentConfig,
ComponentType.MODEL: ModelConfig,
ComponentType.TOOL: ToolConfig,
ComponentType.TERMINATION: TerminationConfig # Add mapping for termination
ComponentTypes.TEAM: TeamConfig,
ComponentTypes.AGENT: AgentConfig,
ComponentTypes.MODEL: ModelConfig,
ComponentTypes.TOOL: ToolConfig,
ComponentTypes.TERMINATION: TerminationConfig, # Add mapping for termination
}
component_type = ComponentType(config_dict["component_type"])
component_type = ComponentTypes(config_dict["component_type"])
config_class = config_types.get(component_type)
if not config_class:
@ -174,28 +173,44 @@ class ComponentFactory:
async def load_termination(self, config: TerminationConfig) -> TerminationComponent:
"""Create termination condition instance from configuration."""
try:
if config.termination_type == TerminationTypes.MAX_MESSAGES:
if config.termination_type == TerminationTypes.COMBINATION:
if not config.conditions or len(config.conditions) < 2:
raise ValueError("Combination termination requires at least 2 conditions")
if not config.operator:
raise ValueError("Combination termination requires an operator (and/or)")
# Load first two conditions
conditions = [await self.load_termination(cond) for cond in config.conditions[:2]]
result = conditions[0] & conditions[1] if config.operator == "and" else conditions[0] | conditions[1]
# Process remaining conditions if any
for condition in config.conditions[2:]:
next_condition = await self.load_termination(condition)
result = result & next_condition if config.operator == "and" else result | next_condition
return result
elif config.termination_type == TerminationTypes.MAX_MESSAGES:
if config.max_messages is None:
raise ValueError("max_messages parameter required for MaxMessageTermination")
return MaxMessageTermination(max_messages=config.max_messages)
elif config.termination_type == TerminationTypes.STOP_MESSAGE:
return StopMessageTermination()
elif config.termination_type == TerminationTypes.TEXT_MENTION:
if not config.text:
raise ValueError(
"text parameter required for TextMentionTermination")
raise ValueError("text parameter required for TextMentionTermination")
return TextMentionTermination(text=config.text)
else:
raise ValueError(
f"Unsupported termination type: {config.termination_type}")
raise ValueError(f"Unsupported termination type: {config.termination_type}")
except Exception as e:
logger.error(f"Failed to create termination condition: {str(e)}")
raise ValueError(
f"Termination condition creation failed: {str(e)}")
raise ValueError(f"Termination condition creation failed: {str(e)}") from e
async def load_team(
self,
config: TeamConfig,
input_func: Optional[Callable] = None
) -> TeamComponent:
async def load_team(self, config: TeamConfig, input_func: Optional[Callable] = None) -> TeamComponent:
"""Create team instance from configuration."""
try:
# Load participants (agents) with input_func
@ -216,33 +231,25 @@ class ComponentFactory:
# Create team based on type
if config.team_type == TeamTypes.ROUND_ROBIN:
return RoundRobinGroupChat(
participants=participants,
termination_condition=termination
)
return RoundRobinGroupChat(participants=participants, termination_condition=termination)
elif config.team_type == TeamTypes.SELECTOR:
if not model_client:
raise ValueError(
"SelectorGroupChat requires a model_client")
raise ValueError("SelectorGroupChat requires a model_client")
selector_prompt = config.selector_prompt if config.selector_prompt else DEFAULT_SELECTOR_PROMPT
return SelectorGroupChat(
participants=participants,
model_client=model_client,
termination_condition=termination,
selector_prompt=selector_prompt
selector_prompt=selector_prompt,
)
else:
raise ValueError(f"Unsupported team type: {config.team_type}")
except Exception as e:
logger.error(f"Failed to create team {config.name}: {str(e)}")
raise ValueError(f"Team creation failed: {str(e)}")
raise ValueError(f"Team creation failed: {str(e)}") from e
async def load_agent(
self,
config: AgentConfig,
input_func: Optional[Callable] = None
) -> AgentComponent:
async def load_agent(self, config: AgentConfig, input_func: Optional[Callable] = None) -> AgentComponent:
"""Create agent instance from configuration."""
try:
# Load model client if specified
@ -263,7 +270,7 @@ class ComponentFactory:
return UserProxyAgent(
name=config.name,
description=config.description or "A human user",
input_func=input_func # Pass through to UserProxyAgent
input_func=input_func, # Pass through to UserProxyAgent
)
elif config.agent_type == AgentTypes.ASSISTANT:
return AssistantAgent(
@ -271,15 +278,14 @@ class ComponentFactory:
description=config.description or "A helpful assistant",
model_client=model_client,
tools=tools,
system_message=system_message
system_message=system_message,
)
else:
raise ValueError(
f"Unsupported agent type: {config.agent_type}")
raise ValueError(f"Unsupported agent type: {config.agent_type}")
except Exception as e:
logger.error(f"Failed to create agent {config.name}: {str(e)}")
raise ValueError(f"Agent creation failed: {str(e)}")
raise ValueError(f"Agent creation failed: {str(e)}") from e
async def load_model(self, config: ModelConfig) -> ModelComponent:
"""Create model instance from configuration."""
@ -291,20 +297,15 @@ class ComponentFactory:
return self._model_cache[cache_key]
if config.model_type == ModelTypes.OPENAI:
model = OpenAIChatCompletionClient(
model=config.model,
api_key=config.api_key,
base_url=config.base_url
)
model = OpenAIChatCompletionClient(model=config.model, api_key=config.api_key, base_url=config.base_url)
self._model_cache[cache_key] = model
return model
else:
raise ValueError(
f"Unsupported model type: {config.model_type}")
raise ValueError(f"Unsupported model type: {config.model_type}")
except Exception as e:
logger.error(f"Failed to create model {config.model}: {str(e)}")
raise ValueError(f"Model creation failed: {str(e)}")
raise ValueError(f"Model creation failed: {str(e)}") from e
async def load_tool(self, config: ToolConfig) -> ToolComponent:
"""Create tool instance from configuration."""
@ -321,9 +322,7 @@ class ComponentFactory:
if config.tool_type == ToolTypes.PYTHON_FUNCTION:
tool = FunctionTool(
name=config.name,
description=config.description,
func=self._func_from_string(config.content)
name=config.name, description=config.description, func=self._func_from_string(config.content)
)
self._tool_cache[cache_key] = tool
return tool
@ -334,7 +333,6 @@ class ComponentFactory:
logger.error(f"Failed to create tool '{config.name}': {str(e)}")
raise
# Helper methods remain largely the same
async def _load_from_file(self, path: Union[str, Path]) -> dict:
"""Load configuration from JSON or YAML file."""
path = Path(path)
@ -342,15 +340,16 @@ class ComponentFactory:
raise FileNotFoundError(f"Config file not found: {path}")
try:
with open(path) as f:
if path.suffix == '.json':
return json.load(f)
elif path.suffix in ('.yml', '.yaml'):
return yaml.safe_load(f)
async with aiofiles.open(path) as f:
content = await f.read()
if path.suffix == ".json":
return json.loads(content)
elif path.suffix in (".yml", ".yaml"):
return yaml.safe_load(content)
else:
raise ValueError(f"Unsupported file format: {path.suffix}")
except Exception as e:
raise ValueError(f"Failed to load file {path}: {str(e)}")
raise ValueError(f"Failed to load file {path}: {str(e)}") from e
def _func_from_string(self, content: str) -> callable:
"""Convert function string to callable."""
@ -362,24 +361,25 @@ class ComponentFactory:
return item
raise ValueError("No function found in provided code")
except Exception as e:
raise ValueError(f"Failed to create function: {str(e)}")
raise ValueError(f"Failed to create function: {str(e)}") from e
def _is_version_supported(self, component_type: ComponentType, ver: str) -> bool:
def _is_version_supported(self, component_type: ComponentTypes, ver: str) -> bool:
"""Check if version is supported for component type."""
try:
v = version.parse(ver)
return ver in self.SUPPORTED_VERSIONS[component_type]
except version.InvalidVersion:
version = Version(ver)
supported = [Version(v) for v in self.SUPPORTED_VERSIONS[component_type]]
return any(version == v for v in supported)
except ValueError:
return False
async def cleanup(self) -> None:
"""Cleanup resources and clear caches."""
for model in self._model_cache.values():
if hasattr(model, 'cleanup'):
if hasattr(model, "cleanup"):
await model.cleanup()
for tool in self._tool_cache.values():
if hasattr(tool, 'cleanup'):
if hasattr(tool, "cleanup"):
await tool.cleanup()
self._model_cache.clear()

View File

@ -1,13 +1,11 @@
import logging
from typing import Optional, Union, Dict, Any, List
from pathlib import Path
from loguru import logger
from ..datamodel import (
Model, Team, Agent, Tool,
Response, ComponentTypes, LinkTypes,
ComponentConfigInput
)
from typing import Any, Dict, List, Optional, Union
from loguru import logger
from ..datamodel.db import Agent, LinkTypes, Model, Team, Tool
from ..datamodel.types import ComponentConfigInput, ComponentTypes, Response
from .component_factory import ComponentFactory
from .db_manager import DatabaseManager
@ -16,10 +14,10 @@ class ConfigurationManager:
"""Manages persistence and relationships of components using ComponentFactory for validation"""
DEFAULT_UNIQUENESS_FIELDS = {
ComponentTypes.MODEL: ['model_type', 'model'],
ComponentTypes.TOOL: ['name'],
ComponentTypes.AGENT: ['agent_type', 'name'],
ComponentTypes.TEAM: ['team_type', 'name']
ComponentTypes.MODEL: ["model_type", "model"],
ComponentTypes.TOOL: ["name"],
ComponentTypes.AGENT: ["agent_type", "name"],
ComponentTypes.TEAM: ["team_type", "name"],
}
def __init__(self, db_manager: DatabaseManager, uniqueness_fields: Dict[ComponentTypes, List[str]] = None):
@ -27,7 +25,9 @@ class ConfigurationManager:
self.component_factory = ComponentFactory()
self.uniqueness_fields = uniqueness_fields or self.DEFAULT_UNIQUENESS_FIELDS
async def import_component(self, component_config: ComponentConfigInput, user_id: str, check_exists: bool = False) -> Response:
async def import_component(
self, component_config: ComponentConfigInput, user_id: str, check_exists: bool = False
) -> Response:
"""
Import a component configuration, validate it, and store the resulting component.
@ -41,23 +41,21 @@ class ConfigurationManager:
"""
try:
# Get validated config as dict
config = await self.component_factory.load(component_config, return_type='dict')
config = await self.component_factory.load(component_config, return_type="dict")
# Get component type
component_type = self._determine_component_type(config)
if not component_type:
raise ValueError(
f"Unable to determine component type from config")
raise ValueError("Unable to determine component type from config")
# Check existence if requested
if check_exists:
existing = self._check_exists(component_type, config, user_id)
if existing:
return Response(
message=self._format_exists_message(
component_type, config),
message=self._format_exists_message(component_type, config),
status=True,
data={"id": existing.id}
data={"id": existing.id},
)
# Route to appropriate storage method
@ -70,8 +68,7 @@ class ConfigurationManager:
elif component_type == ComponentTypes.TOOL:
return await self._store_tool(config, user_id)
else:
raise ValueError(
f"Unsupported component type: {component_type}")
raise ValueError(f"Unsupported component type: {component_type}")
except Exception as e:
logger.error(f"Failed to import component: {str(e)}")
@ -90,23 +87,21 @@ class ConfigurationManager:
Response containing import results for all files
"""
try:
configs = await self.component_factory.load_directory(directory, return_type='dict')
configs = await self.component_factory.load_directory(directory, return_type="dict")
results = []
for config in configs:
result = await self.import_component(config, user_id, check_exists)
results.append({
"component": self._get_component_type(config),
"status": result.status,
"message": result.message,
"id": result.data.get("id") if result.status else None
})
results.append(
{
"component": self._get_component_type(config),
"status": result.status,
"message": result.message,
"id": result.data.get("id") if result.status else None,
}
)
return Response(
message="Directory import complete",
status=True,
data=results
)
return Response(message="Directory import complete", status=True, data=results)
except Exception as e:
logger.error(f"Failed to import directory: {str(e)}")
@ -116,10 +111,7 @@ class ConfigurationManager:
"""Store team component and manage its relationships with agents"""
try:
# Store the team
team_db = Team(
user_id=user_id,
config=config
)
team_db = Team(user_id=user_id, config=config)
team_result = self.db_manager.upsert(team_db)
if not team_result.status:
return team_result
@ -131,27 +123,17 @@ class ConfigurationManager:
if check_exists:
# Check for existing agent
agent_type = self._determine_component_type(participant)
existing_agent = self._check_exists(
agent_type, participant, user_id)
existing_agent = self._check_exists(agent_type, participant, user_id)
if existing_agent:
# Link existing agent
self.db_manager.link(
LinkTypes.TEAM_AGENT,
team_id,
existing_agent.id
)
logger.info(
f"Linked existing agent to team: {existing_agent}")
self.db_manager.link(LinkTypes.TEAM_AGENT, team_id, existing_agent.id)
logger.info(f"Linked existing agent to team: {existing_agent}")
continue
# Store and link new agent
agent_result = await self._store_agent(participant, user_id, check_exists)
if agent_result.status:
self.db_manager.link(
LinkTypes.TEAM_AGENT,
team_id,
agent_result.data["id"]
)
self.db_manager.link(LinkTypes.TEAM_AGENT, team_id, agent_result.data["id"])
return team_result
@ -163,10 +145,7 @@ class ConfigurationManager:
"""Store agent component and manage its relationships with tools and model"""
try:
# Store the agent
agent_db = Agent(
user_id=user_id,
config=config
)
agent_db = Agent(user_id=user_id, config=config)
agent_result = self.db_manager.upsert(agent_db)
if not agent_result.status:
return agent_result
@ -177,64 +156,39 @@ class ConfigurationManager:
if "model_client" in config:
if check_exists:
# Check for existing model
model_type = self._determine_component_type(
config["model_client"])
existing_model = self._check_exists(
model_type, config["model_client"], user_id)
model_type = self._determine_component_type(config["model_client"])
existing_model = self._check_exists(model_type, config["model_client"], user_id)
if existing_model:
# Link existing model
self.db_manager.link(
LinkTypes.AGENT_MODEL,
agent_id,
existing_model.id
)
logger.info(
f"Linked existing model to agent: {existing_model.config.model_type}")
self.db_manager.link(LinkTypes.AGENT_MODEL, agent_id, existing_model.id)
logger.info(f"Linked existing model to agent: {existing_model.config.model_type}")
else:
# Store and link new model
model_result = await self._store_model(config["model_client"], user_id)
if model_result.status:
self.db_manager.link(
LinkTypes.AGENT_MODEL,
agent_id,
model_result.data["id"]
)
self.db_manager.link(LinkTypes.AGENT_MODEL, agent_id, model_result.data["id"])
else:
# Store and link new model without checking
model_result = await self._store_model(config["model_client"], user_id)
if model_result.status:
self.db_manager.link(
LinkTypes.AGENT_MODEL,
agent_id,
model_result.data["id"]
)
self.db_manager.link(LinkTypes.AGENT_MODEL, agent_id, model_result.data["id"])
# Handle tools
for tool_config in config.get("tools", []):
if check_exists:
# Check for existing tool
tool_type = self._determine_component_type(tool_config)
existing_tool = self._check_exists(
tool_type, tool_config, user_id)
existing_tool = self._check_exists(tool_type, tool_config, user_id)
if existing_tool:
# Link existing tool
self.db_manager.link(
LinkTypes.AGENT_TOOL,
agent_id,
existing_tool.id
)
logger.info(
f"Linked existing tool to agent: {existing_tool.config.name}")
self.db_manager.link(LinkTypes.AGENT_TOOL, agent_id, existing_tool.id)
logger.info(f"Linked existing tool to agent: {existing_tool.config.name}")
continue
# Store and link new tool
tool_result = await self._store_tool(tool_config, user_id)
if tool_result.status:
self.db_manager.link(
LinkTypes.AGENT_TOOL,
agent_id,
tool_result.data["id"]
)
self.db_manager.link(LinkTypes.AGENT_TOOL, agent_id, tool_result.data["id"])
return agent_result
@ -245,10 +199,7 @@ class ConfigurationManager:
async def _store_model(self, config: dict, user_id: str) -> Response:
"""Store model component (leaf node - no relationships)"""
try:
model_db = Model(
user_id=user_id,
config=config
)
model_db = Model(user_id=user_id, config=config)
return self.db_manager.upsert(model_db)
except Exception as e:
@ -258,17 +209,16 @@ class ConfigurationManager:
async def _store_tool(self, config: dict, user_id: str) -> Response:
"""Store tool component (leaf node - no relationships)"""
try:
tool_db = Tool(
user_id=user_id,
config=config
)
tool_db = Tool(user_id=user_id, config=config)
return self.db_manager.upsert(tool_db)
except Exception as e:
logger.error(f"Failed to store tool: {str(e)}")
return Response(message=str(e), status=False)
def _check_exists(self, component_type: ComponentTypes, config: dict, user_id: str) -> Optional[Union[Model, Tool, Agent, Team]]:
def _check_exists(
self, component_type: ComponentTypes, config: dict, user_id: str
) -> Optional[Union[Model, Tool, Agent, Team]]:
"""Check if component exists based on configured uniqueness fields."""
fields = self.uniqueness_fields.get(component_type, [])
if not fields:
@ -278,17 +228,13 @@ class ConfigurationManager:
ComponentTypes.MODEL: Model,
ComponentTypes.TOOL: Tool,
ComponentTypes.AGENT: Agent,
ComponentTypes.TEAM: Team
ComponentTypes.TEAM: Team,
}.get(component_type)
components = self.db_manager.get(
component_class, {"user_id": user_id}).data
components = self.db_manager.get(component_class, {"user_id": user_id}).data
for component in components:
matches = all(
component.config.get(field) == config.get(field)
for field in fields
)
matches = all(component.config.get(field) == config.get(field) for field in fields)
if matches:
return component

View File

@ -1,59 +1,77 @@
from pathlib import Path
import threading
from datetime import datetime
from pathlib import Path
from typing import Optional
from loguru import logger
from sqlalchemy import exc, text, func
from sqlalchemy import exc, func, inspect, text
from sqlmodel import Session, SQLModel, and_, create_engine, select
from ..datamodel import LinkTypes, Response
from .schema_manager import SchemaManager
from ..datamodel import (
Response,
LinkTypes
)
# from .dbutils import init_db_samples
class DatabaseManager:
"""A class to manage database operations"""
_init_lock = threading.Lock()
def __init__(
self,
engine_uri: str,
base_dir: Optional[Path | str] = None,
auto_upgrade: bool = True
):
def __init__(self, engine_uri: str, base_dir: Optional[Path] = None):
"""
Initialize DatabaseManager with optional custom base directory.
Initialize DatabaseManager with database connection settings.
Does not perform any database operations.
Args:
engine_uri: Database connection URI
base_dir: Custom base directory for Alembic files. If None, uses current working directory
auto_upgrade: Whether to automatically upgrade schema when differences found
engine_uri: Database connection URI (e.g. sqlite:///db.sqlite3)
base_dir: Base directory for migration files. If None, uses current directory
"""
# Convert string path to Path object if necessary
if isinstance(base_dir, str):
base_dir = Path(base_dir)
connection_args = {
"check_same_thread": True
} if "sqlite" in engine_uri else {}
connection_args = {"check_same_thread": True} if "sqlite" in engine_uri else {}
self.engine = create_engine(engine_uri, connect_args=connection_args)
self.schema_manager = SchemaManager(
engine=self.engine,
base_dir=base_dir,
auto_upgrade=auto_upgrade,
)
# Check and upgrade on startup
upgraded, status = self.schema_manager.check_and_upgrade()
if upgraded:
logger.info("Database schema was upgraded automatically")
else:
logger.info(f"Schema status: {status}")
def initialize_database(self, auto_upgrade: bool = False, force_init_alembic: bool = True) -> Response:
"""
Initialize database and migrations in the correct order.
Args:
auto_upgrade: If True, automatically generate and apply migrations for schema changes
force_init_alembic: If True, reinitialize alembic configuration even if it exists
"""
if not self._init_lock.acquire(blocking=False):
return Response(message="Database initialization already in progress", status=False)
try:
inspector = inspect(self.engine)
tables_exist = inspector.get_table_names()
if not tables_exist:
# Fresh install - create tables and initialize migrations
logger.info("Creating database tables...")
SQLModel.metadata.create_all(self.engine)
if self.schema_manager.initialize_migrations(force=force_init_alembic):
return Response(message="Database initialized successfully", status=True)
return Response(message="Failed to initialize migrations", status=False)
# Handle existing database
if auto_upgrade:
logger.info("Checking database schema...")
if self.schema_manager.ensure_schema_up_to_date(): # <-- Use this instead
return Response(message="Database schema is up to date", status=True)
return Response(message="Database upgrade failed", status=False)
return Response(message="Database is ready", status=True)
except Exception as e:
error_msg = f"Database initialization failed: {str(e)}"
logger.error(error_msg)
return Response(message=error_msg, status=False)
finally:
self._init_lock.release()
def reset_db(self, recreate_tables: bool = True):
"""
@ -65,11 +83,7 @@ class DatabaseManager:
"""
if not self._init_lock.acquire(blocking=False):
logger.warning("Database reset already in progress")
return Response(
message="Database reset already in progress",
status=False,
data=None
)
return Response(message="Database reset already in progress", status=False, data=None)
try:
# Dispose existing connections
@ -77,16 +91,16 @@ class DatabaseManager:
with Session(self.engine) as session:
try:
# Disable foreign key checks for SQLite
if 'sqlite' in str(self.engine.url):
session.exec(text('PRAGMA foreign_keys=OFF'))
if "sqlite" in str(self.engine.url):
session.exec(text("PRAGMA foreign_keys=OFF"))
# Drop all tables
SQLModel.metadata.drop_all(self.engine)
logger.info("All tables dropped successfully")
# Re-enable foreign key checks for SQLite
if 'sqlite' in str(self.engine.url):
session.exec(text('PRAGMA foreign_keys=ON'))
if "sqlite" in str(self.engine.url):
session.exec(text("PRAGMA foreign_keys=ON"))
session.commit()
@ -99,48 +113,29 @@ class DatabaseManager:
if recreate_tables:
logger.info("Recreating tables...")
self.create_db_and_tables()
self.initialize_database(auto_upgrade=False, force_init_alembic=True)
return Response(
message="Database reset successfully" if recreate_tables else "Database tables dropped successfully",
status=True,
data=None
data=None,
)
except Exception as e:
error_msg = f"Error while resetting database: {str(e)}"
logger.error(error_msg)
return Response(
message=error_msg,
status=False,
data=None
)
return Response(message=error_msg, status=False, data=None)
finally:
if self._init_lock.locked():
self._init_lock.release()
logger.info("Database reset lock released")
def create_db_and_tables(self):
"""Create a new database and tables"""
with self._init_lock:
try:
SQLModel.metadata.create_all(self.engine)
logger.info("Database tables created successfully")
try:
# init_db_samples(self)
pass
except Exception as e:
logger.info(
"Error while initializing database samples: " + str(e))
except Exception as e:
logger.info("Error while creating database tables:" + str(e))
def upsert(self, model: SQLModel, return_json: bool = True):
"""Create or update an entity
Args:
model (SQLModel): The model instance to create or update
return_json (bool, optional): If True, returns the model as a dictionary.
return_json (bool, optional): If True, returns the model as a dictionary.
If False, returns the SQLModel instance. Defaults to True.
Returns:
@ -152,8 +147,7 @@ class DatabaseManager:
with Session(self.engine) as session:
try:
existing_model = session.exec(
select(model_class).where(model_class.id == model.id)).first()
existing_model = session.exec(select(model_class).where(model_class.id == model.id)).first()
if existing_model:
model.updated_at = datetime.now()
for key, value in model.model_dump().items():
@ -166,8 +160,7 @@ class DatabaseManager:
session.refresh(model)
except Exception as e:
session.rollback()
logger.error("Error while updating/creating " +
str(model_class.__name__) + ": " + str(e))
logger.error("Error while updating/creating " + str(model_class.__name__) + ": " + str(e))
status = False
return Response(
@ -199,25 +192,21 @@ class DatabaseManager:
try:
statement = select(model_class)
if filters:
conditions = [getattr(model_class, col) ==
value for col, value in filters.items()]
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
statement = statement.where(and_(*conditions))
if hasattr(model_class, "created_at") and order:
order_by_clause = getattr(
model_class.created_at, order)() # Dynamically apply asc/desc
order_by_clause = getattr(model_class.created_at, order)() # Dynamically apply asc/desc
statement = statement.order_by(order_by_clause)
items = session.exec(statement).all()
result = [self._model_to_dict(
item) if return_json else item for item in items]
result = [self._model_to_dict(item) if return_json else item for item in items]
status_message = f"{model_class.__name__} Retrieved Successfully"
except Exception as e:
session.rollback()
status = False
status_message = f"Error while fetching {model_class.__name__}"
logger.error("Error while getting items: " +
str(model_class.__name__) + " " + str(e))
logger.error("Error while getting items: " + str(model_class.__name__) + " " + str(e))
return Response(message=status_message, status=status, data=result)
@ -230,8 +219,7 @@ class DatabaseManager:
try:
statement = select(model_class)
if filters:
conditions = [
getattr(model_class, col) == value for col, value in filters.items()]
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
statement = statement.where(and_(*conditions))
rows = session.exec(statement).all()
@ -290,8 +278,7 @@ class DatabaseManager:
select(link_table).where(
and_(
getattr(link_table, primary_id_field) == primary_id,
getattr(
link_table, secondary_id_field) == secondary_id
getattr(link_table, secondary_id_field) == secondary_id,
)
)
).first()
@ -302,37 +289,24 @@ class DatabaseManager:
# Get the next sequence number if not provided
if sequence is None:
max_seq_result = session.exec(
select(func.max(link_table.sequence)).where(
getattr(link_table, primary_id_field) == primary_id
)
select(func.max(link_table.sequence)).where(getattr(link_table, primary_id_field) == primary_id)
).first()
sequence = 0 if max_seq_result is None else max_seq_result + 1
# Create new link
new_link = link_table(**{
primary_id_field: primary_id,
secondary_id_field: secondary_id,
'sequence': sequence
})
new_link = link_table(
**{primary_id_field: primary_id, secondary_id_field: secondary_id, "sequence": sequence}
)
session.add(new_link)
session.commit()
return Response(
message=f"Entities linked successfully with sequence {sequence}",
status=True
)
return Response(message=f"Entities linked successfully with sequence {sequence}", status=True)
except Exception as e:
session.rollback()
return Response(message=f"Error linking entities: {str(e)}", status=False)
def unlink(
self,
link_type: LinkTypes,
primary_id: int,
secondary_id: int,
sequence: Optional[int] = None
):
def unlink(self, link_type: LinkTypes, primary_id: int, secondary_id: int, sequence: Optional[int] = None):
"""Unlink two entities and reorder sequences if needed."""
with Session(self.engine) as session:
try:
@ -349,13 +323,12 @@ class DatabaseManager:
statement = select(link_table).where(
and_(
getattr(link_table, primary_id_field) == primary_id,
getattr(link_table, secondary_id_field) == secondary_id
getattr(link_table, secondary_id_field) == secondary_id,
)
)
if sequence is not None:
statement = statement.where(
link_table.sequence == sequence)
statement = statement.where(link_table.sequence == sequence)
existing_link = session.exec(statement).first()
@ -379,10 +352,7 @@ class DatabaseManager:
session.commit()
return Response(
message="Entities unlinked successfully and sequences reordered",
status=True
)
return Response(message="Entities unlinked successfully and sequences reordered", status=True)
except Exception as e:
session.rollback()
@ -414,22 +384,14 @@ class DatabaseManager:
.order_by(link_table.sequence)
).all()
result = [
item.model_dump() if return_json else item for item in items]
result = [item.model_dump() if return_json else item for item in items]
return Response(
message="Linked entities retrieved successfully",
status=True,
data=result
)
return Response(message="Linked entities retrieved successfully", status=True, data=result)
except Exception as e:
logger.error(f"Error getting linked entities: {str(e)}")
return Response(
message=f"Error getting linked entities: {str(e)}",
status=False,
data=[]
)
return Response(message=f"Error getting linked entities: {str(e)}", status=False, data=[])
# Add new close method
async def close(self):

View File

@ -1,68 +1,71 @@
import os
from pathlib import Path
import shutil
from typing import Optional, Tuple, List
from loguru import logger
from pathlib import Path
from typing import List, Optional, Tuple
import sqlmodel
from alembic import command
from alembic.autogenerate import compare_metadata
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from alembic.autogenerate import compare_metadata
from sqlalchemy import Engine
from sqlmodel import SQLModel
from alembic.util.exc import CommandError
from loguru import logger
from sqlalchemy import Engine, text
from sqlmodel import SQLModel
class SchemaManager:
"""
Manages database schema validation and migrations using Alembic.
Provides automatic schema validation, migrations, and safe upgrades.
Args:
engine: SQLAlchemy engine instance
auto_upgrade: Whether to automatically upgrade schema when differences found
init_mode: Controls initialization behavior:
- "none": No automatic initialization (raises error if not set up)
- "auto": Initialize if not present (default)
- "force": Always reinitialize, removing existing configuration
Operations are initiated explicitly by DatabaseManager.
"""
def __init__(
self,
engine: Engine,
base_dir: Optional[Path] = None,
auto_upgrade: bool = True,
init_mode: str = "auto"
):
if init_mode not in ["none", "auto", "force"]:
raise ValueError("init_mode must be one of: none, auto, force")
"""
Initialize configuration only - no filesystem or DB operations.
Args:
engine: SQLAlchemy engine instance
base_dir: Base directory for Alembic files. If None, uses current working directory
"""
# Convert string path to Path object if necessary
if isinstance(base_dir, str):
base_dir = Path(base_dir)
self.engine = engine
self.auto_upgrade = auto_upgrade
# Use provided base_dir or default to class file location
self.base_dir = base_dir or Path(__file__).parent
self.alembic_dir = self.base_dir / 'alembic'
self.alembic_ini_path = self.base_dir / 'alembic.ini'
self.alembic_dir = self.base_dir / "alembic"
self.alembic_ini_path = self.base_dir / "alembic.ini"
# Create base directory if it doesn't exist
self.base_dir.mkdir(parents=True, exist_ok=True)
def initialize_migrations(self, force: bool = False) -> bool:
try:
if force:
logger.info("Force reinitialization of migrations...")
self._cleanup_existing_alembic()
if not self._initialize_alembic():
return False
else:
try:
self._validate_alembic_setup()
logger.info("Using existing Alembic configuration")
self._update_configuration()
except FileNotFoundError:
logger.info("Initializing new Alembic configuration")
if not self._initialize_alembic():
return False
# Initialize based on mode
if init_mode == "force":
self._cleanup_existing_alembic()
self._initialize_alembic()
else:
try:
self._validate_alembic_setup()
logger.info("Using existing Alembic configuration")
# Update existing configuration
self._update_configuration()
except FileNotFoundError:
if init_mode == "none":
raise
logger.info("Initializing new Alembic configuration")
self._initialize_alembic()
# Only generate initial revision if alembic is properly initialized
logger.info("Creating initial migration...")
return self.generate_revision("Initial schema") is not None
except Exception as e:
logger.error(f"Failed to initialize migrations: {e}")
return False
def _update_configuration(self) -> None:
"""Updates existing Alembic configuration with current settings."""
@ -70,11 +73,11 @@ class SchemaManager:
# Update alembic.ini
config_content = self._generate_alembic_ini_content()
with open(self.alembic_ini_path, 'w') as f:
with open(self.alembic_ini_path, "w") as f:
f.write(config_content)
# Update env.py
env_path = self.alembic_dir / 'env.py'
env_path = self.alembic_dir / "env.py"
if env_path.exists():
self._update_env_py(env_path)
else:
@ -82,37 +85,22 @@ class SchemaManager:
def _cleanup_existing_alembic(self) -> None:
"""
Safely removes existing Alembic configuration while preserving versions directory.
Completely remove existing Alembic configuration including versions.
For fresh initialization, we don't need to preserve anything.
"""
logger.info(
"Cleaning up existing Alembic configuration while preserving versions...")
logger.info("Cleaning up existing Alembic configuration...")
# Create a backup of versions directory if it exists
if self.alembic_dir.exists() and (self.alembic_dir / 'versions').exists():
logger.info("Preserving existing versions directory")
# Remove alembic directory contents EXCEPT versions
# Remove entire alembic directory if it exists
if self.alembic_dir.exists():
for item in self.alembic_dir.iterdir():
if item.name != 'versions':
try:
if item.is_dir():
shutil.rmtree(item)
logger.info(f"Removed directory: {item}")
else:
item.unlink()
logger.info(f"Removed file: {item}")
except Exception as e:
logger.error(f"Failed to remove {item}: {e}")
import shutil
shutil.rmtree(self.alembic_dir)
logger.info(f"Removed alembic directory: {self.alembic_dir}")
# Remove alembic.ini if it exists
if self.alembic_ini_path.exists():
try:
self.alembic_ini_path.unlink()
logger.info(
f"Removed existing alembic.ini: {self.alembic_ini_path}")
except Exception as e:
logger.error(f"Failed to remove alembic.ini: {e}")
self.alembic_ini_path.unlink()
logger.info("Removed alembic.ini")
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
"""
@ -124,51 +112,52 @@ class SchemaManager:
try:
self._validate_alembic_setup()
if force:
logger.info(
"Force initialization requested. Cleaning up existing configuration...")
logger.info("Force initialization requested. Cleaning up existing configuration...")
self._cleanup_existing_alembic()
self._initialize_alembic()
except FileNotFoundError:
logger.info("Alembic configuration not found. Initializing...")
if self.alembic_dir.exists():
logger.warning(
"Found existing alembic directory but missing configuration")
logger.warning("Found existing alembic directory but missing configuration")
self._cleanup_existing_alembic()
self._initialize_alembic()
logger.info("Alembic initialization complete")
def _initialize_alembic(self) -> None:
logger.info("Initializing Alembic configuration...")
# Create directories first
self.alembic_dir.mkdir(exist_ok=True)
versions_dir = self.alembic_dir / 'versions'
versions_dir.mkdir(exist_ok=True)
# Create env.py BEFORE running command.init
env_path = self.alembic_dir / 'env.py'
if not env_path.exists():
self._create_minimal_env_py(env_path)
logger.info("Created new env.py")
# Write alembic.ini
config_content = self._generate_alembic_ini_content()
with open(self.alembic_ini_path, 'w') as f:
f.write(config_content)
logger.info("Created alembic.ini")
# Now run alembic init
def _initialize_alembic(self) -> bool:
"""Initialize alembic structure and configuration"""
try:
config = self.get_alembic_config()
# Ensure parent directory exists
self.alembic_dir.parent.mkdir(exist_ok=True)
# Run alembic init to create fresh directory structure
logger.info("Initializing alembic directory structure...")
# Create initial config file for alembic init
config_content = self._generate_alembic_ini_content()
with open(self.alembic_ini_path, "w") as f:
f.write(config_content)
# Use the config we just created
config = Config(str(self.alembic_ini_path))
command.init(config, str(self.alembic_dir))
logger.info("Initialized Alembic directory structure")
except CommandError as e:
if "already exists" not in str(e):
raise
# Update script template after initialization
self.update_script_template()
# Update env.py with our customizations
self._update_env_py(self.alembic_dir / "env.py")
logger.info("Alembic initialization complete")
return True
except Exception as e:
# Explicitly convert error to string
logger.error(f"Failed to initialize alembic: {str(e)}")
return False
def _create_minimal_env_py(self, env_path: Path) -> None:
"""Creates a minimal env.py file for Alembic."""
content = '''
content = """
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
@ -201,7 +190,7 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
context.configure(
connection=connection,
connection=connection,
target_metadata=target_metadata,
compare_type=True
)
@ -211,9 +200,9 @@ def run_migrations_online() -> None:
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()'''
run_migrations_online()"""
with open(env_path, 'w') as f:
with open(env_path, "w") as f:
f.write(content)
def _generate_alembic_ini_content(self) -> str:
@ -260,6 +249,29 @@ format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
""".strip()
def update_script_template(self):
"""Update the Alembic script template to include SQLModel."""
template_path = self.alembic_dir / "script.py.mako"
try:
with open(template_path, "r") as f:
content = f.read()
# Add sqlmodel import to imports section
import_section = "from alembic import op\nimport sqlalchemy as sa"
new_imports = "from alembic import op\nimport sqlalchemy as sa\nimport sqlmodel"
content = content.replace(import_section, new_imports)
with open(template_path, "w") as f:
f.write(content)
logger.info("Updated script template")
return True
except Exception as e:
logger.error(f"Failed to update script template: {e}")
return False
def _update_env_py(self, env_path: Path) -> None:
"""
Updates the env.py file to use SQLModel metadata.
@ -268,27 +280,45 @@ datefmt = %H:%M:%S
self._create_minimal_env_py(env_path)
return
try:
with open(env_path, 'r') as f:
with open(env_path, "r") as f:
content = f.read()
# Add SQLModel import
# Add SQLModel import if not present
if "from sqlmodel import SQLModel" not in content:
content = "from sqlmodel import SQLModel\n" + content
# Replace target_metadata
content = content.replace("target_metadata = None", "target_metadata = SQLModel.metadata")
# Update both configure blocks properly
content = content.replace(
"target_metadata = None",
"target_metadata = SQLModel.metadata"
"""context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)""",
"""context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
)""",
)
# Add compare_type=True to context.configure
if "context.configure(" in content and "compare_type=True" not in content:
content = content.replace(
"context.configure(",
"context.configure(compare_type=True,"
)
content = content.replace(
""" context.configure(
connection=connection, target_metadata=target_metadata
)""",
""" context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
)""",
)
with open(env_path, 'w') as f:
with open(env_path, "w") as f:
f.write(content)
logger.info("Updated env.py with SQLModel metadata")
@ -297,6 +327,7 @@ datefmt = %H:%M:%S
raise
# Fixed: use keyword-only argument
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
"""
Ensures Alembic is properly set up, initializing if necessary.
@ -307,32 +338,24 @@ datefmt = %H:%M:%S
try:
self._validate_alembic_setup()
if force:
logger.info(
"Force initialization requested. Cleaning up existing configuration...")
logger.info("Force initialization requested. Cleaning up existing configuration...")
self._cleanup_existing_alembic()
self._initialize_alembic()
except FileNotFoundError:
logger.info("Alembic configuration not found. Initializing...")
if self.alembic_dir.exists():
logger.warning(
"Found existing alembic directory but missing configuration")
logger.warning("Found existing alembic directory but missing configuration")
self._cleanup_existing_alembic()
self._initialize_alembic()
logger.info("Alembic initialization complete")
def _validate_alembic_setup(self) -> None:
"""Validates that Alembic is properly configured."""
required_files = [
self.alembic_ini_path,
self.alembic_dir / 'env.py',
self.alembic_dir / 'versions'
]
required_files = [self.alembic_ini_path, self.alembic_dir / "env.py", self.alembic_dir / "versions"]
missing = [f for f in required_files if not f.exists()]
if missing:
raise FileNotFoundError(
f"Alembic configuration incomplete. Missing: {', '.join(str(f) for f in missing)}"
)
raise FileNotFoundError(f"Alembic configuration incomplete. Missing: {', '.join(str(f) for f in missing)}")
def get_alembic_config(self) -> Config:
"""
@ -430,7 +453,7 @@ datefmt = %H:%M:%S
def check_and_upgrade(self) -> Tuple[bool, str]:
"""
Checks schema status and upgrades if necessary (and auto_upgrade is True).
Checks schema status and upgrades if necessary.
Returns:
Tuple[bool, str]: (action_taken, status_message)
@ -438,13 +461,11 @@ datefmt = %H:%M:%S
needs_upgrade, status = self.check_schema_status()
if needs_upgrade:
if self.auto_upgrade:
if self.upgrade_schema():
return True, "Schema was automatically upgraded"
else:
return False, "Automatic schema upgrade failed"
# Remove the auto_upgrade check since we explicitly called this method
if self.upgrade_schema():
return True, "Schema was automatically upgraded"
else:
return False, f"Schema needs upgrade but auto_upgrade is disabled. Status: {status}"
return False, "Automatic schema upgrade failed"
return False, status
@ -460,11 +481,7 @@ datefmt = %H:%M:%S
"""
try:
config = self.get_alembic_config()
command.revision(
config,
message=message,
autogenerate=True
)
command.revision(config, message=message, autogenerate=True)
return self.get_head_revision()
except Exception as e:
@ -512,26 +529,40 @@ datefmt = %H:%M:%S
def ensure_schema_up_to_date(self) -> bool:
"""
Ensures the database schema is up to date, generating and applying migrations if needed.
Returns:
bool: True if schema is up to date or was successfully updated
Reset migrations and create fresh migration for current schema state.
"""
try:
# Check for unmigrated changes
differences = self.get_schema_differences()
if differences:
# Generate new migration
revision = self.generate_revision("auto-generated")
if not revision:
return False
logger.info(f"Generated new migration: {revision}")
logger.info("Resetting migrations and updating to current schema...")
# Apply any pending migrations
upgraded, status = self.check_and_upgrade()
if not upgraded and "needs upgrade" in status.lower():
# 1. Clear the entire alembic directory
if self.alembic_dir.exists():
shutil.rmtree(self.alembic_dir)
logger.info("Cleared alembic directory")
# 2. Clear alembic_version table
with self.engine.connect() as connection:
connection.execute(text("DROP TABLE IF EXISTS alembic_version"))
connection.commit()
logger.info("Reset alembic version")
# 3. Reinitialize alembic from scratch
if not self._initialize_alembic():
logger.error("Failed to reinitialize alembic")
return False
# 4. Generate fresh migration from current schema
revision = self.generate_revision("current_schema")
if not revision:
logger.error("Failed to generate new migration")
return False
logger.info(f"Generated fresh migration: {revision}")
# 5. Apply the migration
if not self.upgrade_schema():
logger.error("Failed to apply migration")
return False
logger.info("Successfully applied migration")
return True
except Exception as e:

View File

@ -1,2 +1,11 @@
from .db import *
from .types import *
from .db import Agent, LinkTypes, Message, Model, Run, RunStatus, Session, Team, Tool
from .types import (
AgentConfig,
ComponentConfigInput,
MessageConfig,
ModelConfig,
Response,
TeamConfig,
TeamResult,
ToolConfig,
)

View File

@ -2,24 +2,28 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional, Union, Tuple, Type
from sqlalchemy import ForeignKey, Integer, UniqueConstraint
from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func, Relationship, SQLModel
from typing import List, Optional, Tuple, Type, Union
from uuid import UUID, uuid4
from .types import ToolConfig, ModelConfig, AgentConfig, TeamConfig, MessageConfig, MessageMeta
from loguru import logger
from pydantic import BaseModel
from sqlalchemy import ForeignKey, Integer, UniqueConstraint
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel, func
from .types import AgentConfig, MessageConfig, MessageMeta, ModelConfig, TeamConfig, TeamResult, ToolConfig
# added for python3.11 and sqlmodel 0.0.22 incompatibility
if hasattr(SQLModel, "model_config"):
SQLModel.model_config["protected_namespaces"] = ()
elif hasattr(SQLModel, "Config"):
class CustomSQLModel(SQLModel):
class Config:
protected_namespaces = ()
SQLModel = CustomSQLModel
else:
print("Warning: Unable to set protected_namespaces.")
logger.warning("Unable to set protected_namespaces.")
# pylint: disable=protected-access
@ -36,7 +40,7 @@ class ComponentTypes(Enum):
ComponentTypes.TEAM: Team,
ComponentTypes.AGENT: Agent,
ComponentTypes.MODEL: Model,
ComponentTypes.TOOL: Tool
ComponentTypes.TOOL: Tool,
}[self]
@ -51,7 +55,7 @@ class LinkTypes(Enum):
return {
LinkTypes.AGENT_MODEL: (Agent, Model, AgentModelLink),
LinkTypes.AGENT_TOOL: (Agent, Tool, AgentToolLink),
LinkTypes.TEAM_AGENT: (Team, Agent, TeamAgentLink)
LinkTypes.TEAM_AGENT: (Team, Agent, TeamAgentLink),
}[self]
@property
@ -70,40 +74,34 @@ class LinkTypes(Enum):
# link models
class AgentToolLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint('agent_id', 'sequence',
name='unique_agent_tool_sequence'),
{'sqlite_autoincrement': True}
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
{"sqlite_autoincrement": True},
)
agent_id: int = Field(default=None, primary_key=True,
foreign_key="agent.id")
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
tool_id: int = Field(default=None, primary_key=True, foreign_key="tool.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
class AgentModelLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint('agent_id', 'sequence',
name='unique_agent_tool_sequence'),
{'sqlite_autoincrement': True}
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
{"sqlite_autoincrement": True},
)
agent_id: int = Field(default=None, primary_key=True,
foreign_key="agent.id")
model_id: int = Field(default=None, primary_key=True,
foreign_key="model.id")
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
model_id: int = Field(default=None, primary_key=True, foreign_key="model.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
class TeamAgentLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint('agent_id', 'sequence',
name='unique_agent_tool_sequence'),
{'sqlite_autoincrement': True}
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
{"sqlite_autoincrement": True},
)
team_id: int = Field(default=None, primary_key=True, foreign_key="team.id")
agent_id: int = Field(default=None, primary_key=True,
foreign_key="agent.id")
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
# database models
@ -120,10 +118,8 @@ class Tool(SQLModel, table=True):
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[ToolConfig, dict] = Field(
default_factory=ToolConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(
back_populates="tools", link_model=AgentToolLink)
config: Union[ToolConfig, dict] = Field(default_factory=ToolConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(back_populates="tools", link_model=AgentToolLink)
class Model(SQLModel, table=True):
@ -139,10 +135,8 @@ class Model(SQLModel, table=True):
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[ModelConfig, dict] = Field(
default_factory=ModelConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(
back_populates="models", link_model=AgentModelLink)
config: Union[ModelConfig, dict] = Field(default_factory=ModelConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(back_populates="models", link_model=AgentModelLink)
class Team(SQLModel, table=True):
@ -158,10 +152,8 @@ class Team(SQLModel, table=True):
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[TeamConfig, dict] = Field(
default_factory=TeamConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(
back_populates="teams", link_model=TeamAgentLink)
config: Union[TeamConfig, dict] = Field(default_factory=TeamConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(back_populates="teams", link_model=TeamAgentLink)
class Agent(SQLModel, table=True):
@ -177,14 +169,10 @@ class Agent(SQLModel, table=True):
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[AgentConfig, dict] = Field(
default_factory=AgentConfig, sa_column=Column(JSON))
tools: List[Tool] = Relationship(
back_populates="agents", link_model=AgentToolLink)
models: List[Model] = Relationship(
back_populates="agents", link_model=AgentModelLink)
teams: List[Team] = Relationship(
back_populates="agents", link_model=TeamAgentLink)
config: Union[AgentConfig, dict] = Field(default_factory=AgentConfig, sa_column=Column(JSON))
tools: List[Tool] = Relationship(back_populates="agents", link_model=AgentToolLink)
models: List[Model] = Relationship(back_populates="agents", link_model=AgentModelLink)
teams: List[Team] = Relationship(back_populates="agents", link_model=TeamAgentLink)
class Message(SQLModel, table=True):
@ -200,17 +188,12 @@ class Message(SQLModel, table=True):
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[MessageConfig, dict] = Field(
default_factory=MessageConfig, sa_column=Column(JSON))
config: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"))
)
run_id: Optional[UUID] = Field(
default=None, foreign_key="run.id"
)
message_meta: Optional[Union[MessageMeta, dict]] = Field(
default={}, sa_column=Column(JSON))
run_id: Optional[UUID] = Field(default=None, foreign_key="run.id")
message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))
class Session(SQLModel, table=True):
@ -226,9 +209,7 @@ class Session(SQLModel, table=True):
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
team_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE"))
)
team_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE")))
name: Optional[str] = None
@ -242,41 +223,59 @@ class RunStatus(str, Enum):
class Run(SQLModel, table=True):
"""Represents a single execution run within a session"""
__table_args__ = {"sqlite_autoincrement": True}
# Primary key using UUID
id: UUID = Field(
default_factory=uuid4,
primary_key=True,
index=True
id: UUID = Field(default_factory=uuid4, primary_key=True, index=True)
created_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), server_default=func.now())
)
updated_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), onupdate=func.now())
)
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"), nullable=False)
)
status: RunStatus = Field(default=RunStatus.CREATED)
# Timestamps using the same pattern as other models
# Store the original user task
task: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
# Store TeamResult which contains TaskResult
team_result: Union[TeamResult, dict] = Field(default=None, sa_column=Column(JSON))
error_message: Optional[str] = None
version: Optional[str] = "0.0.1"
messages: Union[List[Message], List[dict]] = Field(default_factory=list, sa_column=Column(JSON))
class Config:
json_encoders = {UUID: str, datetime: lambda v: v.isoformat()}
class GalleryConfig(SQLModel, table=False):
id: UUID = Field(default_factory=uuid4, primary_key=True, index=True)
title: Optional[str] = None
description: Optional[str] = None
run: Run
team: TeamConfig = None
tags: Optional[List[str]] = None
visibility: str = "public" # public, private, shared
class Config:
json_encoders = {UUID: str, datetime: lambda v: v.isoformat()}
class Gallery(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now())
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
)
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now())
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
)
# Foreign key to Session
session_id: Optional[int] = Field(
default=None,
sa_column=Column(
Integer,
ForeignKey("session.id", ondelete="CASCADE"),
nullable=False
)
)
# Run status and metadata
status: RunStatus = Field(default=RunStatus.CREATED)
error_message: Optional[str] = None
# Metadata storage following pattern from Message model
run_meta: dict = Field(default={}, sa_column=Column(JSON))
# Version tracking like other models
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[GalleryConfig, dict] = Field(default_factory=GalleryConfig, sa_column=Column(JSON))

View File

@ -1,10 +1,10 @@
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from autogen_agentchat.base import TaskResult
from pydantic import BaseModel
from autogen_agentchat.base._task import TaskResult
class ModelTypes(str, Enum):
@ -29,9 +29,10 @@ class TerminationTypes(str, Enum):
MAX_MESSAGES = "MaxMessageTermination"
STOP_MESSAGE = "StopMessageTermination"
TEXT_MENTION = "TextMentionTermination"
COMBINATION = "CombinationTermination"
class ComponentType(str, Enum):
class ComponentTypes(str, Enum):
TEAM = "team"
AGENT = "agent"
MODEL = "model"
@ -40,11 +41,9 @@ class ComponentType(str, Enum):
class BaseConfig(BaseModel):
model_config = {
"protected_namespaces": ()
}
model_config = {"protected_namespaces": ()}
version: str = "1.0.0"
component_type: ComponentType
component_type: ComponentTypes
class MessageConfig(BaseModel):
@ -58,7 +57,7 @@ class ModelConfig(BaseConfig):
model_type: ModelTypes
api_key: Optional[str] = None
base_url: Optional[str] = None
component_type: ComponentType = ComponentType.MODEL
component_type: ComponentTypes = ComponentTypes.MODEL
class ToolConfig(BaseConfig):
@ -66,7 +65,7 @@ class ToolConfig(BaseConfig):
description: str
content: str
tool_type: ToolTypes
component_type: ComponentType = ComponentType.TOOL
component_type: ComponentTypes = ComponentTypes.TOOL
class AgentConfig(BaseConfig):
@ -76,14 +75,18 @@ class AgentConfig(BaseConfig):
model_client: Optional[ModelConfig] = None
tools: Optional[List[ToolConfig]] = None
description: Optional[str] = None
component_type: ComponentType = ComponentType.AGENT
component_type: ComponentTypes = ComponentTypes.AGENT
class TerminationConfig(BaseConfig):
termination_type: TerminationTypes
# Fields for basic terminations
max_messages: Optional[int] = None
text: Optional[str] = None
component_type: ComponentType = ComponentType.TERMINATION
# Fields for combinations
operator: Optional[Literal["and", "or"]] = None
conditions: Optional[List["TerminationConfig"]] = None
component_type: ComponentTypes = ComponentTypes.TERMINATION
class TeamConfig(BaseConfig):
@ -93,7 +96,7 @@ class TeamConfig(BaseConfig):
model_client: Optional[ModelConfig] = None
selector_prompt: Optional[str] = None
termination_condition: Optional[TerminationConfig] = None
component_type: ComponentType = ComponentType.TEAM
component_type: ComponentTypes = ComponentTypes.TEAM
class TeamResult(BaseModel):
@ -111,6 +114,7 @@ class MessageMeta(BaseModel):
log: Optional[List[dict]] = None
usage: Optional[List[dict]] = None
# web request/response data models
@ -126,12 +130,6 @@ class SocketMessage(BaseModel):
type: str
ComponentConfig = Union[
TeamConfig,
AgentConfig,
ModelConfig,
ToolConfig,
TerminationConfig
]
ComponentConfig = Union[TeamConfig, AgentConfig, ModelConfig, ToolConfig, TerminationConfig]
ComponentConfigInput = Union[str, Path, dict, ComponentConfig]

View File

@ -1,108 +0,0 @@
# metrics - agent_frequency, execution_count, tool_count,
from typing import Dict, List, Optional
from .datamodel import Message, MessageMeta
class Profiler:
"""
Profiler class to profile agent task runs and compute metrics
for performance evaluation.
"""
def __init__(self):
self.metrics: List[Dict] = []
def _is_code(self, message: Message) -> bool:
"""
Check if the message contains code.
:param message: The message instance to check.
:return: True if the message contains code, False otherwise.
"""
content = message.get("message").get("content").lower()
return "```" in content
def _is_tool(self, message: Message) -> bool:
"""
Check if the message uses a tool.
:param message: The message instance to check.
:return: True if the message uses a tool, False otherwise.
"""
content = message.get("message").get("content").lower()
return "from skills import" in content
def _is_code_execution(self, message: Message) -> bool:
"""
Check if the message indicates code execution.
:param message: The message instance to check.
:return: dict with is_code and status keys.
"""
content = message.get("message").get("content").lower()
if "exitcode:" in content:
status = "exitcode: 0" in content
return {"is_code": True, "status": status}
else:
return {"is_code": False, "status": False}
def _is_terminate(self, message: Message) -> bool:
"""
Check if the message indicates termination.
:param message: The message instance to check.
:return: True if the message indicates termination, False otherwise.
"""
content = message.get("message").get("content").lower()
return "terminate" in content
def profile(self, agent_message: Message):
"""
Profile the agent task run and compute metrics.
:param agent: The agent instance that ran the task.
:param task: The task instance that was run.
"""
meta = MessageMeta(**agent_message.meta)
print(meta.log)
usage = meta.usage
messages = meta.messages
profile = []
bar = []
stats = {}
total_code_executed = 0
success_code_executed = 0
agents = []
for message in messages:
agent = message.get("sender")
is_code = self._is_code(message)
is_tool = self._is_tool(message)
is_code_execution = self._is_code_execution(message)
total_code_executed += is_code_execution["is_code"]
success_code_executed += 1 if is_code_execution["status"] else 0
row = {
"agent": agent,
"tool_call": is_code,
"code_execution": is_code_execution,
"terminate": self._is_terminate(message),
}
bar_row = {
"agent": agent,
"tool_call": "tool call" if is_tool else "no tool call",
"code_execution": (
"success"
if is_code_execution["status"]
else "failure" if is_code_execution["is_code"] else "no code"
),
"message": 1,
}
profile.append(row)
bar.append(bar_row)
agents.append(agent)
code_success_rate = (success_code_executed / total_code_executed if total_code_executed > 0 else 0) * 100
stats["code_success_rate"] = code_success_rate
stats["total_code_executed"] = total_code_executed
return {"profile": profile, "bar": bar, "stats": stats, "agents": set(agents), "usage": usage}

View File

@ -1,50 +1,39 @@
from typing import AsyncGenerator, Callable, Union, Optional
import time
from .database import ComponentFactory, Component
from .datamodel import TeamResult, TaskResult, ComponentConfigInput
from autogen_agentchat.messages import ChatMessage, AgentMessage
from typing import AsyncGenerator, Callable, Optional, Union
from autogen_agentchat.base import TaskResult
from autogen_agentchat.messages import AgentMessage, ChatMessage
from autogen_core.base import CancellationToken
from .database import Component, ComponentFactory
from .datamodel import ComponentConfigInput, TeamResult
class TeamManager:
def __init__(self) -> None:
self.component_factory = ComponentFactory()
async def _create_team(
self,
team_config: ComponentConfigInput,
input_func: Optional[Callable] = None
) -> Component:
async def _create_team(self, team_config: ComponentConfigInput, input_func: Optional[Callable] = None) -> Component:
"""Create team instance with common setup logic"""
return await self.component_factory.load(
team_config,
input_func=input_func
)
return await self.component_factory.load(team_config, input_func=input_func)
def _create_result(self, task_result: TaskResult, start_time: float) -> TeamResult:
"""Create TeamResult with timing info"""
return TeamResult(
task_result=task_result,
usage="",
duration=time.time() - start_time
)
return TeamResult(task_result=task_result, usage="", duration=time.time() - start_time)
async def run_stream(
self,
task: str,
team_config: ComponentConfigInput,
input_func: Optional[Callable] = None,
cancellation_token: Optional[CancellationToken] = None
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[AgentMessage, ChatMessage, TaskResult], None]:
"""Stream the team's execution results"""
start_time = time.time()
try:
team = await self._create_team(team_config, input_func)
stream = team.run_stream(
task=task,
cancellation_token=cancellation_token
)
stream = team.run_stream(task=task, cancellation_token=cancellation_token)
async for message in stream:
if cancellation_token and cancellation_token.is_cancelled():
@ -63,15 +52,12 @@ class TeamManager:
task: str,
team_config: ComponentConfigInput,
input_func: Optional[Callable] = None,
cancellation_token: Optional[CancellationToken] = None
cancellation_token: Optional[CancellationToken] = None,
) -> TeamResult:
"""Original non-streaming run method with optional cancellation"""
start_time = time.time()
team = await self._create_team(team_config, input_func)
result = await team.run(
task=task,
cancellation_token=cancellation_token
)
result = await team.run(task=task, cancellation_token=cancellation_token)
return self._create_result(result, start_time)

View File

@ -1 +0,0 @@
from .utils import *

View File

@ -10,7 +10,6 @@ from typing import Any, Dict, List, Tuple, Union
from dotenv import load_dotenv
from loguru import logger
from ..datamodel import Model
from ..version import APP_NAME
@ -153,8 +152,7 @@ def get_modified_files(start_timestamp: float, end_timestamp: float, source_dir:
for root, dirs, files in os.walk(source_dir):
# Update directories and files to exclude those to be ignored
dirs[:] = [d for d in dirs if d not in ignore_files]
files[:] = [f for f in files if f not in ignore_files and os.path.splitext(f)[
1] not in ignore_extensions]
files[:] = [f for f in files if f not in ignore_files and os.path.splitext(f)[1] not in ignore_extensions]
for file in files:
file_path = os.path.join(root, file)
@ -163,9 +161,7 @@ def get_modified_files(start_timestamp: float, end_timestamp: float, source_dir:
# Verify if the file was modified within the given timestamp range
if start_timestamp <= file_mtime <= end_timestamp:
file_relative_path = (
"files/user" +
file_path.split(
"files/user", 1)[1] if "files/user" in file_path else ""
"files/user" + file_path.split("files/user", 1)[1] if "files/user" in file_path else ""
)
file_type = get_file_type(file_path)
@ -253,41 +249,27 @@ def sanitize_model(model: Model):
model = model.model_dump()
valid_keys = ["model", "base_url", "api_key", "api_type", "api_version"]
# only add key if value is not None
sanitized_model = {k: v for k, v in model.items() if (
v is not None and v != "") and k in valid_keys}
sanitized_model = {k: v for k, v in model.items() if (v is not None and v != "") and k in valid_keys}
return sanitized_model
def test_model(model: Model):
"""
Test the model endpoint by sending a simple message to the model and returning the response.
"""
class Version:
def __init__(self, ver_str: str):
try:
# Split into major.minor.patch
self.major, self.minor, self.patch = map(int, ver_str.split("."))
except (ValueError, AttributeError) as err:
raise ValueError(f"Invalid version format: {ver_str}. Expected: major.minor.patch") from err
print("Testing model", model)
def __str__(self):
return f"{self.major}.{self.minor}.{self.patch}"
def __eq__(self, other):
if isinstance(other, str):
other = Version(other)
return (self.major, self.minor, self.patch) == (other.major, other.minor, other.patch)
# def summarize_chat_history(task: str, messages: List[Dict[str, str]], client: ModelClient):
# """
# Summarize the chat history using the model endpoint and returning the response.
# """
# summarization_system_prompt = f"""
# You are a helpful assistant that is able to review the chat history between a set of agents (userproxy agents, assistants etc) as they try to address a given TASK and provide a summary. Be SUCCINCT but also comprehensive enough to allow others (who cannot see the chat history) understand and recreate the solution.
# The task requested by the user is:
# ===
# {task}
# ===
# The summary should focus on extracting the actual solution to the task from the chat history (assuming the task was addressed) such that any other agent reading the summary will understand what the actual solution is. Use a neutral tone and DO NOT directly mention the agents. Instead only focus on the actions that were carried out (e.g. do not say 'assistant agent generated some code visualization code ..' instead say say 'visualization code was generated ..'. The answer should be framed as a response to the user task. E.g. if the task is "What is the height of the Eiffel tower", the summary should be "The height of the Eiffel Tower is ...").
# """
# summarization_prompt = [
# {
# "role": "system",
# "content": summarization_system_prompt,
# },
# {
# "role": "user",
# "content": f"Summarize the following chat history. {str(messages)}",
# },
# ]
# response = client.create(messages=summarization_prompt, cache_seed=None)
# return response.choices[0].message.content
def __gt__(self, other):
if isinstance(other, str):
other = Version(other)
return (self.major, self.minor, self.patch) > (other.major, other.minor, other.patch)

View File

@ -1,3 +1,3 @@
VERSION = "0.4.0.dev37"
VERSION = "0.4.0.dev38"
__version__ = VERSION
APP_NAME = "autogenstudio"

View File

@ -1,18 +1,19 @@
# api/app.py
import os
from contextlib import asynccontextmanager
from typing import AsyncGenerator
# import logging
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from loguru import logger
from .routes import sessions, runs, teams, agents, models, tools, ws
from .deps import init_managers, cleanup_managers
from .config import settings
from .initialization import AppInitializer
from ..version import VERSION
from .config import settings
from .deps import cleanup_managers, init_managers
from .initialization import AppInitializer
from .routes import agents, models, runs, sessions, teams, tools, ws
# Configure logging
# logger = logging.getLogger(__name__)
@ -54,6 +55,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
except Exception as e:
logger.error(f"Error during shutdown: {str(e)}")
# Create FastAPI application
app = FastAPI(lifespan=lifespan, debug=True)
@ -143,6 +145,7 @@ async def get_version():
"data": {"version": VERSION},
}
# Health check endpoint
@ -154,6 +157,7 @@ async def health_check():
"message": "Service is healthy",
}
# Mount static file directories
app.mount("/api", api)
app.mount(
@ -172,7 +176,7 @@ async def internal_error_handler(request, exc):
return {
"status": False,
"message": "Internal server error",
"detail": str(exc) if settings.API_DOCS else "Internal server error"
"detail": str(exc) if settings.API_DOCS else "Internal server error",
}

View File

@ -1,14 +1,14 @@
# api/deps.py
from typing import Optional
from fastapi import Depends, HTTPException, status
import logging
from contextlib import contextmanager
from typing import Optional
from ..database import DatabaseManager
from .managers.connection import WebSocketManager
from fastapi import Depends, HTTPException, status
from ..database import ConfigurationManager, DatabaseManager
from ..teammanager import TeamManager
from .config import settings
from ..database import ConfigurationManager
from .managers.connection import WebSocketManager
logger = logging.getLogger(__name__)
@ -25,17 +25,16 @@ def get_db_context():
"""Provide a transactional scope around a series of operations."""
if not _db_manager:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Database manager not initialized"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database manager not initialized"
)
try:
yield _db_manager
except Exception as e:
logger.error(f"Database operation failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Database operation failed"
)
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database operation failed"
) from e
# Dependency providers
@ -44,8 +43,7 @@ async def get_db() -> DatabaseManager:
"""Dependency provider for database manager"""
if not _db_manager:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Database manager not initialized"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database manager not initialized"
)
return _db_manager
@ -54,8 +52,7 @@ async def get_websocket_manager() -> WebSocketManager:
"""Dependency provider for connection manager"""
if not _websocket_manager:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Connection manager not initialized"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Connection manager not initialized"
)
return _websocket_manager
@ -63,12 +60,10 @@ async def get_websocket_manager() -> WebSocketManager:
async def get_team_manager() -> TeamManager:
"""Dependency provider for team manager"""
if not _team_manager:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Team manager not initialized"
)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Team manager not initialized")
return _team_manager
# Authentication dependency
@ -83,6 +78,7 @@ async def get_current_user(
# Implement your user authentication here
return "user_id" # Replace with actual user identification
# Manager initialization and cleanup
@ -94,20 +90,16 @@ async def init_managers(database_uri: str, config_dir: str, app_root: str) -> No
try:
# Initialize database manager
_db_manager = DatabaseManager(
engine_uri=database_uri, auto_upgrade=settings.UPGRADE_DATABASE, base_dir=app_root)
_db_manager.create_db_and_tables()
_db_manager = DatabaseManager(engine_uri=database_uri, base_dir=app_root)
_db_manager.initialize_database(auto_upgrade=settings.UPGRADE_DATABASE)
# init default team config
_team_config_manager = ConfigurationManager(db_manager=_db_manager)
import_result = await _team_config_manager.import_directory(
config_dir, settings.DEFAULT_USER_ID, check_exists=True)
await _team_config_manager.import_directory(config_dir, settings.DEFAULT_USER_ID, check_exists=True)
# Initialize connection manager
_websocket_manager = WebSocketManager(
db_manager=_db_manager
)
_websocket_manager = WebSocketManager(db_manager=_db_manager)
logger.info("Connection manager initialized")
# Initialize team manager
@ -149,6 +141,7 @@ async def cleanup_managers() -> None:
logger.info("All managers cleaned up")
# Utility functions for dependency management
@ -157,19 +150,17 @@ def get_manager_status() -> dict:
return {
"database_manager": _db_manager is not None,
"websocket_manager": _websocket_manager is not None,
"team_manager": _team_manager is not None
"team_manager": _team_manager is not None,
}
# Combined dependencies
async def get_managers():
"""Get all managers in one dependency"""
return {
"db": await get_db(),
"connection": await get_websocket_manager(),
"team": await get_team_manager()
}
return {"db": await get_db(), "connection": await get_websocket_manager(), "team": await get_team_manager()}
# Error handling for manager operations
@ -183,19 +174,21 @@ class ManagerOperationError(Exception):
self.detail = detail
super().__init__(f"{manager_name} failed during {operation}: {detail}")
# Dependency for requiring specific managers
def require_managers(*manager_names: str):
"""Decorator to require specific managers for a route"""
async def dependency():
status = get_manager_status()
missing = [name for name in manager_names if not status.get(
f"{name}_manager")]
missing = [name for name in manager_names if not status.get(f"{name}_manager")]
if missing:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Required managers not available: {', '.join(missing)}"
detail=f"Required managers not available: {', '.join(missing)}",
)
return True
return Depends(dependency)

View File

@ -2,15 +2,17 @@
import os
from pathlib import Path
from typing import Dict
from pydantic import BaseModel
from loguru import logger
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel
from .config import Settings
class _AppPaths(BaseModel):
"""Internal model representing all application paths"""
app_root: Path
static_root: Path
user_files: Path
@ -47,9 +49,7 @@ class AppInitializer:
"""Generate database URI based on settings or environment"""
if db_uri := os.getenv("AUTOGENSTUDIO_DATABASE_URI"):
return db_uri
return self.settings.DATABASE_URI.replace(
"./", str(app_root) + "/"
)
return self.settings.DATABASE_URI.replace("./", str(app_root) + "/")
def _init_paths(self) -> _AppPaths:
"""Initialize and return AppPaths instance"""
@ -60,14 +60,13 @@ class AppInitializer:
user_files=app_root / "files" / "user",
ui_root=self._app_path / "ui",
config_dir=app_root / self.settings.CONFIG_DIR,
database_uri=self._get_database_uri(app_root)
database_uri=self._get_database_uri(app_root),
)
def _create_directories(self) -> None:
"""Create all required directories"""
self.app_root.mkdir(parents=True, exist_ok=True)
dirs = [self.static_root, self.user_files,
self.ui_root, self.config_dir]
dirs = [self.static_root, self.user_files, self.ui_root, self.config_dir]
for path in dirs:
path.mkdir(parents=True, exist_ok=True)

View File

@ -1,16 +1,17 @@
import asyncio
from autogen_agentchat.base._task import TaskResult
from fastapi import WebSocket, WebSocketDisconnect
from typing import Callable, Dict, Optional, Any
from uuid import UUID
import logging
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Optional, Union
from uuid import UUID
from ...datamodel import Run, RunStatus, TeamResult
from ...database import DatabaseManager
from ...teammanager import TeamManager
from autogen_agentchat.base._task import TaskResult
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
from autogen_core.base import CancellationToken
from fastapi import WebSocket, WebSocketDisconnect
from ...database import DatabaseManager
from ...datamodel import Message, MessageConfig, Run, RunStatus, TeamResult
from ...teammanager import TeamManager
logger = logging.getLogger(__name__)
@ -26,12 +27,20 @@ class WebSocketManager:
self._closed_connections: set[UUID] = set()
self._input_responses: Dict[UUID, asyncio.Queue] = {}
self._cancel_message = TeamResult(task_result=TaskResult(messages=[TextMessage(
source="user", content="Run cancelled by user")], stop_reason="cancelled by user"), usage="", duration=0).model_dump()
self._cancel_message = TeamResult(
task_result=TaskResult(
messages=[TextMessage(source="user", content="Run cancelled by user")], stop_reason="cancelled by user"
),
usage="",
duration=0,
).model_dump()
def _get_stop_message(self, reason: str) -> dict:
return TeamResult(task_result=TaskResult(messages=[TextMessage(
source="user", content=reason)], stop_reason=reason), usage="", duration=0).model_dump()
return TeamResult(
task_result=TaskResult(messages=[TextMessage(source="user", content=reason)], stop_reason=reason),
usage="",
duration=0,
).model_dump()
async def connect(self, websocket: WebSocket, run_id: UUID) -> bool:
try:
@ -41,87 +50,118 @@ class WebSocketManager:
# Initialize input queue for this connection
self._input_responses[run_id] = asyncio.Queue()
run = await self._get_run(run_id)
if run:
run.status = RunStatus.ACTIVE
self.db_manager.upsert(run)
await self._send_message(run_id, {
"type": "system",
"status": "connected",
"timestamp": datetime.now(timezone.utc).isoformat()
})
await self._send_message(
run_id, {"type": "system", "status": "connected", "timestamp": datetime.now(timezone.utc).isoformat()}
)
return True
except Exception as e:
logger.error(f"Connection error for run {run_id}: {e}")
return False
async def start_stream(
self,
run_id: UUID,
team_manager: TeamManager,
task: str,
team_config: dict
) -> None:
async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None:
"""Start streaming task execution with proper run management"""
if run_id not in self._connections or run_id in self._closed_connections:
raise ValueError(f"No active connection for run {run_id}")
team_manager = TeamManager()
cancellation_token = CancellationToken()
self._cancellation_tokens[run_id] = cancellation_token
final_result = None
try:
# Create input function for this run
# Update run with task and status
run = await self._get_run(run_id)
if run:
run.task = MessageConfig(content=task, source="user").model_dump()
run.status = RunStatus.ACTIVE
self.db_manager.upsert(run)
input_func = self.create_input_func(run_id)
async for message in team_manager.run_stream(
task=task,
team_config=team_config,
input_func=input_func, # Pass the input function
cancellation_token=cancellation_token
task=task, team_config=team_config, input_func=input_func, cancellation_token=cancellation_token
):
if cancellation_token.is_cancelled() or run_id in self._closed_connections:
logger.info(
f"Stream cancelled or connection closed for run {run_id}")
logger.info(f"Stream cancelled or connection closed for run {run_id}")
break
formatted_message = self._format_message(message)
if formatted_message:
await self._send_message(run_id, formatted_message)
# Save message if it's a content message
if isinstance(message, (AgentMessage, ChatMessage)):
await self._save_message(run_id, message)
# Capture final result if it's a TeamResult
elif isinstance(message, TeamResult):
final_result = message.model_dump()
if not cancellation_token.is_cancelled() and run_id not in self._closed_connections:
await self._update_run_status(run_id, RunStatus.COMPLETE)
if final_result:
await self._update_run(run_id, RunStatus.COMPLETE, team_result=final_result)
else:
logger.warning(f"No final result captured for completed run {run_id}")
await self._update_run_status(run_id, RunStatus.COMPLETE)
else:
await self._send_message(run_id, {
"type": "completion",
"status": "cancelled",
"data": self._cancel_message,
"timestamp": datetime.now(timezone.utc).isoformat()
})
await self._update_run_status(run_id, RunStatus.STOPPED)
await self._send_message(
run_id,
{
"type": "completion",
"status": "cancelled",
"data": self._cancel_message,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
# Update run with cancellation result
await self._update_run(run_id, RunStatus.STOPPED, team_result=self._cancel_message)
except Exception as e:
logger.error(f"Stream error for run {run_id}: {e}")
await self._handle_stream_error(run_id, e)
finally:
self._cancellation_tokens.pop(run_id, None)
async def _save_message(self, run_id: UUID, message: Union[AgentMessage, ChatMessage]) -> None:
"""Save a message to the database"""
run = await self._get_run(run_id)
if run:
db_message = Message(
session_id=run.session_id,
run_id=run_id,
config=message.model_dump(),
user_id=None, # You might want to pass this from somewhere
)
self.db_manager.upsert(db_message)
async def _update_run(
self, run_id: UUID, status: RunStatus, team_result: Optional[dict] = None, error: Optional[str] = None
) -> None:
"""Update run status and result"""
run = await self._get_run(run_id)
if run:
run.status = status
if team_result:
run.team_result = team_result
if error:
run.error_message = error
self.db_manager.upsert(run)
def create_input_func(self, run_id: UUID) -> Callable:
"""Creates an input function for a specific run"""
async def input_handler(prompt: str = "") -> str:
try:
async def input_handler(prompt: str = "", cancellation_token: Optional[CancellationToken] = None) -> str:
try:
# Send input request to client
await self._send_message(run_id, {
"type": "input_request",
"prompt": prompt,
"data": {
"source": "system",
"content": prompt
await self._send_message(
run_id,
{
"type": "input_request",
"prompt": prompt,
"data": {"source": "system", "content": prompt},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
"timestamp": datetime.now(timezone.utc).isoformat()
})
)
# Wait for response
if run_id in self._input_responses:
@ -141,26 +181,37 @@ class WebSocketManager:
if run_id in self._input_responses:
await self._input_responses[run_id].put(response)
else:
logger.warning(
f"Received input response for inactive run {run_id}")
logger.warning(f"Received input response for inactive run {run_id}")
async def stop_run(self, run_id: UUID, reason: str) -> None:
"""Stop a running task"""
if run_id in self._cancellation_tokens:
logger.info(f"Stopping run {run_id}")
# self._cancellation_tokens[run_id].cancel()
# Send final message if connection still exists and not closed
if run_id in self._connections and run_id not in self._closed_connections:
try:
await self._send_message(run_id, {
"type": "completion",
"status": "cancelled",
"data": self._get_stop_message(reason),
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception:
pass
stop_message = self._get_stop_message(reason)
try:
# Update run record first
await self._update_run(run_id, status=RunStatus.STOPPED, team_result=stop_message)
# Then handle websocket communication if connection is active
if run_id in self._connections and run_id not in self._closed_connections:
await self._send_message(
run_id,
{
"type": "completion",
"status": "cancelled",
"data": stop_message,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
# Finally cancel the token
self._cancellation_tokens[run_id].cancel()
except Exception as e:
logger.error(f"Error stopping run {run_id}: {e}")
# We might want to force disconnect here if db update failed
# await self.disconnect(run_id) # Optional
async def disconnect(self, run_id: UUID) -> None:
"""Clean up connection and associated resources"""
@ -185,8 +236,7 @@ class WebSocketManager:
message: Message dictionary to send
"""
if run_id in self._closed_connections:
logger.warning(
f"Attempted to send message to closed connection for run {run_id}")
logger.warning(f"Attempted to send message to closed connection for run {run_id}")
return
try:
@ -194,36 +244,36 @@ class WebSocketManager:
websocket = self._connections[run_id]
await websocket.send_json(message)
except WebSocketDisconnect:
logger.warning(
f"WebSocket disconnected while sending message for run {run_id}")
logger.warning(f"WebSocket disconnected while sending message for run {run_id}")
await self.disconnect(run_id)
except Exception as e:
logger.error(
f"Error sending message for run {run_id}: {e}, {message}")
logger.error(f"Error sending message for run {run_id}: {e}, {message}")
# Don't try to send error message here to avoid potential recursive loop
await self._update_run_status(run_id, RunStatus.ERROR, str(e))
await self.disconnect(run_id)
async def _handle_stream_error(self, run_id: UUID, error: Exception) -> None:
"""Handle stream errors with connection state awareness
Args:
run_id: UUID of the run
error: Exception that occurred
"""
"""Handle stream errors with proper run updates"""
if run_id not in self._closed_connections:
try:
await self._send_message(run_id, {
error_result = TeamResult(
task_result=TaskResult(
messages=[TextMessage(source="system", content=str(error))], stop_reason="error"
),
usage="",
duration=0,
).model_dump()
await self._send_message(
run_id,
{
"type": "completion",
"status": "error",
"error": str(error),
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.error(
f"Failed to send error message for run {run_id}: {send_error}")
"data": error_result,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
await self._update_run_status(run_id, RunStatus.ERROR, str(error))
await self._update_run(run_id, RunStatus.ERROR, team_result=error_result, error=str(error))
def _format_message(self, message: Any) -> Optional[dict]:
"""Format message for WebSocket transmission
@ -236,10 +286,7 @@ class WebSocketManager:
"""
try:
if isinstance(message, (AgentMessage, ChatMessage)):
return {
"type": "message",
"data": message.model_dump()
}
return {"type": "message", "data": message.model_dump()}
elif isinstance(message, TeamResult):
return {
"type": "result",
@ -260,16 +307,10 @@ class WebSocketManager:
Returns:
Optional[Run]: Run object if found, None otherwise
"""
response = self.db_manager.get(
Run, filters={"id": run_id}, return_json=False)
response = self.db_manager.get(Run, filters={"id": run_id}, return_json=False)
return response.data[0] if response.status and response.data else None
async def _update_run_status(
self,
run_id: UUID,
status: RunStatus,
error: Optional[str] = None
) -> None:
async def _update_run_status(self, run_id: UUID, status: RunStatus, error: Optional[str] = None) -> None:
"""Update run status in database
Args:
@ -285,14 +326,27 @@ class WebSocketManager:
async def cleanup(self) -> None:
"""Clean up all active connections and resources when server is shutting down"""
logger.info(
f"Cleaning up {len(self.active_connections)} active connections")
logger.info(f"Cleaning up {len(self.active_connections)} active connections")
try:
# First cancel all running tasks
for run_id in self.active_runs.copy():
if run_id in self._cancellation_tokens:
self._cancellation_tokens[run_id].cancel()
run = await self._get_run(run_id)
if run and run.status == RunStatus.ACTIVE:
interrupted_result = TeamResult(
task_result=TaskResult(
messages=[TextMessage(source="system", content="Run interrupted by server shutdown")],
stop_reason="server_shutdown",
),
usage="",
duration=0,
).model_dump()
run.status = RunStatus.STOPPED
run.team_result = interrupted_result
self.db_manager.upsert(run)
# Then disconnect all websockets with timeout
# 10 second timeout for entire cleanup

View File

@ -1,181 +1,51 @@
# api/routes/agents.py
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict
from ..deps import get_db
from fastapi import APIRouter, Depends, HTTPException
from ...database import DatabaseManager # Add this import
from ...datamodel import Agent, Model, Tool
from ..deps import get_db
router = APIRouter()
@router.get("/")
async def list_agents(
user_id: str,
db=Depends(get_db)
) -> Dict:
async def list_agents(user_id: str, db: DatabaseManager = Depends(get_db)) -> Dict:
"""List all agents for a user"""
response = db.get(Agent, filters={"user_id": user_id})
return {
"status": True,
"data": response.data
}
return {"status": True, "data": response.data}
@router.get("/{agent_id}")
async def get_agent(
agent_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
async def get_agent(agent_id: int, user_id: str, db: DatabaseManager = Depends(get_db)) -> Dict:
"""Get a specific agent"""
response = db.get(
Agent,
filters={"id": agent_id, "user_id": user_id}
)
response = db.get(Agent, filters={"id": agent_id, "user_id": user_id})
if not response.status or not response.data:
raise HTTPException(status_code=404, detail="Agent not found")
return {
"status": True,
"data": response.data[0]
}
return {"status": True, "data": response.data[0]}
@router.post("/")
async def create_agent(
agent: Agent,
db=Depends(get_db)
) -> Dict:
async def create_agent(agent: Agent, db: DatabaseManager = Depends(get_db)) -> Dict:
"""Create a new agent"""
response = db.upsert(agent)
if not response.status:
raise HTTPException(status_code=400, detail=response.message)
return {
"status": True,
"data": response.data
}
return {"status": True, "data": response.data}
@router.delete("/{agent_id}")
async def delete_agent(
agent_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
async def delete_agent(agent_id: int, user_id: str, db: DatabaseManager = Depends(get_db)) -> Dict:
"""Delete an agent"""
response = db.delete(
filters={"id": agent_id, "user_id": user_id},
model_class=Agent
)
return {
"status": True,
"message": "Agent deleted successfully"
}
db.delete(filters={"id": agent_id, "user_id": user_id}, model_class=Agent)
return {"status": True, "message": "Agent deleted successfully"}
# Agent-Model link endpoints
@router.post("/{agent_id}/models/{model_id}")
async def link_agent_model(
agent_id: int,
model_id: int,
db=Depends(get_db)
) -> Dict:
async def link_agent_model(agent_id: int, model_id: int, db: DatabaseManager = Depends(get_db)) -> Dict:
"""Link a model to an agent"""
response = db.link(
link_type="agent_model",
primary_id=agent_id,
secondary_id=model_id
)
return {
"status": True,
"message": "Model linked to agent successfully"
}
@router.delete("/{agent_id}/models/{model_id}")
async def unlink_agent_model(
agent_id: int,
model_id: int,
db=Depends(get_db)
) -> Dict:
"""Unlink a model from an agent"""
response = db.unlink(
link_type="agent_model",
primary_id=agent_id,
secondary_id=model_id
)
return {
"status": True,
"message": "Model unlinked from agent successfully"
}
@router.get("/{agent_id}/models")
async def get_agent_models(
agent_id: int,
db=Depends(get_db)
) -> Dict:
"""Get all models linked to an agent"""
response = db.get_linked_entities(
link_type="agent_model",
primary_id=agent_id,
return_json=True
)
return {
"status": True,
"data": response.data
}
# Agent-Tool link endpoints
@router.post("/{agent_id}/tools/{tool_id}")
async def link_agent_tool(
agent_id: int,
tool_id: int,
db=Depends(get_db)
) -> Dict:
"""Link a tool to an agent"""
response = db.link(
link_type="agent_tool",
primary_id=agent_id,
secondary_id=tool_id
)
return {
"status": True,
"message": "Tool linked to agent successfully"
}
@router.delete("/{agent_id}/tools/{tool_id}")
async def unlink_agent_tool(
agent_id: int,
tool_id: int,
db=Depends(get_db)
) -> Dict:
"""Unlink a tool from an agent"""
response = db.unlink(
link_type="agent_tool",
primary_id=agent_id,
secondary_id=tool_id
)
return {
"status": True,
"message": "Tool unlinked from agent successfully"
}
@router.get("/{agent_id}/tools")
async def get_agent_tools(
agent_id: int,
db=Depends(get_db)
) -> Dict:
"""Get all tools linked to an agent"""
response = db.get_linked_entities(
link_type="agent_tool",
primary_id=agent_id,
return_json=True
)
return {
"status": True,
"data": response.data
}
db.link(link_type="agent_model", primary_id=agent_id, secondary_id=model_id)
return {"status": True, "message": "Model linked to agent successfully"}

View File

@ -0,0 +1,62 @@
# api/routes/gallery.py
from fastapi import APIRouter, Depends, HTTPException
from ...database import DatabaseManager
from ...datamodel import Gallery, GalleryConfig, Response, Run, Session
from ..deps import get_db
router = APIRouter()
@router.post("/")
async def create_gallery_entry(
gallery_data: GalleryConfig, user_id: str, db: DatabaseManager = Depends(get_db)
) -> Response:
# First validate that user owns all runs
for run in gallery_data.runs:
run_result = db.get(Run, filters={"id": run.id})
if not run_result.status or not run_result.data:
raise HTTPException(status_code=404, detail=f"Run {run.id} not found")
# Get associated session to check ownership
session_result = db.get(Session, filters={"id": run_result.data[0].session_id})
if not session_result.status or not session_result.data or session_result.data[0].user_id != user_id:
raise HTTPException(status_code=403, detail=f"Not authorized to add run {run.id} to gallery")
# Create gallery entry
gallery = Gallery(user_id=user_id, config=gallery_data)
result = db.upsert(gallery)
return result
@router.get("/{gallery_id}")
async def get_gallery_entry(gallery_id: int, user_id: str, db: DatabaseManager = Depends(get_db)) -> Response:
result = db.get(Gallery, filters={"id": gallery_id})
if not result.status or not result.data:
raise HTTPException(status_code=404, detail="Gallery entry not found")
gallery = result.data[0]
if gallery.config["visibility"] != "public" and gallery.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized to view this gallery entry")
return result
@router.get("/")
async def list_gallery_entries(user_id: str, db: DatabaseManager = Depends(get_db)) -> Response:
result = db.get(Gallery, filters={"user_id": user_id})
return result
@router.delete("/{gallery_id}")
async def delete_gallery_entry(gallery_id: int, user_id: str, db: DatabaseManager = Depends(get_db)) -> Response:
# Check ownership first
result = db.get(Gallery, filters={"id": gallery_id})
if not result.status or not result.data:
raise HTTPException(status_code=404, detail="Gallery entry not found")
if result.data[0].user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized to delete this gallery entry")
# Delete if authorized
return db.delete(Gallery, filters={"id": gallery_id})

View File

@ -1,95 +1,42 @@
# api/routes/models.py
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict
from fastapi import APIRouter, Depends, HTTPException
from openai import OpenAIError
from ..deps import get_db
from ...datamodel import Model
from ...utils import test_model
from ..deps import get_db
router = APIRouter()
@router.get("/")
async def list_models(
user_id: str,
db=Depends(get_db)
) -> Dict:
async def list_models(user_id: str, db=Depends(get_db)) -> Dict:
"""List all models for a user"""
response = db.get(Model, filters={"user_id": user_id})
return {
"status": True,
"data": response.data
}
return {"status": True, "data": response.data}
@router.get("/{model_id}")
async def get_model(
model_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
async def get_model(model_id: int, user_id: str, db=Depends(get_db)) -> Dict:
"""Get a specific model"""
response = db.get(
Model,
filters={"id": model_id, "user_id": user_id}
)
response = db.get(Model, filters={"id": model_id, "user_id": user_id})
if not response.status or not response.data:
raise HTTPException(status_code=404, detail="Model not found")
return {
"status": True,
"data": response.data[0]
}
return {"status": True, "data": response.data[0]}
@router.post("/")
async def create_model(
model: Model,
db=Depends(get_db)
) -> Dict:
async def create_model(model: Model, db=Depends(get_db)) -> Dict:
"""Create a new model"""
response = db.upsert(model)
if not response.status:
raise HTTPException(status_code=400, detail=response.message)
return {
"status": True,
"data": response.data
}
return {"status": True, "data": response.data}
@router.delete("/{model_id}")
async def delete_model(
model_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
async def delete_model(model_id: int, user_id: str, db=Depends(get_db)) -> Dict:
"""Delete a model"""
response = db.delete(
filters={"id": model_id, "user_id": user_id},
model_class=Model
)
return {
"status": True,
"message": "Model deleted successfully"
}
@router.post("/test")
async def test_model_endpoint(model: Model) -> Dict:
"""Test a model configuration"""
try:
response = test_model(model)
return {
"status": True,
"message": "Model tested successfully",
"data": response
}
except OpenAIError as e:
raise HTTPException(
status_code=400,
detail=f"OpenAI API error: {str(e)}"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error testing model: {str(e)}"
)
db.delete(filters={"id": model_id, "user_id": user_id}, model_class=Model)
return {"status": True, "message": "Model deleted successfully"}

View File

@ -1,14 +1,12 @@
# /api/runs routes
from fastapi import APIRouter, Body, Depends, HTTPException
from uuid import UUID
from typing import Dict
from uuid import UUID
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel
from ..deps import get_db, get_websocket_manager, get_team_manager
from ...datamodel import Run, Session, Message, Team, RunStatus, MessageConfig
from ...teammanager import TeamManager
from autogen_core.base import CancellationToken
from ...datamodel import Message, MessageConfig, Run, RunStatus, Session, Team
from ..deps import get_db, get_team_manager, get_websocket_manager
router = APIRouter()
@ -23,54 +21,45 @@ async def create_run(
request: CreateRunRequest,
db=Depends(get_db),
) -> Dict:
"""Create a new run"""
"""Create a new run with initial state"""
session_response = db.get(
Session,
filters={"id": request.session_id, "user_id": request.user_id},
return_json=False
Session, filters={"id": request.session_id, "user_id": request.user_id}, return_json=False
)
if not session_response.status or not session_response.data:
raise HTTPException(status_code=404, detail="Session not found")
try:
run = db.upsert(Run(session_id=request.session_id), return_json=False)
return {
"status": run.status,
"data": {"run_id": str(run.data.id)}
}
# }
# Create run with default state
run = db.upsert(
Run(
session_id=request.session_id,
status=RunStatus.CREATED,
task=None, # Will be set when run starts
team_result=None,
),
return_json=False,
)
return {"status": run.status, "data": {"run_id": str(run.data.id)}}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e
@router.post("/{run_id}/start")
async def start_run(
run_id: UUID,
message: Message = Body(...),
ws_manager=Depends(get_websocket_manager),
team_manager=Depends(get_team_manager),
db=Depends(get_db),
) -> Dict:
"""Start streaming task execution"""
# We might want to add these endpoints:
if isinstance(message.config, dict):
message.config = MessageConfig(**message.config)
session = db.get(Session, filters={
"id": message.session_id}, return_json=False)
@router.get("/{run_id}")
async def get_run(run_id: UUID, db=Depends(get_db)) -> Dict:
"""Get run details including task and result"""
run = db.get(Run, filters={"id": run_id}, return_json=False)
if not run.status or not run.data:
raise HTTPException(status_code=404, detail="Run not found")
team = db.get(
Team, filters={"id": session.data[0].team_id}, return_json=False)
return {"status": True, "data": run.data[0]}
try:
await ws_manager.start_stream(run_id, team_manager, message.config.content, team.data[0].config)
return {
"status": True,
"message": "Stream started successfully",
"data": {"run_id": str(run_id)}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{run_id}/messages")
async def get_run_messages(run_id: UUID, db=Depends(get_db)) -> Dict:
"""Get all messages for a run"""
messages = db.get(Message, filters={"run_id": run_id}, order="created_at asc", return_json=False)
return {"status": True, "data": messages.data}

View File

@ -1,72 +1,45 @@
# api/routes/sessions.py
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from ...datamodel import Message, Run, Session
from ..deps import get_db
from ...datamodel import Session, Message
router = APIRouter()
@router.get("/")
async def list_sessions(
user_id: str,
db=Depends(get_db)
) -> Dict:
async def list_sessions(user_id: str, db=Depends(get_db)) -> Dict:
"""List all sessions for a user"""
response = db.get(Session, filters={"user_id": user_id})
return {
"status": True,
"data": response.data
}
return {"status": True, "data": response.data}
@router.get("/{session_id}")
async def get_session(
session_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
async def get_session(session_id: int, user_id: str, db=Depends(get_db)) -> Dict:
"""Get a specific session"""
response = db.get(
Session,
filters={"id": session_id, "user_id": user_id}
)
response = db.get(Session, filters={"id": session_id, "user_id": user_id})
if not response.status or not response.data:
raise HTTPException(status_code=404, detail="Session not found")
return {
"status": True,
"data": response.data[0]
}
return {"status": True, "data": response.data[0]}
@router.post("/")
async def create_session(
session: Session,
db=Depends(get_db)
) -> Dict:
async def create_session(session: Session, db=Depends(get_db)) -> Dict:
"""Create a new session"""
response = db.upsert(session)
if not response.status:
raise HTTPException(status_code=400, detail=response.message)
return {
"status": True,
"data": response.data
}
return {"status": True, "data": response.data}
@router.put("/{session_id}")
async def update_session(
session_id: int,
user_id: str,
session: Session,
db=Depends(get_db)
) -> Dict:
async def update_session(session_id: int, user_id: str, session: Session, db=Depends(get_db)) -> Dict:
"""Update an existing session"""
# First verify the session belongs to user
existing = db.get(
Session,
filters={"id": session_id, "user_id": user_id}
)
existing = db.get(Session, filters={"id": session_id, "user_id": user_id})
if not existing.status or not existing.data:
raise HTTPException(status_code=404, detail="Session not found")
@ -75,40 +48,74 @@ async def update_session(
if not response.status:
raise HTTPException(status_code=400, detail=response.message)
return {
"status": True,
"data": response.data,
"message": "Session updated successfully"
}
return {"status": True, "data": response.data, "message": "Session updated successfully"}
@router.delete("/{session_id}")
async def delete_session(
session_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
async def delete_session(session_id: int, user_id: str, db=Depends(get_db)) -> Dict:
"""Delete a session"""
response = db.delete(
filters={"id": session_id, "user_id": user_id},
model_class=Session
)
return {
"status": True,
"message": "Session deleted successfully"
}
db.delete(filters={"id": session_id, "user_id": user_id}, model_class=Session)
return {"status": True, "message": "Session deleted successfully"}
@router.get("/{session_id}/messages")
async def list_messages(
session_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
"""List all messages for a session"""
filters = {"session_id": session_id, "user_id": user_id}
response = db.get(Message, filters=filters, order="asc")
return {
"status": True,
"data": response.data
}
@router.get("/{session_id}/runs")
async def list_session_runs(session_id: int, user_id: str, db=Depends(get_db)) -> Dict:
"""Get complete session history organized by runs"""
try:
# 1. Verify session exists and belongs to user
session = db.get(Session, filters={"id": session_id, "user_id": user_id}, return_json=False)
if not session.status:
raise HTTPException(status_code=500, detail="Database error while fetching session")
if not session.data:
raise HTTPException(status_code=404, detail="Session not found or access denied")
# 2. Get ordered runs for session
runs = db.get(Run, filters={"session_id": session_id}, order="asc", return_json=False)
if not runs.status:
raise HTTPException(status_code=500, detail="Database error while fetching runs")
# 3. Build response with messages per run
run_data = []
if runs.data: # It's ok to have no runs
for run in runs.data:
try:
# Get messages for this specific run
messages = db.get(Message, filters={"run_id": run.id}, order="asc", return_json=False)
if not messages.status:
logger.error(f"Failed to fetch messages for run {run.id}")
# Continue processing other runs even if one fails
messages.data = []
run_data.append(
{
"id": str(run.id),
"created_at": run.created_at,
"status": run.status,
"task": run.task,
"team_result": run.team_result,
"messages": messages.data or [],
}
)
except Exception as e:
logger.error(f"Error processing run {run.id}: {str(e)}")
# Include run with error state instead of failing entirely
run_data.append(
{
"id": str(run.id),
"created_at": run.created_at,
"status": "ERROR",
"task": run.task,
"team_result": None,
"messages": [],
"error": f"Failed to process run: {str(e)}",
}
)
return {"status": True, "data": {"runs": run_data}}
except HTTPException:
raise # Re-raise HTTP exceptions
except Exception as e:
logger.error(f"Unexpected error in list_messages: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error while fetching session data") from e

View File

@ -1,146 +1,41 @@
# api/routes/teams.py
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict
from ..deps import get_db
from fastapi import APIRouter, Depends, HTTPException
from ...datamodel import Team
from ..deps import get_db
router = APIRouter()
@router.get("/")
async def list_teams(
user_id: str,
db=Depends(get_db)
) -> Dict:
async def list_teams(user_id: str, db=Depends(get_db)) -> Dict:
"""List all teams for a user"""
response = db.get(Team, filters={"user_id": user_id})
return {
"status": True,
"data": response.data
}
return {"status": True, "data": response.data}
@router.get("/{team_id}")
async def get_team(
team_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
async def get_team(team_id: int, user_id: str, db=Depends(get_db)) -> Dict:
"""Get a specific team"""
response = db.get(
Team,
filters={"id": team_id, "user_id": user_id}
)
response = db.get(Team, filters={"id": team_id, "user_id": user_id})
if not response.status or not response.data:
raise HTTPException(status_code=404, detail="Team not found")
return {
"status": True,
"data": response.data[0]
}
return {"status": True, "data": response.data[0]}
@router.post("/")
async def create_team(
team: Team,
db=Depends(get_db)
) -> Dict:
async def create_team(team: Team, db=Depends(get_db)) -> Dict:
"""Create a new team"""
response = db.upsert(team)
if not response.status:
raise HTTPException(status_code=400, detail=response.message)
return {
"status": True,
"data": response.data
}
return {"status": True, "data": response.data}
@router.delete("/{team_id}")
async def delete_team(
team_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
async def delete_team(team_id: int, user_id: str, db=Depends(get_db)) -> Dict:
"""Delete a team"""
response = db.delete(
filters={"id": team_id, "user_id": user_id},
model_class=Team
)
return {
"status": True,
"message": "Team deleted successfully"
}
# Team-Agent link endpoints
@router.post("/{team_id}/agents/{agent_id}")
async def link_team_agent(
team_id: int,
agent_id: int,
db=Depends(get_db)
) -> Dict:
"""Link an agent to a team"""
response = db.link(
link_type="team_agent",
primary_id=team_id,
secondary_id=agent_id
)
return {
"status": True,
"message": "Agent linked to team successfully"
}
@router.post("/{team_id}/agents/{agent_id}/{sequence_id}")
async def link_team_agent_sequence(
team_id: int,
agent_id: int,
sequence_id: int,
db=Depends(get_db)
) -> Dict:
"""Link an agent to a team with sequence"""
response = db.link(
link_type="team_agent",
primary_id=team_id,
secondary_id=agent_id,
sequence_id=sequence_id
)
return {
"status": True,
"message": "Agent linked to team with sequence successfully"
}
@router.delete("/{team_id}/agents/{agent_id}")
async def unlink_team_agent(
team_id: int,
agent_id: int,
db=Depends(get_db)
) -> Dict:
"""Unlink an agent from a team"""
response = db.unlink(
link_type="team_agent",
primary_id=team_id,
secondary_id=agent_id
)
return {
"status": True,
"message": "Agent unlinked from team successfully"
}
@router.get("/{team_id}/agents")
async def get_team_agents(
team_id: int,
db=Depends(get_db)
) -> Dict:
"""Get all agents linked to a team"""
response = db.get_linked_entities(
link_type="team_agent",
primary_id=team_id,
return_json=True
)
return {
"status": True,
"data": response.data
}
db.delete(filters={"id": team_id, "user_id": user_id}, model_class=Team)
return {"status": True, "message": "Team deleted successfully"}

View File

@ -1,103 +1,41 @@
# api/routes/tools.py
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict
from ..deps import get_db
from fastapi import APIRouter, Depends, HTTPException
from ...datamodel import Tool
from ..deps import get_db
router = APIRouter()
@router.get("/")
async def list_tools(
user_id: str,
db=Depends(get_db)
) -> Dict:
async def list_tools(user_id: str, db=Depends(get_db)) -> Dict:
"""List all tools for a user"""
response = db.get(Tool, filters={"user_id": user_id})
return {
"status": True,
"data": response.data
}
return {"status": True, "data": response.data}
@router.get("/{tool_id}")
async def get_tool(
tool_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
async def get_tool(tool_id: int, user_id: str, db=Depends(get_db)) -> Dict:
"""Get a specific tool"""
response = db.get(
Tool,
filters={"id": tool_id, "user_id": user_id}
)
response = db.get(Tool, filters={"id": tool_id, "user_id": user_id})
if not response.status or not response.data:
raise HTTPException(status_code=404, detail="Tool not found")
return {
"status": True,
"data": response.data[0]
}
return {"status": True, "data": response.data[0]}
@router.post("/")
async def create_tool(
tool: Tool,
db=Depends(get_db)
) -> Dict:
async def create_tool(tool: Tool, db=Depends(get_db)) -> Dict:
"""Create a new tool"""
response = db.upsert(tool)
if not response.status:
raise HTTPException(status_code=400, detail=response.message)
return {
"status": True,
"data": response.data
}
return {"status": True, "data": response.data}
@router.delete("/{tool_id}")
async def delete_tool(
tool_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
async def delete_tool(tool_id: int, user_id: str, db=Depends(get_db)) -> Dict:
"""Delete a tool"""
response = db.delete(
filters={"id": tool_id, "user_id": user_id},
model_class=Tool
)
return {
"status": True,
"message": "Tool deleted successfully"
}
@router.post("/{tool_id}/test")
async def test_tool(
tool_id: int,
user_id: str,
db=Depends(get_db)
) -> Dict:
"""Test a tool configuration"""
# Get tool
tool_response = db.get(
Tool,
filters={"id": tool_id, "user_id": user_id}
)
if not tool_response.status or not tool_response.data:
raise HTTPException(status_code=404, detail="Tool not found")
tool = tool_response.data[0]
try:
# Implement tool testing logic here
# This would depend on the tool type and configuration
return {
"status": True,
"message": "Tool tested successfully",
"data": {"tool_id": tool_id}
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error testing tool: {str(e)}"
)
db.delete(filters={"id": tool_id, "user_id": user_id}, model_class=Tool)
return {"status": True, "message": "Tool deleted successfully"}

View File

@ -1,17 +1,17 @@
# api/ws.py
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, HTTPException
from typing import Dict
from uuid import UUID
import logging
import asyncio
import json
from datetime import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from loguru import logger
from ..deps import get_websocket_manager, get_db, get_team_manager
from ...datamodel import Run, RunStatus
from ..deps import get_db, get_websocket_manager
from ..managers import WebSocketManager
router = APIRouter()
logger = logging.getLogger(__name__)
@router.websocket("/runs/{run_id}")
@ -20,12 +20,12 @@ async def run_websocket(
run_id: UUID,
ws_manager: WebSocketManager = Depends(get_websocket_manager),
db=Depends(get_db),
team_manager=Depends(get_team_manager)
):
"""WebSocket endpoint for run communication"""
# Verify run exists and is in valid state
run_response = db.get(Run, filters={"id": run_id}, return_json=False)
if not run_response.status or not run_response.data:
logger.warning(f"Run not found: {run_id}")
await websocket.close(code=4004, reason="Run not found")
return
@ -48,18 +48,32 @@ async def run_websocket(
raw_message = await websocket.receive_text()
message = json.loads(raw_message)
if message.get("type") == "stop":
print(f"Received stop request for run {run_id}")
reason = message.get(
"reason") or "User requested stop/cancellation"
if message.get("type") == "start":
# Handle start message
logger.info(f"Received start request for run {run_id}")
task = message.get("task")
team_config = message.get("team_config")
if task and team_config:
# await ws_manager.start_stream(run_id, task, team_config)
asyncio.create_task(ws_manager.start_stream(run_id, task, team_config))
else:
logger.warning(f"Invalid start message format for run {run_id}")
await websocket.send_json(
{
"type": "error",
"error": "Invalid start message format",
"timestamp": datetime.utcnow().isoformat(),
}
)
elif message.get("type") == "stop":
logger.info(f"Received stop request for run {run_id}")
reason = message.get("reason") or "User requested stop/cancellation"
await ws_manager.stop_run(run_id, reason=reason)
break
elif message.get("type") == "ping":
await websocket.send_json({
"type": "pong",
"timestamp": datetime.utcnow().isoformat()
})
await websocket.send_json({"type": "pong", "timestamp": datetime.utcnow().isoformat()})
elif message.get("type") == "input_response":
# Handle input response from client
@ -67,16 +81,13 @@ async def run_websocket(
if response is not None:
await ws_manager.handle_input_response(run_id, response)
else:
logger.warning(
f"Invalid input response format for run {run_id}")
logger.warning(f"Invalid input response format for run {run_id}")
except json.JSONDecodeError:
logger.warning(f"Invalid JSON received: {raw_message}")
await websocket.send_json({
"type": "error",
"error": "Invalid message format",
"timestamp": datetime.utcnow().isoformat()
})
await websocket.send_json(
{"type": "error", "error": "Invalid message format", "timestamp": datetime.utcnow().isoformat()}
)
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for run {run_id}")

View File

@ -17,8 +17,6 @@
"typecheck": "tsc --noEmit"
},
"dependencies": {
"@ant-design/charts": "^2.2.3",
"@ant-design/plots": "^2.2.2",
"@dagrejs/dagre": "^1.1.4",
"@headlessui/react": "^2.2.0",
"@heroicons/react": "^2.0.18",
@ -38,7 +36,7 @@
"gatsby-source-filesystem": "^5.14.0",
"gatsby-transformer-sharp": "^5.14.0",
"install": "^0.13.0",
"lucide-react": "^0.456.0",
"lucide-react": "^0.460.0",
"postcss": "^8.4.49",
"react": "^18.2.0",
"react-dom": "^18.2.0",

View File

@ -105,22 +105,15 @@ export type ThreadStatus =
| "timeout";
export interface WebSocketMessage {
type: "message" | "result" | "completion" | "input_request";
data?: {
source?: string;
models_usage?: RequestUsage | null;
content?: string;
task_result?: TaskResult;
};
status?: ThreadStatus;
type: "message" | "result" | "completion" | "input_request" | "error";
data?: AgentMessageConfig | TaskResult;
status?: RunStatus;
error?: string;
timestamp?: string;
}
export interface TaskResult {
messages: AgentMessageConfig[];
usage: string;
duration: number;
stop_reason?: string;
}
@ -187,3 +180,36 @@ export interface TeamConfig extends BaseConfig {
export interface Team extends DBModel {
config: TeamConfig;
}
export interface TeamResult {
task_result: TaskResult;
usage: string;
duration: number;
}
export interface Run {
id: string;
created_at: string;
status: RunStatus;
task: AgentMessageConfig;
team_result: TeamResult | null;
messages: Message[]; // Change to Message[]
error_message?: string;
}
// Separate transient state
interface TransientRunState {
pendingInput?: {
prompt: string;
isPending: boolean;
};
}
export type RunStatus =
| "created"
| "active" // covers 'streaming'
| "awaiting_input"
| "timeout"
| "complete"
| "error"
| "stopped";

View File

@ -10,7 +10,6 @@ import {
Node,
Edge,
Background,
Controls,
NodeTypes,
useReactFlow,
ReactFlowProvider,
@ -24,16 +23,16 @@ import {
AgentMessageConfig,
AgentConfig,
TeamConfig,
Run,
} from "../../../../types/datamodel";
import { ThreadState } from "../types";
import { CustomEdge } from "./edge";
import { CustomEdge, CustomEdgeData } from "./edge";
import { useConfigStore } from "../../../../../hooks/store";
import { AgentFlowToolbar } from "./toolbar";
import { EdgeMessageModal } from "./edgemessagemodal";
interface AgentFlowProps {
teamConfig: TeamConfig;
messages: AgentMessageConfig[];
threadState: ThreadState;
run: Run;
}
interface MessageSequence {
@ -51,40 +50,37 @@ interface BidirectionalPattern {
const NODE_DIMENSIONS = {
default: { width: 170, height: 100 },
end: { width: 120, height: 80 },
end: { width: 170, height: 80 },
task: { width: 170, height: 100 },
};
const getLayoutedElements = (
nodes: Node[],
edges: Edge[],
edges: CustomEdge[],
direction: "TB" | "LR"
) => {
const g = new Dagre.graphlib.Graph().setDefaultEdgeLabel(() => ({}));
// Updated graph settings
g.setGraph({
rankdir: direction,
nodesep: 80,
ranksep: 120,
ranker: "tight-tree",
nodesep: 110,
ranksep: 100,
ranker: "network-simplex",
marginx: 30,
marginy: 30,
});
// Add nodes (unchanged)
nodes.forEach((node) => {
const dimensions =
node.data.type === "end" ? NODE_DIMENSIONS.end : NODE_DIMENSIONS.default;
g.setNode(node.id, { ...node, ...dimensions });
});
// Create a map to track bidirectional edges
const bidirectionalPairs = new Map<
string,
{ source: string; target: string }[]
>();
// First pass - identify bidirectional pairs
edges.forEach((edge) => {
const forwardKey = `${edge.source}->${edge.target}`;
const reverseKey = `${edge.target}->${edge.source}`;
@ -99,33 +95,19 @@ const getLayoutedElements = (
});
});
// Second pass - add edges with weights
bidirectionalPairs.forEach((pairs, pairKey) => {
if (pairs.length === 2) {
// Bidirectional edge
const [first, second] = pairs;
g.setEdge(first.source, first.target, {
weight: 2,
minlen: 1,
});
g.setEdge(second.source, second.target, {
weight: 1,
minlen: 1,
});
g.setEdge(first.source, first.target, { weight: 2, minlen: 1 });
g.setEdge(second.source, second.target, { weight: 1, minlen: 1 });
} else {
// Regular edge
const edge = pairs[0];
g.setEdge(edge.source, edge.target, {
weight: 1,
minlen: 1,
});
g.setEdge(edge.source, edge.target, { weight: 1, minlen: 1 });
}
});
// Run layout
Dagre.layout(g);
// Position nodes
const positionedNodes = nodes.map((node) => {
const { x, y } = g.node(node.id);
const dimensions =
@ -139,7 +121,6 @@ const getLayoutedElements = (
};
});
// Create a map of node positions for edge calculations
const nodePositions = new Map(
positionedNodes.map((node) => [
node.id,
@ -150,7 +131,6 @@ const getLayoutedElements = (
])
);
// Process edges with positions
const positionedEdges = edges.map((edge) => {
const sourcePos = nodePositions.get(edge.source)!;
const targetPos = nodePositions.get(edge.target)!;
@ -164,10 +144,7 @@ const getLayoutedElements = (
};
});
return {
nodes: positionedNodes,
edges: positionedEdges,
};
return { nodes: positionedNodes, edges: positionedEdges };
};
const createNode = (
@ -175,11 +152,11 @@ const createNode = (
type: "user" | "agent" | "end",
agentConfig?: AgentConfig,
isActive: boolean = false,
threadState?: ThreadState
run?: Run
): Node => {
const isStreamingOrWaiting =
threadState?.status === "streaming" ||
threadState?.status === "awaiting_input";
const isProcessing =
run?.status === "active" || run?.status === "awaiting_input";
if (type === "user") {
return {
id,
@ -193,7 +170,7 @@ const createNode = (
isActive,
status: "",
reason: "",
draggable: !isStreamingOrWaiting,
draggable: !isProcessing,
},
};
}
@ -206,8 +183,8 @@ const createNode = (
data: {
type: "end",
label: "End",
status: threadState?.status,
reason: threadState?.reason || "",
status: run?.status,
reason: run?.error_message || "",
agentType: "",
description: "",
isActive: false,
@ -216,6 +193,23 @@ const createNode = (
};
}
// if (type === "task") {
// return {
// id,
// type: "agentNode",
// position: { x: 0, y: 0 },
// data: {
// type: "task",
// label: "Task",
// description: run?.task.content || "",
// isActive: false,
// status: null,
// reason: null,
// draggable: false,
// },
// };
// }
return {
id,
type: "agentNode",
@ -228,7 +222,7 @@ const createNode = (
isActive,
status: "",
reason: "",
draggable: !isStreamingOrWaiting,
draggable: !isProcessing,
},
};
};
@ -241,23 +235,28 @@ const edgeTypes = {
custom: CustomEdge,
};
const AgentFlow: React.FC<AgentFlowProps> = ({
teamConfig,
messages,
threadState,
}) => {
const AgentFlow: React.FC<AgentFlowProps> = ({ teamConfig, run }) => {
const { fitView } = useReactFlow();
const [nodes, setNodes] = useState<Node[]>([]);
const [edges, setEdges] = useState<Edge[]>([]);
const [edges, setEdges] = useState<CustomEdge[]>([]);
const [shouldRefit, setShouldRefit] = useState(false);
const [isFullscreen, setIsFullscreen] = useState(false);
// Get settings from store
const { agentFlow: settings, setAgentFlowSettings } = useConfigStore();
const { agentFlow: settings } = useConfigStore();
const [modalOpen, setModalOpen] = useState(false);
const [selectedEdge, setSelectedEdge] = useState<CustomEdge | null>(null);
const handleEdgeClick = useCallback((edge: CustomEdge) => {
if (!edge.data?.messages) return; // Early return if no data/messages
setSelectedEdge(edge);
setModalOpen(true);
}, []);
const onNodesChange = useCallback((changes: NodeChange[]) => {
setNodes((nds) => applyNodeChanges(changes, nds));
}, []);
const flowWrapper = useRef<HTMLDivElement>(null);
useEffect(() => {
@ -265,36 +264,38 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
const timeoutId = setTimeout(() => {
fitView({ padding: 0.2, duration: 200 });
setShouldRefit(false);
}, 100); // Increased delay slightly
}, 100);
return () => clearTimeout(timeoutId);
}
}, [shouldRefit, fitView]);
// Process messages into nodes and edges
const processMessages = useCallback(
(messages: AgentMessageConfig[]) => {
if (messages.length === 0) return { nodes: [], edges: [] };
if (!run.task) return { nodes: [], edges: [] };
const nodeMap = new Map<string, Node>();
const transitionCounts = new Map<string, MessageSequence>();
const bidirectionalPatterns = new Map<string, BidirectionalPattern>();
// Process first message source
const firstAgentConfig = teamConfig.participants.find(
(p) => p.name === messages[0].source
);
nodeMap.set(
messages[0].source,
createNode(
// Add first message node if it exists
if (messages.length > 0) {
const firstAgentConfig = teamConfig.participants.find(
(p) => p.name === messages[0].source
);
nodeMap.set(
messages[0].source,
messages[0].source === "user" ? "user" : "agent",
firstAgentConfig,
false
)
);
createNode(
messages[0].source,
messages[0].source === "user" ? "user" : "agent",
firstAgentConfig,
false,
run
)
);
}
// Group messages by transitions
// Process message transitions
for (let i = 0; i < messages.length - 1; i++) {
const currentMsg = messages[i];
const nextMsg = messages[i + 1];
@ -319,7 +320,6 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
transition.messages.push(currentMsg);
}
// Ensure all nodes are in the nodeMap
if (!nodeMap.has(nextMsg.source)) {
const agentConfig = teamConfig.participants.find(
(p) => p.name === nextMsg.source
@ -330,13 +330,14 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
nextMsg.source,
nextMsg.source === "user" ? "user" : "agent",
agentConfig,
false
false,
run
)
);
}
}
// Identify bidirectional patterns
// Process bidirectional patterns
transitionCounts.forEach((transition, key) => {
const [source, target] = key.split("->");
const reverseKey = `${target}->${source}`;
@ -352,10 +353,9 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
});
// Create edges with bidirectional routing
const newEdges: Edge[] = [];
const newEdges: CustomEdge[] = [];
const processedKeys = new Set<string>();
// Helper function to create edge label based on settings
const createEdgeLabel = (transition: MessageSequence) => {
if (!settings.showLabels) return "";
if (transition.totalTokens > 0) {
@ -374,7 +374,6 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
const bidirectionalPattern = bidirectionalPatterns.get(patternKey);
if (bidirectionalPattern) {
// Create paired edges for bidirectional pattern
const forwardKey = `${source}->${target}`;
const reverseKey = `${target}->${source}`;
@ -386,7 +385,7 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
isSecondary: boolean,
edgeId: string,
pairedEdgeId: string
) => ({
): CustomEdge => ({
id: edgeId,
source: transition.source,
target: transition.target,
@ -421,7 +420,6 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
processedKeys.add(forwardKey);
processedKeys.add(reverseKey);
} else {
// Handle regular edges (including self-loops)
newEdges.push({
id: `${transition.source}-${transition.target}-${key}`,
source: transition.source,
@ -432,7 +430,7 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
messages: settings.showMessages ? transition.messages : [],
},
animated:
threadState?.status === "streaming" &&
run.status === "active" &&
key === Array.from(transitionCounts.keys()).pop(),
style: {
stroke: "#2563eb",
@ -442,25 +440,22 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
}
});
// Handle end node logic
if (threadState && messages.length > 0) {
// Add end node if run is complete/error/stopped
if (run && messages.length > 0) {
const lastMessage = messages[messages.length - 1];
if (["complete", "error", "cancelled"].includes(threadState.status)) {
nodeMap.set(
"end",
createNode("end", "end", undefined, false, threadState)
);
if (["complete", "error", "stopped"].includes(run.status)) {
nodeMap.set("end", createNode("end", "end", undefined, false, run));
const edgeColor =
{
complete: "#2563eb",
cancelled: "red",
error: "red",
streaming: "#2563eb",
awaiting_input: "#2563eb",
timeout: "red",
}[threadState.status] || "#2563eb";
const edgeColor = {
complete: "#2563eb",
stopped: "red",
error: "red",
active: "#2563eb",
awaiting_input: "#2563eb",
timeout: "red",
created: "#2563eb",
}[run.status];
newEdges.push({
id: `${lastMessage.source}-end`,
@ -481,12 +476,9 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
}
}
return {
nodes: Array.from(nodeMap.values()),
edges: newEdges,
};
return { nodes: Array.from(nodeMap.values()), edges: newEdges };
},
[teamConfig.participants, threadState, settings]
[teamConfig.participants, run, settings]
);
const handleToggleFullscreen = useCallback(() => {
@ -508,7 +500,9 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
}, [isFullscreen, handleToggleFullscreen]);
useEffect(() => {
const { nodes: newNodes, edges: newEdges } = processMessages(messages);
const { nodes: newNodes, edges: newEdges } = processMessages(
run.messages.map((m) => m.config)
);
const { nodes: layoutedNodes, edges: layoutedEdges } = getLayoutedElements(
newNodes,
newEdges,
@ -518,17 +512,22 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
setNodes(layoutedNodes);
setEdges(layoutedEdges);
if (messages.length > 0) {
if (run.messages.length > 0) {
setTimeout(() => {
fitView({ padding: 0.2, duration: 200 });
}, 50);
}
}, [messages, processMessages, settings.direction, threadState, fitView]);
}, [run.messages, processMessages, settings.direction, run.status, fitView]);
// Define common ReactFlow props
const reactFlowProps = {
nodes,
edges,
edges: edges.map((edge) => ({
...edge,
data: {
...edge.data,
onClick: () => handleEdgeClick(edge),
},
})),
nodeTypes,
edgeTypes,
defaultViewport: { x: 0, y: 0, zoom: 1 },
@ -538,7 +537,6 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
proOptions: { hideAttribution: true },
};
// Define common toolbar props
const toolbarProps = useMemo(
() => ({
isFullscreen,
@ -547,16 +545,16 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
}),
[isFullscreen, handleToggleFullscreen, fitView]
);
return (
<div
ref={flowWrapper}
className={`transition-all duration-200 ${
isFullscreen
? "fixed inset-4 z-[9999] shadow" // Modal-like styling
? "fixed inset-4 z-[50] shadow"
: "w-full h-full min-h-[300px]"
} bg-tertiary rounded-lg`}
>
{/* Backdrop when fullscreen */}
{isFullscreen && (
<div
className="fixed inset-0 -z-10 bg-background/80 backdrop-blur-sm"
@ -566,11 +564,18 @@ const AgentFlow: React.FC<AgentFlowProps> = ({
<ReactFlow {...reactFlowProps}>
{settings.showGrid && <Background />}
{/* <Controls /> */}
<div className="absolute top-0 right-0 z-50">
<AgentFlowToolbar {...toolbarProps} />
</div>
</ReactFlow>
<EdgeMessageModal
open={modalOpen}
onClose={() => {
setModalOpen(false);
setSelectedEdge(null);
}}
edge={selectedEdge}
/>
</div>
);
};

View File

@ -66,7 +66,7 @@ function AgentNode({ data, isConnectable }: AgentNodeProps) {
if (data.type === "end") {
return {
wrapper: `relative min-w-[120px] shadow rounded-lg overflow-hidden ${activeStyles}`,
wrapper: `relative min-w-[180px] shadow rounded-lg overflow-hidden ${activeStyles}`,
border:
data.status === "complete"
? "var(--accent)"
@ -77,7 +77,7 @@ function AgentNode({ data, isConnectable }: AgentNodeProps) {
}
return {
wrapper: `min-w-[150px] rounded-lg shadow overflow-hidden ${activeStyles}`,
wrapper: `min-w-[180px] rounded-lg shadow overflow-hidden ${activeStyles}`,
border: undefined,
};
};

View File

@ -1,38 +1,23 @@
import React from "react";
import { Tooltip } from "antd";
import { AgentMessageConfig } from "../../../../types/datamodel";
import {
Edge,
EdgeLabelRenderer,
type EdgeProps,
getSmoothStepPath,
} from "@xyflow/react";
import { RenderMessage } from "../rendermessage";
interface EdgeTooltipContentProps {
messages: AgentMessageConfig[];
}
const EDGE_OFFSET = 140;
interface CustomEdgeData {
export interface CustomEdgeData extends Record<string, unknown> {
label?: string;
messages: AgentMessageConfig[];
routingType?: "primary" | "secondary";
bidirectionalPair?: string;
onClick?: () => void;
}
const EdgeTooltipContent: React.FC<EdgeTooltipContentProps> = ({
messages,
}) => {
return (
<div className="p-2 overflow-auto max-h-[200px] scroll max-w-[350px]">
<div className="text-xs mb-2">{messages.length} messages</div>
<div className="edge-tooltip">
{messages.map((message, index) => (
<RenderMessage key={index} message={message} />
))}
</div>
</div>
);
};
export type CustomEdge = Edge<CustomEdgeData>;
interface CustomEdgeProps extends Omit<EdgeProps, "data"> {
data: CustomEdgeData;
@ -63,27 +48,24 @@ export const CustomEdge: React.FC<CustomEdgeProps> = ({
let labelX = 0;
let labelY = 0;
if (isSelfLoop) {
const rightOffset = 120;
const verticalOffset = sourceY - targetY;
const verticalPadding = 6;
const radius = 8;
if (data.routingType === "secondary" || isSelfLoop) {
// Calculate the midpoint and offset
const midY = (sourceY + targetY) / 2;
const offset = EDGE_OFFSET;
// Create path that goes out from output, right, up, left, into input
edgePath = `
M ${sourceX} ${targetY - verticalPadding}
L ${sourceX + rightOffset - radius} ${targetY - verticalPadding}
Q ${sourceX + rightOffset} ${targetY - verticalPadding} ${
sourceX + rightOffset
} ${targetY - verticalPadding + radius}
L ${sourceX + rightOffset} ${sourceY + verticalPadding - radius}
Q ${sourceX + rightOffset} ${sourceY + verticalPadding} ${
sourceX + rightOffset - radius
} ${sourceY + verticalPadding}
L ${sourceX} ${sourceY + verticalPadding}
M ${sourceX},${sourceY}
L ${sourceX},${sourceY + 10}
L ${sourceX + offset},${sourceY + 10}
L ${sourceX + offset},${targetY - 10}
L ${targetX},${targetY - 10}
L ${targetX},${targetY}
`;
labelX = sourceX + rightOffset + 10;
labelY = targetY + verticalOffset / 2;
// Set label position
labelX = sourceX + offset;
labelY = midY;
} else {
[edgePath, labelX, labelY] = getSmoothStepPath({
sourceX,
@ -98,7 +80,7 @@ export const CustomEdge: React.FC<CustomEdgeProps> = ({
if (!data.routingType || isSelfLoop) return { x, y };
// Make vertical separation more pronounced
const verticalOffset = data.routingType === "secondary" ? -35 : 35; // Increased from 20 to 35
const verticalOffset = data.routingType === "secondary" ? -35 : 35;
const horizontalOffset = data.routingType === "secondary" ? -25 : 25;
// Calculate edge angle to determine if it's more horizontal or vertical
@ -109,7 +91,7 @@ export const CustomEdge: React.FC<CustomEdgeProps> = ({
// Always apply some vertical offset
const basePosition = {
x: isMoreHorizontal ? x : x + horizontalOffset,
y: y + (data.routingType === "secondary" ? -35 : 35), // Always apply vertical offset
y: y + (data.routingType === "secondary" ? -20 : 20),
};
return basePosition;
@ -137,29 +119,23 @@ export const CustomEdge: React.FC<CustomEdgeProps> = ({
position: "absolute",
transform: `translate(-50%, -50%) translate(${labelPosition.x}px,${labelPosition.y}px)`,
pointerEvents: "all",
// Add a slight transition for smooth updates
transition: "transform 0.2s ease-in-out",
transition: "all 0.2s ease-in-out",
}}
onClick={data.onClick}
>
<Tooltip
title={
data.messages && data.messages.length > 0 ? (
<EdgeTooltipContent messages={data.messages} />
) : (
data?.label
)
}
overlayStyle={{ maxWidth: "none" }}
<div
className="px-2 py-1 rounded bg-secondary hover:bg-tertiary text-primary
cursor-pointer transform hover:scale-110 transition-all
flex items-center gap-1"
style={{
whiteSpace: "nowrap",
}}
>
<div
className="px-2 py-1 rounded bg-secondary bg-opacity-50 text-primary text-sm"
style={{
whiteSpace: "nowrap", // Prevent label from wrapping
}}
>
{data.label}
</div>
</Tooltip>
{messageCount > 0 && (
<span className="text-xs text-secondary">({messageCount})</span>
)}
<span className="text-sm">{data.label}</span>
</div>
</div>
</EdgeLabelRenderer>
)}

View File

@ -0,0 +1,103 @@
// edgemessagemodal.tsx
import React, { useState, useMemo } from "react";
import { Modal, Input } from "antd";
import { RenderMessage } from "../rendermessage";
import { CustomEdge } from "./edge";
const { Search } = Input;
interface EdgeMessageModalProps {
open: boolean;
onClose: () => void;
edge: CustomEdge | null;
}
export const EdgeMessageModal: React.FC<EdgeMessageModalProps> = ({
open,
onClose,
edge,
}) => {
const [searchTerm, setSearchTerm] = useState("");
const totalTokens = useMemo(() => {
if (!edge?.data?.messages) return 0;
return edge.data.messages.reduce((acc, msg) => {
const promptTokens = msg.models_usage?.prompt_tokens || 0;
const completionTokens = msg.models_usage?.completion_tokens || 0;
return acc + promptTokens + completionTokens;
}, 0);
}, [edge?.data?.messages]);
const filteredMessages = useMemo(() => {
if (!edge?.data?.messages) return [];
if (!searchTerm) return edge.data.messages;
return edge.data.messages.filter(
(msg) =>
typeof msg.content === "string" &&
msg.content.toLowerCase().includes(searchTerm.toLowerCase())
);
}, [edge?.data?.messages, searchTerm]);
if (!edge) return null;
return (
<Modal
title={
<div className="space-y-2">
<div className="font-medium">
{edge.source} {edge.target}
</div>
<div className="text-sm text-secondary flex justify-between">
{edge.data && (
<span>
{edge.data.messages.length} message
{`${edge.data.messages.length > 1 ? "s" : ""}`}
</span>
)}
<span>{totalTokens.toLocaleString()} tokens</span>
</div>
{edge.data && edge.data.messages.length > 0 && (
<div className="text-xs py-2 font-normal">
{" "}
The above represents the number of times the {`${edge.target}`}{" "}
node sent a message{" "}
<span className="font-semibold underline text-accent">after</span>{" "}
the {`${edge.source}`} node.{" "}
</div>
)}
</div>
}
open={open}
onCancel={onClose}
width={800}
footer={null}
>
<div className="max-h-[70vh] overflow-y-auto space-y-4 scroll pr-2">
<Search
placeholder="Search message content..."
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
allowClear
className="sticky top-0 z-10"
/>
<div className="space-y-4 ">
{filteredMessages.map((msg, idx) => (
<RenderMessage
key={idx}
message={msg}
isLast={idx === filteredMessages.length - 1}
/>
))}
{filteredMessages.length === 0 && (
<div className="text-center text-secondary py-8">
No messages found
</div>
)}
</div>
</div>
</Modal>
);
};

View File

@ -4,200 +4,119 @@ import { getServerUrl } from "../../../utils";
import { SessionManager } from "../../shared/session/manager";
import { IStatus } from "../../../types/app";
import {
Run,
Message,
ThreadStatus,
WebSocketMessage,
TeamConfig,
AgentMessageConfig,
RunStatus,
TeamResult,
} from "../../../types/datamodel";
import { useConfigStore } from "../../../../hooks/store";
import { appContext } from "../../../../hooks/provider";
import ChatInput from "./chatinput";
import { ModelUsage, ThreadState, TIMEOUT_CONFIG } from "./types";
import { MessageList } from "./messagelist";
import TeamManager from "../../shared/team/manager";
import { teamAPI } from "../../shared/team/api";
import { sessionAPI } from "../../shared/session/api";
import RunView from "./runview";
import { TIMEOUT_CONFIG } from "./types";
const logo = require("../../../../images/landing/welcome.svg").default;
export default function ChatView({
initMessages,
}: {
initMessages: Message[];
}) {
export default function ChatView() {
const serverUrl = getServerUrl();
const [loading, setLoading] = React.useState(false);
const [error, setError] = React.useState<IStatus | null>({
status: true,
message: "All good",
});
const [messages, setMessages] = React.useState<Message[]>(initMessages);
const [threadMessages, setThreadMessages] = React.useState<
Record<string, ThreadState>
>({});
const chatContainerRef = React.useRef<HTMLDivElement>(null);
const timeoutRefs = React.useRef<Record<string, NodeJS.Timeout>>({});
// Core state
const [existingRuns, setExistingRuns] = React.useState<Run[]>([]);
const [currentRun, setCurrentRun] = React.useState<Run | null>(null);
const chatContainerRef = React.useRef<HTMLDivElement | null>(null);
// Context and config
const { user } = React.useContext(appContext);
const { session, sessions } = useConfigStore();
const [activeSockets, setActiveSockets] = React.useState<
Record<string, WebSocket>
>({});
const activeSocketsRef = React.useRef<Record<string, WebSocket>>({});
const [activeSocket, setActiveSocket] = React.useState<WebSocket | null>(
null
);
const [teamConfig, setTeamConfig] = React.useState<TeamConfig | null>(null);
const [teamConfig, setTeamConfig] = React.useState<any>(null);
const inputTimeoutRef = React.useRef<NodeJS.Timeout | null>(null);
const activeSocketRef = React.useRef<WebSocket | null>(null);
React.useEffect(() => {
if (chatContainerRef.current) {
chatContainerRef.current.scrollTo({
top: chatContainerRef.current.scrollHeight,
behavior: "smooth",
});
// Create a Message object from AgentMessageConfig
const createMessage = (
config: AgentMessageConfig,
runId: string,
sessionId: number
): Message => ({
created_at: new Date().toISOString(),
updated_at: new Date().toISOString(),
config,
session_id: sessionId,
run_id: runId,
user_id: user?.email || undefined,
});
// Load existing runs when session changes
const loadSessionRuns = async () => {
if (!session?.id || !user?.email) return;
try {
const response = await sessionAPI.getSessionRuns(session.id, user.email);
setExistingRuns(response.runs);
} catch (error) {
console.error("Error loading session runs:", error);
message.error("Failed to load chat history");
}
}, [messages, threadMessages]);
};
React.useEffect(() => {
if (session && session.team_id && user && user.email) {
teamAPI.getTeam(session.team_id, user?.email).then((team) => {
if (session?.id) {
loadSessionRuns();
} else {
setExistingRuns([]);
setCurrentRun(null);
}
}, [session?.id]);
// Load team config
React.useEffect(() => {
if (session?.team_id && user?.email) {
teamAPI.getTeam(session.team_id, user.email).then((team) => {
setTeamConfig(team.config);
// console.log("Team Config", team.config);
});
}
}, [session]);
const updateSocket = (runId: string, socket: WebSocket | null) => {
if (socket) {
activeSocketsRef.current[runId] = socket;
setActiveSockets((prev) => ({ ...prev, [runId]: socket }));
} else {
delete activeSocketsRef.current[runId];
setActiveSockets((prev) => {
const next = { ...prev };
delete next[runId];
return next;
});
}
};
React.useEffect(() => {
setTimeout(() => {
if (chatContainerRef.current && existingRuns.length > 0) {
// Scroll to bottom to show latest run
chatContainerRef.current.scrollTo({
top: chatContainerRef.current.scrollHeight,
behavior: "auto", // Use 'auto' instead of 'smooth' for initial load
});
}
}, 450);
}, [existingRuns.length, currentRun?.messages]);
// Cleanup socket on unmount
React.useEffect(() => {
return () => {
Object.values(activeSockets).forEach((socket) => socket.close());
if (inputTimeoutRef.current) {
clearTimeout(inputTimeoutRef.current);
}
activeSocket?.close();
};
}, [activeSockets]);
const handleTimeoutForRun = (runId: string) => {
const socket = activeSocketsRef.current[runId];
if (socket && socket.readyState === WebSocket.OPEN) {
// Send stop message to backend, just like when user clicks stop
socket.send(
JSON.stringify({
type: "stop",
reason: TIMEOUT_CONFIG.DEFAULT_MESSAGE,
})
);
}
// Update thread state with timeout reason
setThreadMessages((prev) => {
const currentThread = prev[runId];
if (!currentThread) return prev;
return {
...prev,
[runId]: {
...currentThread,
status: "cancelled", // Use existing cancelled status
reason: "Input request timed out after 3 minutes",
isExpanded: true,
inputRequest: currentThread.inputRequest
? {
prompt: currentThread.inputRequest.prompt,
isPending: true,
}
: undefined,
},
};
});
if (timeoutRefs.current[runId]) {
clearTimeout(timeoutRefs.current[runId]);
delete timeoutRefs.current[runId];
}
};
const handleInputResponse = async (runId: string, response: string) => {
// Clear timeout when response is received
if (timeoutRefs.current[runId]) {
clearTimeout(timeoutRefs.current[runId]);
delete timeoutRefs.current[runId];
}
if (response === "TIMEOUT") {
handleTimeoutForRun(runId);
return;
}
const socket = activeSockets[runId];
if (socket && socket.readyState === WebSocket.OPEN) {
try {
socket.send(
JSON.stringify({
type: "input_response",
response: response,
})
);
setThreadMessages((prev) => ({
...prev,
[runId]: {
...prev[runId],
status: "streaming",
inputRequest: undefined,
},
}));
} catch (error) {
console.error("Error sending input response:", error);
message.error("Failed to send response");
setThreadMessages((prev) => ({
...prev,
[runId]: {
...prev[runId],
status: "error",
reason: "Failed to send input response",
},
}));
}
} else {
message.error("Connection lost. Please try again.");
}
};
const getBaseUrl = (url: string): string => {
try {
// Remove protocol (http:// or https://)
let baseUrl = url.replace(/(^\w+:|^)\/\//, "");
// Handle both localhost and production cases
if (baseUrl.startsWith("localhost")) {
// For localhost, keep the port if it exists
baseUrl = baseUrl.replace("/api", "");
} else if (baseUrl === "/api") {
// For production where url is just '/api'
baseUrl = window.location.host;
} else {
// For other cases, remove '/api' and trailing slash
baseUrl = baseUrl.replace("/api", "").replace(/\/$/, "");
}
return baseUrl;
} catch (error) {
console.error("Error processing server URL:", error);
throw new Error("Invalid server URL configuration");
}
};
}, [activeSocket]);
const createRun = async (sessionId: number): Promise<string> => {
const payload = { session_id: sessionId, user_id: user?.email || "" };
const response = await fetch(`${serverUrl}/runs/`, {
method: "POST",
headers: { "Content-Type": "application/json" },
@ -212,334 +131,343 @@ export default function ChatView({
return data.data.run_id;
};
const startRun = async (runId: string, query: string) => {
const messagePayload = {
user_id: user?.email,
session_id: session?.id,
config: {
content: query,
source: "user",
},
};
const handleWebSocketMessage = (message: WebSocketMessage) => {
setCurrentRun((current) => {
if (!current || !session?.id) return null;
const response = await fetch(`${serverUrl}/runs/${runId}/start`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(messagePayload),
});
if (!response.ok) {
throw new Error("Failed to start run");
}
return await response.json();
};
const connectWebSocket = (runId: string, query: string) => {
const baseUrl = getBaseUrl(serverUrl);
const wsProtocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const wsUrl = `${wsProtocol}//${baseUrl}/api/ws/runs/${runId}`;
const socket = new WebSocket(wsUrl);
let isClosing = false;
const clearTimeoutForRun = () => {
if (timeoutRefs.current[runId]) {
clearTimeout(timeoutRefs.current[runId]);
delete timeoutRefs.current[runId];
}
};
const closeSocket = () => {
if (!isClosing && socket.readyState !== WebSocket.CLOSED) {
isClosing = true;
socket.close();
updateSocket(runId, null);
}
};
socket.onopen = async () => {
try {
updateSocket(runId, socket);
setThreadMessages((prev) => ({
...prev,
[runId]: {
messages: [],
status: "streaming",
isExpanded: true,
},
}));
setMessages((prev: Message[]) =>
prev.map((msg: Message) => {
if (msg.run_id === runId && msg.config.source === "bot") {
return {
...msg,
config: {
...msg.config,
content: "Starting...",
},
};
}
return msg;
})
);
await startRun(runId, query);
} catch (error) {
closeSocket();
setThreadMessages((prev) => ({
...prev,
[runId]: {
...prev[runId],
status: "error",
isExpanded: true,
},
}));
}
};
socket.onmessage = (event) => {
const message: WebSocketMessage = JSON.parse(event.data);
console.log("WebSocket message:", message);
switch (message.type) {
case "input_request":
clearTimeoutForRun();
timeoutRefs.current[runId] = setTimeout(() => {
handleTimeoutForRun(runId);
}, TIMEOUT_CONFIG.DURATION_MS);
setThreadMessages((prev) => ({
...prev,
[runId]: {
...prev[runId],
status: "awaiting_input",
inputRequest: {
prompt: message.data?.content || "",
isPending: false,
},
},
}));
break;
case "error":
if (inputTimeoutRef.current) {
clearTimeout(inputTimeoutRef.current);
inputTimeoutRef.current = null;
}
if (activeSocket) {
activeSocket.close();
setActiveSocket(null);
activeSocketRef.current = null;
}
console.log("Error: ", message.error);
case "message":
clearTimeoutForRun();
if (!message.data) return current;
setThreadMessages((prev) => {
const currentThread = prev[runId] || {
messages: [],
status: "streaming",
isExpanded: true,
};
// Create new Message object from websocket data
const newMessage = createMessage(
message.data as AgentMessageConfig,
current.id,
session.id
);
const models_usage: ModelUsage | undefined = message.data
?.models_usage
? {
prompt_tokens: message.data.models_usage.prompt_tokens,
completion_tokens:
message.data.models_usage.completion_tokens,
}
: undefined;
return {
...current,
messages: [...current.messages, newMessage],
};
return {
...prev,
[runId]: {
...currentThread,
messages: [
...currentThread.messages,
{
source: message.data?.source || "",
content: message.data?.content || "",
models_usage,
},
],
status: "streaming",
},
};
});
break;
case "input_request":
if (inputTimeoutRef.current) {
clearTimeout(inputTimeoutRef.current);
}
inputTimeoutRef.current = setTimeout(() => {
const socket = activeSocketRef.current;
console.log("Input timeout", socket);
if (socket?.readyState === WebSocket.OPEN) {
socket.send(
JSON.stringify({
type: "stop",
reason: TIMEOUT_CONFIG.DEFAULT_MESSAGE,
code: TIMEOUT_CONFIG.WEBSOCKET_CODE,
})
);
setCurrentRun((prev) =>
prev
? {
...prev,
status: "stopped",
error_message: TIMEOUT_CONFIG.DEFAULT_MESSAGE,
}
: null
);
}
}, TIMEOUT_CONFIG.DURATION_MS);
return {
...current,
status: "awaiting_input",
};
case "result":
case "completion":
clearTimeoutForRun();
// When run completes, move it to existingRuns
const status: RunStatus =
message.status === "complete"
? "complete"
: message.status === "error"
? "error"
: "stopped";
setThreadMessages((prev) => {
const currentThread = prev[runId];
if (!currentThread) return prev;
const isTeamResult = (data: any): data is TeamResult => {
return (
data &&
"task_result" in data &&
"usage" in data &&
"duration" in data
);
};
const status: ThreadStatus = message.status || "complete";
const reason =
message.data?.task_result?.stop_reason ||
(message.error ? `Error: ${message.error}` : undefined);
const updatedRun = {
...current,
status,
team_result:
message.data && isTeamResult(message.data) ? message.data : null,
};
return {
...prev,
[runId]: {
...currentThread,
status,
reason,
isExpanded: true,
finalResult: message.data?.task_result?.messages
?.filter((msg: any) => msg.content !== "TERMINATE")
.pop(),
},
};
});
closeSocket();
break;
}
};
socket.onclose = (event) => {
clearTimeoutForRun();
if (!isClosing) {
updateSocket(runId, null);
setThreadMessages((prev) => {
const thread = prev[runId];
if (thread && thread.status === "streaming") {
return {
...prev,
[runId]: {
...thread,
status:
event.code === TIMEOUT_CONFIG.WEBSOCKET_CODE
? "timeout"
: "complete",
reason: event.reason || "Connection closed",
},
};
// Add to existing runs if complete
if (status === "complete") {
if (inputTimeoutRef.current) {
clearTimeout(inputTimeoutRef.current);
inputTimeoutRef.current = null;
}
if (activeSocket) {
activeSocket.close();
setActiveSocket(null);
activeSocketRef.current = null;
}
setExistingRuns((prev) => [...prev, updatedRun]);
return null;
}
return prev;
});
return updatedRun;
default:
return current;
}
};
socket.onerror = (error) => {
clearTimeoutForRun();
setThreadMessages((prev) => {
const thread = prev[runId];
if (!thread) return prev;
return {
...prev,
[runId]: {
...thread,
status: "error",
reason: "WebSocket connection error occurred",
isExpanded: true,
},
};
});
closeSocket();
};
return socket;
});
};
const cancelRun = async (runId: string) => {
const socket = activeSockets[runId];
if (socket && socket.readyState === WebSocket.OPEN) {
socket.send(
JSON.stringify({ type: "stop", reason: "Cancelled by user" })
const handleError = (error: any) => {
console.error("Error:", error);
message.error("Error during request processing");
setCurrentRun((current) => {
if (!current) return null;
const errorRun = {
...current,
status: "error" as const,
error_message:
error instanceof Error ? error.message : "Unknown error occurred",
};
// Add failed run to existing runs
setExistingRuns((prev) => [...prev, errorRun]);
return null; // Clear current run
});
setError({
status: false,
message:
error instanceof Error ? error.message : "Unknown error occurred",
});
};
const handleInputResponse = async (response: string) => {
if (!activeSocketRef.current || !currentRun) return;
if (activeSocketRef.current.readyState !== WebSocket.OPEN) {
console.error(
"Socket not in OPEN state:",
activeSocketRef.current.readyState
);
handleError(new Error("WebSocket connection not available"));
return;
}
// Clear timeout when response received
if (inputTimeoutRef.current) {
clearTimeout(inputTimeoutRef.current);
inputTimeoutRef.current = null;
}
try {
console.log("Sending input response:", response);
activeSocketRef.current.send(
JSON.stringify({
type: "input_response",
response: response,
})
);
setThreadMessages((prev) => ({
...prev,
[runId]: {
...prev[runId],
status: "cancelled",
reason: "Cancelled by user",
isExpanded: true,
},
}));
setCurrentRun((current) => {
if (!current) return null;
return {
...current,
status: "active",
};
});
} catch (error) {
handleError(error);
}
};
// Clean up timeouts when component unmounts
React.useEffect(() => {
return () => {
Object.entries(timeoutRefs.current).forEach(([_, timeout]) =>
clearTimeout(timeout)
const handleCancel = async () => {
if (!activeSocketRef.current || !currentRun) return;
// Clear timeout when manually cancelled
if (inputTimeoutRef.current) {
clearTimeout(inputTimeoutRef.current);
inputTimeoutRef.current = null;
}
try {
activeSocketRef.current.send(
JSON.stringify({
type: "stop",
reason: "Cancelled by user",
})
);
timeoutRefs.current = {};
};
}, []);
setCurrentRun((current) => {
if (!current) return null;
return {
...current,
status: "stopped",
};
});
} catch (error) {
handleError(error);
}
};
const runTask = async (query: string) => {
setError(null);
setLoading(true);
if (!session?.id) {
// Add explicit cleanup
if (activeSocket) {
activeSocket.close();
setActiveSocket(null);
activeSocketRef.current = null;
}
if (inputTimeoutRef.current) {
clearTimeout(inputTimeoutRef.current);
inputTimeoutRef.current = null;
}
if (!session?.id || !teamConfig) {
// Add teamConfig check
setLoading(false);
return;
}
let runId: string | null = null;
try {
runId = (await createRun(session.id)) + "";
const runId = await createRun(session.id);
const userMessage: Message = {
config: {
// Initialize run state BEFORE websocket connection
setCurrentRun({
id: runId,
created_at: new Date().toISOString(),
status: "created", // Start with created status
messages: [],
task: {
content: query,
source: "user",
},
session_id: session.id,
run_id: runId,
};
const botMessage: Message = {
config: {
content: "Thinking...",
source: "bot",
},
session_id: session.id,
run_id: runId,
};
setMessages((prev) => [...prev, userMessage, botMessage]);
connectWebSocket(runId, query); // Now passing query to connectWebSocket
} catch (err) {
console.error("Error:", err);
message.error("Error during request processing");
if (runId) {
if (activeSockets[runId]) {
activeSockets[runId].close();
}
setThreadMessages((prev) => ({
...prev,
[runId!]: {
...prev[runId!],
status: "error",
isExpanded: true,
},
}));
}
setError({
status: false,
message: err instanceof Error ? err.message : "Unknown error occurred",
team_result: null,
error_message: undefined,
});
// Setup WebSocket
const socket = setupWebSocket(runId, query);
setActiveSocket(socket);
activeSocketRef.current = socket;
} catch (error) {
handleError(error);
} finally {
setLoading(false);
}
};
React.useEffect(() => {
// session changed
if (session) {
setMessages([]);
setThreadMessages({});
const setupWebSocket = (runId: string, query: string): WebSocket => {
if (!session || !session.id) {
throw new Error("Invalid session configuration");
}
}, [session]);
// Close existing socket if any
if (activeSocket?.readyState === WebSocket.OPEN) {
activeSocket.close();
}
const baseUrl = getBaseUrl(serverUrl);
const wsProtocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const wsUrl = `${wsProtocol}//${baseUrl}/api/ws/runs/${runId}`;
const socket = new WebSocket(wsUrl);
// Initialize current run
setCurrentRun({
id: runId,
created_at: new Date().toISOString(),
status: "active",
task: createMessage(
{ content: query, source: "user" },
runId,
session.id || 0
).config,
team_result: null,
messages: [],
error_message: undefined,
});
socket.onopen = () => {
// Send start message with teamConfig
socket.send(
JSON.stringify({
type: "start",
task: query,
team_config: teamConfig,
})
);
};
socket.onmessage = (event) => {
try {
const message = JSON.parse(event.data);
handleWebSocketMessage(message);
} catch (error) {
console.error("WebSocket message parsing error:", error);
}
};
socket.onclose = () => {
activeSocketRef.current = null;
setActiveSocket(null);
};
socket.onerror = (error) => {
handleError(error);
};
return socket;
};
// Helper for WebSocket URL
const getBaseUrl = (url: string): string => {
try {
let baseUrl = url.replace(/(^\w+:|^)\/\//, "");
if (baseUrl.startsWith("localhost")) {
baseUrl = baseUrl.replace("/api", "");
} else if (baseUrl === "/api") {
baseUrl = window.location.host;
} else {
baseUrl = baseUrl.replace("/api", "").replace(/\/$/, "");
}
return baseUrl;
} catch (error) {
console.error("Error processing server URL:", error);
throw new Error("Invalid server URL configuration");
}
};
return (
<div className="text-primary h-[calc(100vh-195px)] bg-primary relative rounded flex-1 scroll">
@ -549,45 +477,62 @@ export default function ChatView({
</div>
<TeamManager />
</div>
<div className="flex flex-col h-full">
<div
className="flex-1 overflow-y-auto scroll mt-2 relative min-h-0"
ref={chatContainerRef}
className="flex-1 overflow-y-auto scroll mt-2 min-h-0 relative"
>
<MessageList
messages={messages}
threadMessages={threadMessages}
setThreadMessages={setThreadMessages}
onRetry={runTask}
onCancel={cancelRun}
onInputResponse={handleInputResponse} // Add the new prop
loading={loading}
teamConfig={teamConfig}
/>
<div id="scroll-gradient" className="scroll-gradient h-8 top-0">
{" "}
<span className=" inline-block h-6"></span>{" "}
</div>
{sessions !== null && sessions?.length === 0 ? (
<div className="flex h-[calc(100%-100px)] flex-col items-center justify-center w-full">
<div className="mt-4 text-sm text-secondary text-center">
<img src={logo} alt="Welcome" className="w-72 h-72 mb-4" />
Welcome! Create a session to get started!
</div>
</div>
) : (
<>
{teamConfig && (
<>
{/* Existing Runs */}
{existingRuns.map((run, index) => (
<RunView
teamConfig={teamConfig}
key={run.id + "-review-" + index}
run={run}
isFirstRun={index === 0}
/>
))}
{/* Current Run */}
{currentRun && (
<RunView
run={currentRun}
teamConfig={teamConfig}
onInputResponse={handleInputResponse}
onCancel={handleCancel}
isFirstRun={existingRuns.length === 0}
/>
)}
</>
)}
</>
)}
</div>
{sessions !== null && sessions?.length === 0 ? (
<div className="flex h-[calc(100%-100px)] flex-col items-center justify-center w-full">
<div className="mt-4 text-sm text-secondary text-center">
<img src={logo} alt="Welcome" className="w-72 h-72 mb-4" />
Welcome! Create a session to get started!
</div>
{session && (
<div className="flex-shrink-0">
<ChatInput
onSubmit={runTask}
loading={loading}
error={error}
disabled={currentRun?.status === "awaiting_input"}
/>
</div>
) : (
<>
{session && (
<div className="flex-shrink-0">
<ChatInput
onSubmit={runTask}
loading={loading}
error={error}
disabled={Object.values(threadMessages).some(
(thread) => thread.status === "awaiting_input"
)} // Disable input while waiting for user input
/>
</div>
)}
</>
)}
</div>
</div>

View File

@ -12,15 +12,6 @@ import { ThreadState, TIMEOUT_CONFIG } from "./types";
import { RenderMessage } from "./rendermessage";
import LoadingDots from "../../shared/atoms";
interface ThreadViewProps {
thread: ThreadState;
isStreaming: boolean;
runId: string;
onCancel: (runId: string) => void;
onInputResponse: (runId: string, response: string) => void;
threadContainerRef: (el: HTMLDivElement | null) => void;
}
interface InputRequestProps {
prompt: string;
onSubmit: (response: string) => void;
@ -170,105 +161,4 @@ const InputRequestView: React.FC<InputRequestProps> = ({
);
};
const ThreadView: React.FC<ThreadViewProps> = ({
thread,
isStreaming,
runId,
onCancel,
onInputResponse,
threadContainerRef,
}) => {
const isAwaitingInput = thread.status === "awaiting_input";
const isTimedOut = thread.status === "timeout";
const getStatusIcon = () => {
switch (thread.status) {
case "streaming":
return <Loader2 size={16} className="animate-spin text-accent" />;
case "awaiting_input":
return <MessageSquare size={16} className="text-accent" />;
case "complete":
return <CheckCircle size={16} className="text-accent" />;
case "error":
return <AlertTriangle size={16} className="text-red-500" />;
case "timeout":
return <Clock size={16} className="text-red-500" />;
default:
return null;
}
};
const getStatusText = () => {
if (isStreaming) {
return (
<>
<span className="inline-block mr-2">Agents working</span>
<LoadingDots size={8} />
</>
);
}
if (isAwaitingInput) return "Waiting for your input";
if (isTimedOut) return TIMEOUT_CONFIG.DEFAULT_MESSAGE;
if (thread.reason)
return (
<>
<span className="font-semibold mr-2">Stop Reason:</span>
{thread.reason}
</>
);
return null;
};
const handleTimeout = () => {
if (thread.inputRequest) {
onInputResponse(runId, "TIMEOUT");
}
};
return (
<div className="mt-2 border border-secondary rounded bg-primary">
<div className="sticky top-0 z-10 flex bg-primary rounded-t items-center justify-between p-3 border-b border-secondary bg-secondary/10">
<div className="text-sm text-primary flex items-center gap-2">
{getStatusIcon()}
{getStatusText()}
</div>
{(isStreaming || isAwaitingInput) && (
<button
onClick={() => onCancel(runId)}
className="flex items-center gap-1 px-3 py-1 rounded bg-red-500 hover:bg-red-600 text-white text-xs font-medium transition-colors"
>
<StopCircle size={12} />
<span>Stop</span>
</button>
)}
</div>
<div
ref={threadContainerRef}
className="max-h-[400px] overflow-y-auto scroll"
>
<div className="p-3 space-y-3">
{thread.messages.map((threadMsg, threadIndex) => (
<div key={`thread-${threadIndex}`}>
<RenderMessage
message={threadMsg}
isLast={threadIndex === thread.messages.length - 1}
/>
</div>
))}
{thread.inputRequest && (
<InputRequestView
prompt={thread.inputRequest.prompt}
onSubmit={(response) => onInputResponse(runId, response)}
disabled={!isAwaitingInput || isTimedOut}
onTimeout={handleTimeout}
/>
)}
</div>
</div>
</div>
);
};
export default ThreadView;
export default InputRequestView;

View File

@ -1,289 +0,0 @@
import React from "react";
import { ThreadState } from "./types";
import {
AgentMessageConfig,
Message,
TeamConfig,
} from "../../../types/datamodel";
import { RenderMessage } from "./rendermessage";
import {
StopCircle,
User,
Network,
MessageSquare,
Loader2,
CheckCircle,
AlertTriangle,
TriangleAlertIcon,
GroupIcon,
} from "lucide-react";
import AgentFlow from "./agentflow/agentflow";
import ThreadView from "./threadview";
import LoadingDots from "../../shared/atoms";
interface MessageListProps {
messages: Message[];
threadMessages: Record<string, ThreadState>;
setThreadMessages: React.Dispatch<
React.SetStateAction<Record<string, ThreadState>>
>;
onRetry: (content: string) => void;
onCancel: (runId: string) => void;
onInputResponse: (runId: string, response: string) => void;
loading?: boolean;
teamConfig?: TeamConfig;
}
interface MessagePair {
userMessage: Message;
botMessage: Message;
}
export const MessageList: React.FC<MessageListProps> = ({
messages,
threadMessages,
setThreadMessages,
onRetry,
onCancel,
onInputResponse, // New prop
loading = false,
teamConfig,
}) => {
const messagePairs = React.useMemo(() => {
const pairs: MessagePair[] = [];
for (let i = 0; i < messages.length; i += 2) {
if (messages[i] && messages[i + 1]) {
pairs.push({
userMessage: messages[i],
botMessage: messages[i + 1],
});
}
}
return pairs;
}, [messages]);
// Create a ref map to store refs for each thread container
const threadContainerRefs = React.useRef<
Record<string, HTMLDivElement | null>
>({});
// Effect to handle scrolling when thread messages update
React.useEffect(() => {
Object.entries(threadMessages).forEach(([runId, thread]) => {
if (thread.isExpanded && threadContainerRefs.current[runId]) {
const container = threadContainerRefs.current[runId];
if (container) {
container.scrollTo({
top: container.scrollHeight,
behavior: "smooth",
});
}
}
});
}, [threadMessages]);
const toggleThread = (runId: string) => {
setThreadMessages((prev) => ({
...prev,
[runId]: {
...prev[runId],
isExpanded: !prev[runId]?.isExpanded,
},
}));
};
const calculateThreadTokens = (messages: AgentMessageConfig[]) => {
return messages.reduce((total, msg) => {
if (!msg.models_usage) return total;
return (
total +
(msg.models_usage.prompt_tokens || 0) +
(msg.models_usage.completion_tokens || 0)
);
}, 0);
};
const getStatusIcon = (status: ThreadState["status"]) => {
switch (status) {
case "streaming":
return (
<div className="inline-block mr-1">
<Loader2
size={20}
className="inline-block mr-1 text-accent animate-spin"
/>{" "}
<span className="inline-block mr-2">Processing</span>{" "}
<LoadingDots size={8} />
</div>
);
case "awaiting_input": // New status
return (
<div className="text-sm mb-2">
<MessageSquare
size={20}
className="inline-block mr-1 text-accent"
/>{" "}
Waiting for your input
</div>
);
case "complete":
return (
<div className="text-sm mb-2">
<CheckCircle size={20} className="inline-block mr-1 text-accent" />{" "}
Task completed
</div>
);
case "error":
return (
<div className="text-sm mb-2">
<AlertTriangle
size={20}
className="inline-block mr-1 text-red-500"
/>{" "}
An error occurred.
</div>
);
case "cancelled":
return (
<div className="text-sm mb-2">
<StopCircle size={20} className="inline-block mr-1 text-red-500" />{" "}
Task was cancelled.
</div>
);
default:
return null;
}
};
return (
<div className="space-y-6 p-4 h-full">
{messagePairs.map(({ userMessage, botMessage }, pairIndex) => {
const isLast = pairIndex === messagePairs.length - 1;
const thread = threadMessages[botMessage.run_id];
const hasThread = thread && thread.messages.length > 0;
const isStreaming = thread?.status === "streaming";
const isAwaitingInput = thread?.status === "awaiting_input"; // New check
const isFirstMessage = pairIndex === 0;
return (
<div key={`pair-${botMessage.run_id}`} className="space-y-6">
{/* User message */}
{
<div
className={`${
isFirstMessage ? "mb-2" : "mt-8"
} mb-4 pt-2 border-t border-dashed border-secondary`}
>
{/* <div>Task Run 1. </div> */}
<div className="text-xs text-secondary">
Run {pairIndex + 1}
{!isFirstMessage && (
<>
{" "}
|{" "}
<TriangleAlertIcon className="w-4 h-4 -mt-1 inline-block mr-1 ml-1" />{" "}
Note: Each run does not share data with previous runs in
the same session yet.{" "}
</>
)}
</div>
</div>
}
<div className="flex flex-col items-end">
<div className="flex items-center gap-2 mb-1">
<span className="text-sm font-medium text-primary">You</span>
<div className="p-1.5 rounded bg-secondary text-accent">
<User size={20} />
</div>
</div>
<div className="w-full">
<RenderMessage message={userMessage.config} isLast={false} />
</div>
</div>
{/* Team response */}
<div className="flex flex-col items-start">
<div className="flex items-center gap-2 mb-1">
<div className="p-1.5 rounded bg-secondary text-primary">
<GroupIcon size={20} />
</div>
<span className="text-sm font-medium text-primary">
Agent Team
</span>
</div>
{/* Main response container */}
<div className="w-full">
<div className="p-4 bg-tertiary bordder border-secondary rounded">
<div className="text-primary">
{getStatusIcon(thread?.status)}{" "}
{!isAwaitingInput && thread?.finalResult?.content}
</div>
</div>
{/* Thread section */}
{hasThread && (
<div className="mt-2 pl-4 border-l-2 border-secondary/30">
<div className="flex pt-2">
<div className="flex-1">
<button
onClick={() => toggleThread(botMessage.run_id)}
className="flex items-center gap-1 text-sm text-secondary hover:text-primary transition-colors"
>
<MessageSquare size={16} />
<span className="text-accent">
{thread.isExpanded ? "Hide" : "Show"}
</span>{" "}
agent discussion
</button>
</div>
<div className="text-sm text-secondary">
{calculateThreadTokens(thread.messages)} tokens |{" "}
{thread.messages.length} messages
</div>
</div>
<div className="flex flex-row gap-4">
<div className="flex-1">
{thread.isExpanded && (
<ThreadView
thread={thread}
isStreaming={isStreaming}
runId={botMessage.run_id}
onCancel={onCancel}
onInputResponse={onInputResponse} // Pass through the new prop
threadContainerRef={(el) =>
(threadContainerRefs.current[botMessage.run_id] =
el)
}
/>
)}
</div>
<div className="bg-tertiary flex-1 rounded mt-2">
{teamConfig && thread.isExpanded && (
<AgentFlow
teamConfig={teamConfig}
messages={thread.messages}
threadState={thread}
/>
)}
</div>
</div>
</div>
)}
</div>
</div>
</div>
);
})}
{messages.length === 0 && !loading && (
<div className="text-center text-secondary h-full">
<div className="text-sm mt-4">Send a message to begin!</div>
</div>
)}
</div>
);
};

View File

@ -58,9 +58,9 @@ const RenderToolCall: React.FC<{ content: FunctionCall[] }> = ({ content }) => (
{content.map((call) => (
<div key={call.id} className="border rounded p-2">
<div className="font-medium">Function: {call.name}</div>
<pre className="text-sm mt-1 bg-secondary p-2 rounded">
<div className="text-sm mt-1 bg-secondary p-2 rounded">
{JSON.stringify(JSON.parse(call.arguments), null, 2)}
</pre>
</div>
</div>
))}
</div>
@ -90,9 +90,9 @@ const RenderToolResult: React.FC<{ content: FunctionExecutionResult[] }> = ({
}) => (
<div className="space-y-2">
{content.map((result) => (
<div key={result.call_id} className=" rounded p-2">
<div key={result.call_id} className=" rounded p-2 ">
<div className="font-medium">Result ID: {result.call_id}</div>
<pre className="text-sm mt-1 bg-secondary p-2 border rounded">
<pre className="text-sm mt-1 bg-secondary p-2 border rounded scroll overflow-x-scroll">
{result.content}
</pre>
</div>
@ -105,6 +105,7 @@ export const RenderMessage: React.FC<MessageProps> = ({
isLast = false,
className = "",
}) => {
if (!message) return null;
const isUser = messageUtils.isUser(message.source);
const content = message.content;

View File

@ -0,0 +1,271 @@
import React, { useState, useRef, useEffect } from "react";
import {
StopCircle,
MessageSquare,
Loader2,
CheckCircle,
AlertTriangle,
TriangleAlertIcon,
GroupIcon,
} from "lucide-react";
import { Run, Message, TeamConfig } from "../../../types/datamodel";
import AgentFlow from "./agentflow/agentflow";
import { RenderMessage } from "./rendermessage";
import LoadingDots from "../../shared/atoms";
import InputRequestView from "./inputrequest";
import { Tooltip } from "antd";
interface RunViewProps {
run: Run;
teamConfig?: TeamConfig;
onInputResponse?: (response: string) => void;
onCancel?: () => void;
isFirstRun?: boolean;
}
const RunView: React.FC<RunViewProps> = ({
run,
onInputResponse,
onCancel,
teamConfig,
isFirstRun = false,
}) => {
const [isExpanded, setIsExpanded] = useState(true);
const threadContainerRef = useRef<HTMLDivElement | null>(null);
const isActive = run.status === "active" || run.status === "awaiting_input";
// Replace existing scroll effect with this simpler one
useEffect(() => {
setTimeout(() => {
if (threadContainerRef.current) {
threadContainerRef.current.scrollTo({
top: threadContainerRef.current.scrollHeight,
behavior: "smooth",
});
}
}, 450);
}, [run.messages]); // Only depend on messages changing
const calculateThreadTokens = (messages: Message[]) => {
return messages.reduce((total, msg) => {
if (!msg.config.models_usage) return total;
return (
total +
(msg.config.models_usage.prompt_tokens || 0) +
(msg.config.models_usage.completion_tokens || 0)
);
}, 0);
};
const getStatusIcon = (status: Run["status"]) => {
switch (status) {
case "active":
return (
<div className="inline-block mr-1">
<Loader2
size={20}
className="inline-block mr-1 text-accent animate-spin"
/>
<span className="inline-block mr-2 ml-1 ">Processing</span>
<LoadingDots size={8} />
</div>
);
case "awaiting_input":
return (
<div className="text-sm mb-2">
<MessageSquare
size={20}
className="inline-block mr-2 text-accent"
/>
<span className="inline-block mr-2">Waiting for your input </span>
<LoadingDots size={8} />
</div>
);
case "complete":
return (
<div className="text-sm mb-2">
<CheckCircle size={20} className="inline-block mr-2 text-accent" />
Task completed
</div>
);
case "error":
return (
<div className="text-sm mb-2">
<AlertTriangle
size={20}
className="inline-block mr-2 text-red-500"
/>
{run.error_message || "An error occurred"}
</div>
);
case "stopped":
return (
<div className="text-sm mb-2">
<StopCircle size={20} className="inline-block mr-2 text-red-500" />
Task was stopped
</div>
);
default:
return null;
}
};
return (
<div className="space-y-6 mt-6 mr-2 ">
{/* Run Header */}
<div
className={`${
isFirstRun ? "mb-2" : "mt-8"
} mb-4 pt-2 border-t border-dashed border-secondary`}
>
<div className="text-xs text-secondary">
<Tooltip
title={
<div className="text-xs">
<div>ID: {run.id}</div>
<div>Created: {new Date(run.created_at).toLocaleString()}</div>
<div>Status: {run.status}</div>
</div>
}
>
<span className="cursor-help">Run ...{run.id.slice(-6)}</span>
</Tooltip>
{!isFirstRun && (
<>
{" "}
|{" "}
<TriangleAlertIcon className="w-4 h-4 -mt-1 inline-block mr-1 ml-1" />
Note: Each run does not share data with previous runs in the same
session yet.
</>
)}
</div>
</div>
{/* User Message */}
<div className="flex flex-col items-end w-full">
<div className="w-full">
<RenderMessage message={run.task} isLast={false} />
</div>
</div>
{/* Team Response */}
<div className="flex flex-col items-start">
<div className="flex items-center gap-2 mb-1">
<div className="p-1.5 rounded bg-secondary text-primary">
<GroupIcon size={20} />
</div>
<span className="text-sm font-medium text-primary">Agent Team</span>
</div>
<div className=" w-full">
{/* Main Response Container */}
<div className="p-4 bg-secondary border border-secondary rounded">
<div className="flex justify-between items-start mb-2">
<div className="text-primary">{getStatusIcon(run.status)}</div>
{/* Cancel Button - More prominent placement */}
{isActive && onCancel && (
<button
onClick={onCancel}
className="px-4 text-sm py-2 bg-red-500 hover:bg-red-600 text-white rounded-md transition-colors flex items-center gap-2"
>
<StopCircle size={16} />
Cancel Run
</button>
)}
</div>
{/* Final Response */}
{run.status !== "awaiting_input" && run.status !== "active" && (
<div className="text-sm">
<div className="text-xs mb-1 text-secondary -mt-2 border rounded p-2">
Stop reason: {run.team_result?.task_result?.stop_reason}
</div>
{run.messages[run.messages.length - 1]?.config?.content + ""}
</div>
)}
</div>
{/* Thread Section */}
<div className="">
{run.messages.length > 0 && (
<div className="mt-2 pl-4 border-l-2 border-secondary/30">
<div className="flex pt-2">
<div className="flex-1">
<button
onClick={() => setIsExpanded(!isExpanded)}
className="flex items-center gap-1 text-sm text-secondary hover:text-primary transition-colors"
>
<MessageSquare size={16} />
<span className="text-accent">
{isExpanded ? "Hide" : "Show"}
</span>{" "}
agent discussion
</button>
</div>
<div className="text-sm text-secondary">
{calculateThreadTokens(run.messages)} tokens |{" "}
{run.messages.length} messages
</div>
</div>
{isExpanded && (
<div className="flex flex-row gap-4">
{/* Messages Thread */}
<div
ref={threadContainerRef}
className="flex-1 mt-2 overflow-y-auto max-h-[400px] scroll-smooth scroll pb-2 relative"
>
<div id="scroll-gradient" className="scroll-gradient h-8">
{" "}
<span className=" inline-block h-6"></span>{" "}
</div>
{run.messages.map((msg, idx) => (
<div
key={"message_id" + idx + run.id}
className=" mr-2"
>
<RenderMessage
message={msg.config}
isLast={idx === run.messages.length - 1}
/>
</div>
))}
{/* Input Request UI */}
{run.status === "awaiting_input" && onInputResponse && (
<div className="mt-4 mr-2">
<InputRequestView
prompt="Type your response..."
onSubmit={onInputResponse}
/>
</div>
)}
<div className="text-primary mt-2">
<div className="w-4 h-4 inline-block border-secondary rounded-bl-lg border-l-2 border-b-2"></div>{" "}
<div className="inline-block ">
{getStatusIcon(run.status)}
</div>
</div>
</div>
{/* Agent Flow Visualization */}
<div className="bg-tertiary flex-1 rounded mt-2">
{teamConfig && (
<AgentFlow teamConfig={teamConfig} run={run} />
)}
</div>
</div>
)}
</div>
)}
</div>
</div>
</div>
</div>
);
};
export default RunView;

View File

@ -1,20 +1,18 @@
import React, { ReactNode } from "react";
import Markdown from "react-markdown";
import React from "react";
interface MarkdownViewProps {
children: string;
content: string;
className?: string;
}
export const MarkdownView: React.FC<MarkdownViewProps> = ({
children,
content,
className = "",
}) => {
return (
<div
className={`text-sm w-full prose dark:prose-invert text-primary rounded p-2 ${className}`}
>
<Markdown>{children}</Markdown>
<div className={`text-sm w-full text-primary rounded ${className}`}>
<Markdown>{content}</Markdown>
</div>
);
};

View File

@ -1,4 +1,4 @@
import { Session } from "../../../types/datamodel";
import { Session, SessionRuns } from "../../../types/datamodel";
import { getServerUrl } from "../../../utils";
export class SessionAPI {
@ -83,6 +83,23 @@ export class SessionAPI {
return data.data;
}
// session runs with messages
async getSessionRuns(
sessionId: number,
userId: string
): Promise<SessionRuns> {
const response = await fetch(
`${this.getBaseUrl()}/sessions/${sessionId}/runs?user_id=${userId}`,
{
headers: this.getHeaders(),
}
);
const data = await response.json();
if (!data.status)
throw new Error(data.message || "Failed to fetch session runs");
return data.data; // Returns { runs: RunMessage[] }
}
async deleteSession(sessionId: number, userId: string): Promise<void> {
const response = await fetch(
`${this.getBaseUrl()}/sessions/${sessionId}?user_id=${userId}`,

View File

@ -8,7 +8,7 @@ const IndexPage = ({ data }: any) => {
return (
<Layout meta={data.site.siteMetadata} title="Home" link={"/"}>
<main style={{ height: "100%" }} className=" h-full ">
<ChatView initMessages={[]} />
<ChatView />
</main>
</Layout>
);

File diff suppressed because it is too large Load Diff

View File

@ -16,17 +16,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"task_result=TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?'), ToolCallMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=65, completion_tokens=15), content=[FunctionCall(id='call_x8C5nib1PJkMZGQ6zrNUlfa0', arguments='{\"city\":\"New York\"}', name='get_weather')]), ToolCallResultMessage(source='writing_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_x8C5nib1PJkMZGQ6zrNUlfa0')]), TextMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=97, completion_tokens=14), content='The weather in New York is currently 73 degrees and sunny.'), TextMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=13), content='Would you like to know the weather in any other city?')], stop_reason='Maximum number of messages 5 reached, current message count: 5') usage='' duration=1.9984567165374756\n"
]
}
],
"outputs": [],
"source": [
"from autogenstudio.teammanager import TeamManager \n",
" \n",
@ -37,22 +29,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source='user' models_usage=None content='What is the weather in New York?'\n",
"source='writing_agent' models_usage=RequestUsage(prompt_tokens=65, completion_tokens=15) content=[FunctionCall(id='call_Gwnfsa8ndnOsXTvRECTr92hr', arguments='{\"city\":\"New York\"}', name='get_weather')]\n",
"source='writing_agent' models_usage=None content=[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_Gwnfsa8ndnOsXTvRECTr92hr')]\n",
"source='writing_agent' models_usage=RequestUsage(prompt_tokens=97, completion_tokens=14) content='The weather in New York is currently 73 degrees and sunny.'\n",
"source='writing_agent' models_usage=RequestUsage(prompt_tokens=123, completion_tokens=14) content='The weather in New York is currently 73 degrees and sunny.'\n",
"task_result=TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?'), ToolCallMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=65, completion_tokens=15), content=[FunctionCall(id='call_Gwnfsa8ndnOsXTvRECTr92hr', arguments='{\"city\":\"New York\"}', name='get_weather')]), ToolCallResultMessage(source='writing_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_Gwnfsa8ndnOsXTvRECTr92hr')]), TextMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=97, completion_tokens=14), content='The weather in New York is currently 73 degrees and sunny.'), TextMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=14), content='The weather in New York is currently 73 degrees and sunny.')], stop_reason='Maximum number of messages 5 reached, current message count: 5') usage='' duration=2.363379955291748\n"
]
}
],
"outputs": [],
"source": [
"result_stream = wm.run_stream(task=\"What is the weather in New York?\", team_config=\"team.json\") \n",
"async for response in result_stream:\n",
@ -74,28 +53,32 @@
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO [alembic.runtime.migration] Context impl SQLiteImpl.\n",
"INFO [alembic.runtime.migration] Will assume non-transactional DDL.\n",
"\u001b[32m2024-11-14 09:06:25.242\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mautogenstudio.database.schema_manager\u001b[0m:\u001b[36mupgrade_schema\u001b[0m:\u001b[36m390\u001b[0m - \u001b[1mSchema upgraded successfully to head\u001b[0m\n",
"\u001b[32m2024-11-14 09:06:25.243\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mautogenstudio.database.db_manager\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m34\u001b[0m - \u001b[1mDatabase schema was upgraded automatically\u001b[0m\n",
"\u001b[32m2024-11-14 09:06:25.244\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mautogenstudio.database.db_manager\u001b[0m:\u001b[36mcreate_db_and_tables\u001b[0m:\u001b[36m108\u001b[0m - \u001b[1mDatabase tables created successfully\u001b[0m\n"
]
"data": {
"text/plain": [
"Response(message='Database is ready', status=True, data=None)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from autogenstudio.database import DatabaseManager \n",
"import os \n",
"# delete database\n",
"# if os.path.exists(\"test.db\"):\n",
"# os.remove(\"test.db\") \n",
"\n",
"os.makedirs(\"test\", exist_ok=True)\n",
"# create a database\n",
"dbmanager = DatabaseManager(engine_uri=\"sqlite:///test.db\")\n",
"dbmanager.create_db_and_tables() "
"dbmanager = DatabaseManager(engine_uri=\"sqlite:///test.db\", base_dir=\"test\")\n",
"dbmanager.initialize_database() "
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@ -145,14 +128,14 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"53 teams in database\n"
"2 teams in database\n"
]
}
],
@ -172,7 +155,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -194,17 +177,9 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"message='Directory import complete' status=True data=[{'component': 'team', 'status': True, 'message': 'Team Created Successfully', 'id': 54}]\n"
]
}
],
"outputs": [],
"source": [
"result = await config_manager.import_directory(\".\", user_id=\"user_id\", check_exists=False)\n",
"print(result)"
@ -212,17 +187,9 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"54 teams in database\n"
]
}
],
"outputs": [],
"source": [
"all_teams = dbmanager.get(Team)\n",
"print(len(all_teams.data), \"teams in database\")"
@ -237,7 +204,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -282,22 +249,9 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source='user' models_usage=None content='Plan a 3 day trip to Nepal.'\n",
"source='planner_agent' models_usage=RequestUsage(prompt_tokens=45, completion_tokens=54) content=\"Consider starting your 3-day trip to Nepal with a cultural tour in Kathmandu Valley, followed by an exploration of Pokhara's natural beauty on the second day, and finally, indulge in a thrilling safari at Chitwan National Park on the third day.\"\n",
"source='local_agent' models_usage=RequestUsage(prompt_tokens=116, completion_tokens=54) content=\"Consider starting your 3-day trip to Nepal with a cultural tour in Kathmandu Valley, followed by an exploration of Pokhara's natural beauty on the second day, and finally, indulge in a thrilling safari at Chitwan National Park on the third day.\"\n",
"source='language_agent' models_usage=RequestUsage(prompt_tokens=201, completion_tokens=45) content=\"Your travel plan lacks a mention of dealing with potential language barriers; it might be useful to learn some phrases in Nepali, as it's the official language of Nepal, or have a local translation app handy during your trip.\"\n",
"source='travel_summary_agent' models_usage=RequestUsage(prompt_tokens=270, completion_tokens=237) content='Day 1: Start your adventure in the capital city, Kathmandu. Take a guided tour of Kathmandu Valley to explore its UNESCO World Heritage Sites, such as the Durbar Squares, Swayambhunath Stupa, and Boudhanath Stupa. Engage with the locals and sample some traditional Nepalese cuisine.\\n\\nDay 2: Proceed to Pokhara, known for its stunning natural beauty. Visit the iconic Phewa Lake and enjoy a boat ride, then trek to the Peace Pagoda for a panoramic view of the city. Round off the day with a visit to the fascinating Pokhara Mountain Museum.\\n\\nDay 3: Travel to Chitwan National Park for a memorable safari. Explore the diverse wildlife and lush vegetation that make the park a UNESCO World Heritage site. Be on the lookout for rhinos, Bengal tigers, and a multitude of bird species.\\n\\nNote: Communication is key to enjoying your trip. The official language of Nepal is Nepali. It can be helpful to learn a few basic phrases or carry a translation app to help you interact with the local people and enrich your cultural experience.\\n\\nTERMINATE.'\n",
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Plan a 3 day trip to Nepal.'), TextMessage(source='planner_agent', models_usage=RequestUsage(prompt_tokens=45, completion_tokens=54), content=\"Consider starting your 3-day trip to Nepal with a cultural tour in Kathmandu Valley, followed by an exploration of Pokhara's natural beauty on the second day, and finally, indulge in a thrilling safari at Chitwan National Park on the third day.\"), TextMessage(source='local_agent', models_usage=RequestUsage(prompt_tokens=116, completion_tokens=54), content=\"Consider starting your 3-day trip to Nepal with a cultural tour in Kathmandu Valley, followed by an exploration of Pokhara's natural beauty on the second day, and finally, indulge in a thrilling safari at Chitwan National Park on the third day.\"), TextMessage(source='language_agent', models_usage=RequestUsage(prompt_tokens=201, completion_tokens=45), content=\"Your travel plan lacks a mention of dealing with potential language barriers; it might be useful to learn some phrases in Nepali, as it's the official language of Nepal, or have a local translation app handy during your trip.\"), TextMessage(source='travel_summary_agent', models_usage=RequestUsage(prompt_tokens=270, completion_tokens=237), content='Day 1: Start your adventure in the capital city, Kathmandu. Take a guided tour of Kathmandu Valley to explore its UNESCO World Heritage Sites, such as the Durbar Squares, Swayambhunath Stupa, and Boudhanath Stupa. Engage with the locals and sample some traditional Nepalese cuisine.\\n\\nDay 2: Proceed to Pokhara, known for its stunning natural beauty. Visit the iconic Phewa Lake and enjoy a boat ride, then trek to the Peace Pagoda for a panoramic view of the city. Round off the day with a visit to the fascinating Pokhara Mountain Museum.\\n\\nDay 3: Travel to Chitwan National Park for a memorable safari. Explore the diverse wildlife and lush vegetation that make the park a UNESCO World Heritage site. Be on the lookout for rhinos, Bengal tigers, and a multitude of bird species.\\n\\nNote: Communication is key to enjoying your trip. The official language of Nepal is Nepali. It can be helpful to learn a few basic phrases or carry a translation app to help you interact with the local people and enrich your cultural experience.\\n\\nTERMINATE.')], stop_reason=\"Text 'TERMINATE' mentioned\")\n"
]
}
],
"outputs": [],
"source": [
"\n",
"result = group_chat.run_stream(task=\"Plan a 3 day trip to Nepal.\")\n",
@ -317,7 +271,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -330,19 +284,9 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source='user' models_usage=None content='hello there'\n",
"source='user_agent' models_usage=None content='Hello World thereEnter your response: '\n",
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='hello there'), TextMessage(source='user_agent', models_usage=None, content='Hello World thereEnter your response: ')], stop_reason=None)\n"
]
}
],
"outputs": [],
"source": [
"from autogen_core.base import CancellationToken \n",
"cancellation_token = CancellationToken()\n",

View File

@ -23,7 +23,8 @@ dependencies = [
"pydantic-settings",
"fastapi",
"typer",
"uvicorn",
"uvicorn",
"aiofiles",
"python-dotenv",
"websockets",
"numpy < 2.0.0",
@ -69,3 +70,21 @@ filterwarnings = [
[project.scripts]
autogenstudio = "autogenstudio.cli:run"
[tool.ruff]
extend = "../../pyproject.toml"
exclude = ["build", "dist"]
include = [
"autogenstudio/**"
]
[tool.ruff.lint]
ignore = ["B008"]
[tool.poe.tasks]
fmt = "ruff format"
format.ref = "fmt"
lint = "ruff check"
test = "pytest -n 0"

View File

@ -6,10 +6,10 @@ from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from autogen_core.components.tools import FunctionTool
from autogenstudio.datamodel import (
from autogenstudio.datamodel.types import (
AgentConfig, ModelConfig, TeamConfig, ToolConfig, TerminationConfig,
ModelTypes, AgentTypes, TeamTypes, TerminationTypes, ToolTypes,
ComponentType
ComponentTypes
)
from autogenstudio.database import ComponentFactory
@ -41,7 +41,7 @@ def calculator(a: int, b: int, operation: str = '+') -> int:
raise ValueError("Invalid operation")
""",
tool_type=ToolTypes.PYTHON_FUNCTION,
component_type=ComponentType.TOOL,
component_type=ComponentTypes.TOOL,
version="1.0.0"
)
@ -52,7 +52,7 @@ def sample_model_config():
model_type=ModelTypes.OPENAI,
model="gpt-4",
api_key="test-key",
component_type=ComponentType.MODEL,
component_type=ComponentTypes.MODEL,
version="1.0.0"
)
@ -65,7 +65,7 @@ def sample_agent_config(sample_model_config: ModelConfig, sample_tool_config: To
system_message="You are a helpful assistant",
model_client=sample_model_config,
tools=[sample_tool_config],
component_type=ComponentType.AGENT,
component_type=ComponentTypes.AGENT,
version="1.0.0"
)
@ -75,7 +75,7 @@ def sample_termination_config():
return TerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=10,
component_type=ComponentType.TERMINATION,
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
)
@ -88,7 +88,7 @@ def sample_team_config(sample_agent_config: AgentConfig, sample_termination_conf
participants=[sample_agent_config],
termination_condition=sample_termination_config,
model_client=sample_model_config,
component_type=ComponentType.TEAM,
component_type=ComponentTypes.TEAM,
version="1.0.0"
)
@ -115,7 +115,7 @@ async def test_load_tool_invalid_config(component_factory: ComponentFactory):
description="",
content="",
tool_type=ToolTypes.PYTHON_FUNCTION,
component_type=ComponentType.TOOL,
component_type=ComponentTypes.TOOL,
version="1.0.0"
))
@ -125,7 +125,7 @@ async def test_load_tool_invalid_config(component_factory: ComponentFactory):
description="Invalid function",
content="def invalid_func(): return invalid syntax",
tool_type=ToolTypes.PYTHON_FUNCTION,
component_type=ComponentType.TOOL,
component_type=ComponentTypes.TOOL,
version="1.0.0"
)
with pytest.raises(ValueError):
@ -154,7 +154,7 @@ async def test_load_termination(component_factory: ComponentFactory):
max_msg_config = TerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentType.TERMINATION,
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
)
termination = await component_factory.load_termination(max_msg_config)
@ -164,7 +164,7 @@ async def test_load_termination(component_factory: ComponentFactory):
# Test StopMessageTermination
stop_msg_config = TerminationConfig(
termination_type=TerminationTypes.STOP_MESSAGE,
component_type=ComponentType.TERMINATION,
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
)
termination = await component_factory.load_termination(stop_msg_config)
@ -174,13 +174,108 @@ async def test_load_termination(component_factory: ComponentFactory):
text_mention_config = TerminationConfig(
termination_type=TerminationTypes.TEXT_MENTION,
text="DONE",
component_type=ComponentType.TERMINATION,
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
)
termination = await component_factory.load_termination(text_mention_config)
assert isinstance(termination, TextMentionTermination)
assert termination._text == "DONE"
# Test AND combination
and_combo_config = TerminationConfig(
termination_type=TerminationTypes.COMBINATION,
operator="and",
conditions=[
TerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
),
TerminationConfig(
termination_type=TerminationTypes.TEXT_MENTION,
text="DONE",
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
)
],
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
)
termination = await component_factory.load_termination(and_combo_config)
assert termination is not None
# Test OR combination
or_combo_config = TerminationConfig(
termination_type=TerminationTypes.COMBINATION,
operator="or",
conditions=[
TerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
),
TerminationConfig(
termination_type=TerminationTypes.TEXT_MENTION,
text="DONE",
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
)
],
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
)
termination = await component_factory.load_termination(or_combo_config)
assert termination is not None
# Test invalid combinations
with pytest.raises(ValueError):
await component_factory.load_termination(TerminationConfig(
termination_type=TerminationTypes.COMBINATION,
conditions=[], # Empty conditions
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
))
with pytest.raises(ValueError):
await component_factory.load_termination(TerminationConfig(
termination_type=TerminationTypes.COMBINATION,
operator="invalid", # type: ignore
conditions=[
TerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
)
],
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
))
# Test missing operator
with pytest.raises(ValueError):
await component_factory.load_termination(TerminationConfig(
termination_type=TerminationTypes.COMBINATION,
conditions=[
TerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
),
TerminationConfig(
termination_type=TerminationTypes.TEXT_MENTION,
text="DONE",
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
)
],
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
))
@pytest.mark.asyncio
async def test_load_team(component_factory: ComponentFactory, sample_team_config: TeamConfig, sample_model_config: ModelConfig):
@ -201,13 +296,13 @@ async def test_load_team(component_factory: ComponentFactory, sample_team_config
system_message="You are another helpful assistant",
model_client=sample_model_config,
tools=sample_team_config.participants[0].tools,
component_type=ComponentType.AGENT,
component_type=ComponentTypes.AGENT,
version="1.0.0"
)
],
termination_condition=sample_team_config.termination_condition,
model_client=sample_model_config,
component_type=ComponentType.TEAM,
component_type=ComponentTypes.TEAM,
version="1.0.0"
)
team = await component_factory.load_team(selector_team_config)
@ -223,7 +318,7 @@ async def test_invalid_configs(component_factory: ComponentFactory):
name="test",
agent_type="InvalidAgent", # type: ignore
system_message="test",
component_type=ComponentType.AGENT,
component_type=ComponentTypes.AGENT,
version="1.0.0"
))
@ -233,7 +328,7 @@ async def test_invalid_configs(component_factory: ComponentFactory):
name="test",
team_type="InvalidTeam", # type: ignore
participants=[],
component_type=ComponentType.TEAM,
component_type=ComponentTypes.TEAM,
version="1.0.0"
))
@ -241,6 +336,6 @@ async def test_invalid_configs(component_factory: ComponentFactory):
with pytest.raises(ValueError):
await component_factory.load_termination(TerminationConfig(
termination_type="InvalidTermination", # type: ignore
component_type=ComponentType.TERMINATION,
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
))

View File

@ -1,15 +1,16 @@
import os
import asyncio
import pytest
from sqlmodel import Session, text, select
from typing import Generator
from datetime import datetime
from autogenstudio.database import DatabaseManager
from autogenstudio.datamodel import (
Model, ModelConfig, Agent, AgentConfig, Tool, ToolConfig,
Team, TeamConfig, ModelTypes, AgentTypes, TeamTypes, ComponentType,
TerminationConfig, TerminationTypes, LinkTypes, ToolTypes
from autogenstudio.datamodel.types import (
ModelConfig, AgentConfig, ToolConfig,
TeamConfig, ModelTypes, AgentTypes, TeamTypes, ComponentTypes,
TerminationConfig, TerminationTypes, ToolTypes
)
from autogenstudio.datamodel.db import Model, Tool, Agent, Team, LinkTypes
@pytest.fixture
@ -18,13 +19,13 @@ def test_db() -> Generator[DatabaseManager, None, None]:
db_path = "test.db"
db = DatabaseManager(f"sqlite:///{db_path}")
db.reset_db()
db.create_db_and_tables()
# Initialize database instead of create_db_and_tables
db.initialize_database(auto_upgrade=False)
yield db
# Clean up
asyncio.run(db.close())
db.reset_db()
try:
# Close database connections before removing file
db.engine.dispose()
# Remove the database file
if os.path.exists(db_path):
os.remove(db_path)
except Exception as e:
@ -44,7 +45,7 @@ def sample_model(test_user: str) -> Model:
config=ModelConfig(
model="gpt-4",
model_type=ModelTypes.OPENAI,
component_type=ComponentType.MODEL,
component_type=ComponentTypes.MODEL,
version="1.0.0"
).model_dump()
)
@ -60,7 +61,7 @@ def sample_tool(test_user: str) -> Tool:
description="A test tool",
content="async def test_func(x: str) -> str:\n return f'Test {x}'",
tool_type=ToolTypes.PYTHON_FUNCTION,
component_type=ComponentType.TOOL,
component_type=ComponentTypes.TOOL,
version="1.0.0"
).model_dump()
)
@ -76,7 +77,7 @@ def sample_agent(test_user: str, sample_model: Model, sample_tool: Tool) -> Agen
agent_type=AgentTypes.ASSISTANT,
model_client=ModelConfig.model_validate(sample_model.config),
tools=[ToolConfig.model_validate(sample_tool.config)],
component_type=ComponentType.AGENT,
component_type=ComponentTypes.AGENT,
version="1.0.0"
).model_dump()
)
@ -92,11 +93,11 @@ def sample_team(test_user: str, sample_agent: Agent) -> Team:
participants=[AgentConfig.model_validate(sample_agent.config)],
termination_condition=TerminationConfig(
termination_type=TerminationTypes.STOP_MESSAGE,
component_type=ComponentType.TERMINATION,
component_type=ComponentTypes.TERMINATION,
version="1.0.0"
).model_dump(),
team_type=TeamTypes.ROUND_ROBIN,
component_type=ComponentType.TEAM,
component_type=ComponentTypes.TEAM,
version="1.0.0"
).model_dump()
)
@ -144,7 +145,7 @@ class TestDatabaseOperations:
config=ModelConfig(
model="gpt-4",
model_type=ModelTypes.OPENAI,
component_type=ComponentType.MODEL,
component_type=ComponentTypes.MODEL,
version="1.0.0"
).model_dump()
)
@ -153,7 +154,7 @@ class TestDatabaseOperations:
config=ModelConfig(
model="gpt-3.5",
model_type=ModelTypes.OPENAI,
component_type=ComponentType.MODEL,
component_type=ComponentTypes.MODEL,
version="1.0.0"
).model_dump()
)
@ -181,3 +182,59 @@ class TestDatabaseOperations:
model_names = [model.config["model"] for model in linked_models.data]
assert "gpt-4" in model_names
assert "gpt-3.5" in model_names
def test_upsert_operations(self, test_db: DatabaseManager, sample_model: Model):
"""Test upsert for both create and update scenarios"""
# Test Create
response = test_db.upsert(sample_model)
assert response.status is True
assert "Created Successfully" in response.message
# Test Update
sample_model.config["model"] = "gpt-4-turbo"
response = test_db.upsert(sample_model)
assert response.status is True
assert "Updated Successfully" in response.message
# Verify Update
result = test_db.get(Model, {"id": sample_model.id})
assert result.status is True
assert result.data[0].config["model"] == "gpt-4-turbo"
def test_delete_operations(self, test_db: DatabaseManager, sample_model: Model):
"""Test delete with various filters"""
# First insert the model
test_db.upsert(sample_model)
# Test deletion by id
response = test_db.delete(Model, {"id": sample_model.id})
assert response.status is True
assert "Deleted Successfully" in response.message
# Verify deletion
result = test_db.get(Model, {"id": sample_model.id})
assert len(result.data) == 0
# Test deletion with non-existent id
response = test_db.delete(Model, {"id": 999999})
assert "Row not found" in response.message
def test_initialize_database_scenarios(self):
"""Test different initialize_database parameters"""
db_path = "test_init.db"
db = DatabaseManager(f"sqlite:///{db_path}")
try:
# Test basic initialization
response = db.initialize_database()
assert response.status is True
# Test with auto_upgrade
response = db.initialize_database(auto_upgrade=True)
assert response.status is True
finally:
asyncio.run(db.close())
db.reset_db()
if os.path.exists(db_path):
os.remove(db_path)

View File

@ -573,9 +573,10 @@ dev = [
[[package]]
name = "autogenstudio"
version = "0.4.0.dev37"
version = "0.4.0.dev38"
source = { editable = "packages/autogen-studio" }
dependencies = [
{ name = "aiofiles" },
{ name = "alembic" },
{ name = "autogen-agentchat" },
{ name = "autogen-core" },
@ -605,6 +606,7 @@ web = [
[package.metadata]
requires-dist = [
{ name = "aiofiles" },
{ name = "alembic" },
{ name = "autogen-agentchat", editable = "packages/autogen-agentchat" },
{ name = "autogen-core", editable = "packages/autogen-core" },